In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Union, Optional

## Coding the VAE

In [10]:
class SelfAttention(nn.Module):
    
    def __init__(self, d_embed, n_heads: int = 4, qkv_bias=True, out_bias=True) -> None:
        super().__init__()
        
        self.scale = d_embed ** -0.5
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
        self.QKV = nn.Linear(d_embed, d_embed * 3, bias=qkv_bias)
        self.O = nn.Linear(d_embed, d_embed, bias=out_bias)
        
    def forward(self, x: torch.Tensor, mask: bool = False) -> torch.Tensor:
        # x: (batch_size, height*width, channels)
        # x: (batch_size, seq_len, d_embed)
        in_shape = x.shape
        bs, seq_len, d_embed = x.shape
        q, k, v = self.QKV(x).chunk(3, dim=-1) # (batch_size, seq_len, d_embed)@(d_embed, d_embed*3) -> (batch_size, seq_len, d_embed*3) -> (3x) (batch_size, seq_len, d_embed)
        
        # (batch_size, seq_len, d_embed) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, n_heads, seq_len, d_head)
        q = q.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        
        attn_scores = torch.bmm(q, k.transpose(2, 3)) * self.scale # (batch_size, n_heads, seq_len, seq_len)
        if mask is not False:
            mask = torch.ones_like(attn_scores).bool().triu(1) # (batch_size, n_heads, seq_len, seq_len)
            attn_scores.masked_fill_(mask, -1e9)
        weights = F.softmax(attn_scores, dim=-1)
        output = torch.bmm(weights, v) # (batch_size, n_heads, seq_len, seq_len) -> (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
        output = output.transpose(1, 2).contiguous().view(in_shape) # (batch_size, n_heads, seq_len, d_head) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, seq_len, d_embed)
        return self.O(output) # (batch_size, seq_len, d_embed)@(d_embed, d_embed) -> (batch_size, seq_len, d_embed)
        

In [7]:
class VaeAttentionBlock(nn.Module):
    
    def __init__(self, n_channels: int) -> None:
        super().__init__()
        self.gourp_norm_1 = nn.GroupNorm(num_groups=32, num_channels=n_channels)
        self.attention = SelfAttention(n_channels)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residue = x 
        n,c,h,w = x.shape # BxCxHxW
        x = x.view(n,c,h*w) # BxCxH*W
        x = x.transpose(-1, -2) # BxH*WxC
        attn = self.attention(x) # BxH*WxC
        x = x.transpose(-1, -2).view((n, c, h, w)) # BxCxHxW
        x = x + residue
        return x

In [9]:
class VaeResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.gourp_norm_1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.group_norm_2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
        
        if in_channels == out_channels:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = self.conv_1(F.silu(self.gourp_norm_1(x)))
        res = self.conv_2(F.silu(self.group_norm_2(res)))
        return res + self.skip(x)

In [6]:
class VaeEncoder(nn.Sequential):
    def __init__(self, ) -> None:
        super().__init__(
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, padding=1), # 3x512x512 -> 128x512x512
            VaeResidualBlock(in_channels=128, out_channels=128), # 128x512x512 -> 128x512x512
            VaeResidualBlock(in_channels=128, out_channels=128), # 128x512x512 -> 128x512x512
            
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=0), # 128x512x512 -> 128x256x256
            VaeResidualBlock(in_channels=128, out_channels=256), # 128x256x256 -> 256x256x256
            VaeResidualBlock(in_channels=256, out_channels=256), # 256x256x256 -> 256x256x256
            
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=0), # 256x256x256 -> 256x128x128
            VaeResidualBlock(in_channels=256, out_channels=512), # 256x128x128 -> 512x128x128
            VaeResidualBlock(in_channels=512, out_channels=512), # 512x128x128 -> 512x128x128
            
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=0), # 512x128x128 -> 512x64x64
            VaeResidualBlock(in_channels=512, out_channels=512), # 512x64x64 -> 512x64x64
            VaeResidualBlock(in_channels=512, out_channels=512), # 512x64x64 -> 512x64x64
            VaeResidualBlock(in_channels=512, out_channels=512), # 512x64x64 -> 512x64x64
            
            VaeAttentionBlock(n_channels=512), # 512x64x64 -> 512x64x64
            
            VaeResidualBlock(in_channels=512, out_channels=512), # 512x64x64 -> 512x64x64
    
            nn.GroupNorm(num_groups=32, num_channels=512),
            nn.SiLU(),
            
            nn.Conv2d(in_channels=512, out_channels=8, kernel_size=3, padding=1), # 512x128x128 -> 8x64x64
            nn.Conv2d(in_channels=8, out_channels=8, kernel_size=1, padding=0), #  8x64x64 -> 8x64x64
        )
        
    def forward(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        for layer in self:
            if getattr(layer, 'stride', None) == (2,2):
                x = F.pad(x, (0,1,0,1))
            x = layer(x)
        # Bx8x64x64 -> (2x) of Bx4x64x64
        mean, log_var = x.chunk(2, dim=1)
        std = log_var.clamp(-30, 20).exp().sqrt()
        latent = (mean + std * noise) * 0.18215
        return latent

In [None]:
class VaeDecoder(nn.Sequential):
    def __init__(self, ) -> None:
        super().__init__()    

In [None]:
class VAE(nn.Module):
    
    def __init__(self, encoder: VaeEncoder, decoder: VaeDecoder) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder