In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Union, Optional

## Coding the VAE

In [54]:
class SelfAttention(nn.Module):
    
    def __init__(self, d_embed, n_heads: int = 4, qkv_bias=True, out_bias=True) -> None:
        super().__init__()
        
        self.scale = d_embed ** -0.5
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
        self.QKV = nn.Linear(d_embed, d_embed * 3, bias=qkv_bias)
        self.O = nn.Linear(d_embed, d_embed, bias=out_bias)
        
    def forward(self, x: torch.Tensor, mask: bool = False) -> torch.Tensor:
        # x: (batch_size, height*width, channels)
        # x: (batch_size, seq_len, d_embed)
        in_shape = x.shape
        bs, seq_len, d_embed = x.shape
        q, k, v = self.QKV(x).chunk(3, dim=-1) # (batch_size, seq_len, d_embed)@(d_embed, d_embed*3) -> (batch_size, seq_len, d_embed*3) -> (3x) (batch_size, seq_len, d_embed)
        
        # (batch_size, seq_len, d_embed) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, n_heads, seq_len, d_head)
        q = q.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        
        print(f'-------------> x: {in_shape}, q: {q.shape}, k: {k.shape}')
        # print this q, k.transpose(-2, -1)
        print(f'-------------> q: {q.shape}, k: {k.transpose(-2, -1).shape}')
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not False:
            mask = torch.ones_like(attn_scores).bool().triu(1) # (batch_size, n_heads, seq_len, seq_len)
            attn_scores.masked_fill_(mask, -1e9)
        weights = F.softmax(attn_scores, dim=-1)
        output = weights @ v # (batch_size, n_heads, seq_len, seq_len) -> (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
        output = output.transpose(1, 2).contiguous().view(in_shape) # (batch_size, n_heads, seq_len, d_head) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, seq_len, d_embed)
        return self.O(output) # (batch_size, seq_len, d_embed)@(d_embed, d_embed) -> (batch_size, seq_len, d_embed)

In [55]:
class VaeAttentionBlock(nn.Module):
    
    def __init__(self, n_channels: int) -> None:
        super().__init__()
        self.gourp_norm_1 = nn.GroupNorm(num_groups=32, num_channels=n_channels)
        self.attention = SelfAttention(n_channels)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residue = x 
        n,c,h,w = x.shape # BxCxHxW
        x = x.view(n,c,h*w) # BxCxH*W
        x = x.transpose(-1, -2) # BxH*WxC
        attn = self.attention(x) # BxH*WxC
        x = x.transpose(-1, -2).view((n, c, h, w)) # BxCxHxW
        x = x + residue
        return x

In [56]:
class VaeResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.gourp_norm_1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.group_norm_2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
        
        if in_channels == out_channels:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = self.conv_1(F.silu(self.gourp_norm_1(x)))
        res = self.conv_2(F.silu(self.group_norm_2(res)))
        return res + self.skip(x)

In [57]:
class VaeEncoder(nn.Sequential):
    def __init__(self) -> None:
        layers = [
            # Initial layers
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, padding=1),

            # Residual blocks 512x512
            VaeResidualBlock(in_channels=128, out_channels=128),
            VaeResidualBlock(in_channels=128, out_channels=128),

            # Downsample to 256x256
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2),
            VaeResidualBlock(in_channels=128, out_channels=256),
            VaeResidualBlock(in_channels=256, out_channels=256),

            # Downsample to 128x128
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2),
            VaeResidualBlock(in_channels=256, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),

            # Downsample to 64x64
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),

            # Attention block 64x64
            VaeAttentionBlock(n_channels=512),
            
            # Additional residual block 64x64
            VaeResidualBlock(in_channels=512, out_channels=512),

            # Final layers
            nn.GroupNorm(num_groups=32, num_channels=512),
            nn.SiLU(),
            nn.Conv2d(in_channels=512, out_channels=8, kernel_size=3, padding=1),
            nn.Conv2d(in_channels=8, out_channels=8, kernel_size=1)
        ]
        super(VaeEncoder, self).__init__(*layers)
        
    def forward(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        for layer in self:
            if getattr(layer, 'stride', None) == (2,2):
                x = F.pad(x, (0,1,0,1))
            x = layer(x)
        # Bx8x64x64 -> (2x) of Bx4x64x64
        mean, log_var = x.chunk(2, dim=1)
        std = log_var.clamp(-30, 20).exp().sqrt()
        latent = (mean + std * noise) * 0.18215
        return latent # Bx4x64x64


In [58]:
import torch.nn as nn

class VaeDecoder(nn.Sequential):
    def __init__(self) -> None:
        layers = [
            # Initial layers
            nn.Conv2d(in_channels=4, out_channels=4, kernel_size=1, padding=0),
            nn.Conv2d(in_channels=4, out_channels=512, kernel_size=3, padding=1),
            
            # Residual blocks 64x64
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeAttentionBlock(n_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            
            # Upsampling to 128x128
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            VaeResidualBlock(in_channels=512, out_channels=512),
            
            # Upsampling to 256x256
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            VaeResidualBlock(in_channels=512, out_channels=256),
            VaeResidualBlock(in_channels=256, out_channels=256),
            VaeResidualBlock(in_channels=256, out_channels=256),
            
            # Upsampling to 512x512
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            VaeResidualBlock(in_channels=256, out_channels=128),
            VaeResidualBlock(in_channels=128, out_channels=128),
            VaeResidualBlock(in_channels=128, out_channels=128),
            
            nn.GroupNorm(num_groups=32, num_channels=128),
            nn.SiLU(),
            nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, padding=1)
        ]
        super(VaeDecoder, self).__init__(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x /= 0.18215
        for layer in self: x = layer(x)
        return x


In [59]:
class VAE(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.encoder = VaeEncoder()
        self.decoder = VaeDecoder()
        
    def encode(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        return self.encoder(x, noise)

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(x)
        
    def forward(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        latent = self.encoder(x, noise)
        return self.decoder(latent)

In [60]:
VAE()

VAE(
  (encoder): VaeEncoder(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): VaeResidualBlock(
      (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (gourp_norm_1): GroupNorm(32, 128, eps=1e-05, affine=True)
      (group_norm_2): GroupNorm(32, 128, eps=1e-05, affine=True)
      (skip): Identity()
    )
    (2): VaeResidualBlock(
      (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (gourp_norm_1): GroupNorm(32, 128, eps=1e-05, affine=True)
      (group_norm_2): GroupNorm(32, 128, eps=1e-05, affine=True)
      (skip): Identity()
    )
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
    (4): VaeResidualBlock(
      (conv_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv

In [62]:
device = 'mps'
x = torch.randn(1, 3, 512, 512).to(device)
vae = VAE().to(device)
enc_x = vae.encode(x, torch.randn(1, 4, 64, 64).to(device))

-------------> x: torch.Size([1, 4096, 512]), q: torch.Size([1, 4, 4096, 128]), k: torch.Size([1, 4, 4096, 128])
-------------> q: torch.Size([1, 4, 4096, 128]), k: torch.Size([1, 4, 128, 4096])


In [63]:
enc_x.shape

torch.Size([1, 4, 64, 64])

-------------> x: torch.Size([1, 4096, 512]), q: torch.Size([1, 4, 4096, 128]), k: torch.Size([1, 4, 4096, 128])
-------------> q: torch.Size([1, 4, 4096, 128]), k: torch.Size([1, 4, 128, 4096])


In [65]:
dec_x.shape

torch.Size([1, 3, 512, 512])