In [1]:
%%writefile models/attention.py
import torch
from torch import nn
import math
import logging

class MultiheadSelfAttention(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, use_bias=True):
        super().__init__()
        self.proj_q = nn.Linear(embedding_dim, embedding_dim, bias=use_bias)
        self.proj_k = nn.Linear(embedding_dim, embedding_dim, bias=use_bias)
        self.proj_v = nn.Linear(embedding_dim, embedding_dim, bias=use_bias)
        self.num_heads = num_heads
        self.head_dim = embedding_dim // self.num_heads
        self.proj_out = nn.Linear(embedding_dim, embedding_dim, bias=use_bias)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, lookahead_mask: bool=True) -> torch.Tensor:
        # q, k, v: (batch_size, seq_len, embedding_dim)
        
        batch_size, seq_len, embedding_dim = q.shape
        
        q = self.proj_q(q)
        k = self.proj_k(k)
        v = self.proj_v(v)

        desired_shape = (batch_size, seq_len, self.num_heads, self.head_dim)
        
        # (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, num_heads, head_dim) -> (batch_size, num_heads, seq_len, head_dim)
        q = q.view(desired_shape).permute(0, 2, 1, 3)
        k = k.view(desired_shape).permute(0, 2, 1, 3)
        v = v.view(desired_shape).permute(0, 2, 1, 3)

       

        # (batch_size, num_heads, seq_len, head_dim) @ (batch_size, num_heads, head_dim, seq_len) -> (batch_size, seq_len, seq_len, seq_len)
        attn_weights = q @ k.transpose(-1, -2)
        if lookahead_mask:
            mask = torch.ones_like(attn_weights, dtype=torch.bool).triu(1)
            attn_weights.masked_fill_(mask, -torch.inf)
            
        attn_weights /= math.sqrt(self.head_dim)
        attn_weights = torch.softmax(attn_weights, dim=-1)

        # (batch_size, num_heads, seq_len, seq_len) @ (batch_size, num_heads, seq_len, head_dim) -> (batch_size, num_heads, seq_len, head_dim)
        attn_weights = attn_weights @ v

        # (batch_size, num_heads, seq_len, head_dim) -> (batch_size, seq_len, num_heads, head_dim) -> (batch_size, seq_len, embedding_dim)
        attn_weights = attn_weights.transpose(1, 2).reshape((batch_size, seq_len, embedding_dim))

        out = self.proj_out(attn_weights)

        return out

Overwriting models/attention.py


In [2]:
%%writefile models/resnet.py
import torch
from torch import nn
from torch.nn import functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dropout: float=0.0):
        super().__init__()

        self.groupnorm_1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        self.groupnorm_2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        if in_channels != out_channels:
            self.proj_input = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        else:
            self.proj_input = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_norm = self.groupnorm_1(x)
        out = F.silu(x_norm)
        out = self.conv_1(out)

        out = self.groupnorm_2(out)
        out = F.silu(out)
        out = self.dropout(out)
        out = self.conv_2(out)
        
        return out + self.proj_input(x)
        

Overwriting models/resnet.py


In [3]:
%%writefile models/vae.py
import torch
from torch import nn
from torch.nn import functional as F
from .resnet import ResidualBlock
from .attention import MultiheadSelfAttention
from typing import List

class Downsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.pad = (0, 1, 0, 1)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.pad(x, self.pad)
        x = self.conv(x)
        return x

class UpSample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.upsample(x)
        x = self.conv(x)
        return x
        
class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.attn = MultiheadSelfAttention(num_heads=1, embedding_dim=in_channels)

    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, channels, h, w)
        batch_size, channels, h, w = x.shape
        x_norm = self.groupnorm(x)

        x_norm = x_norm.view((batch_size, channels, -1)).transpose(1, 2)

        out = self.attn(q=x_norm, k=x_norm, v=x_norm)
        out = out.transpose(1, 2).reshape(x.shape)
        
        return out + x
        
class VAE_Encoder(nn.Module):
    def __init__(self, in_channels: int, ch_mult: List[int]=[1, 2, 4, 8], dropout: float=0.0, z_channels: int=8):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, 128, kernel_size=3, stride=1, padding=1)

        # start
        self.down = nn.ModuleList()
        in_ch_mult = [1] + ch_mult
        ch = 128
        for i in range(len(ch_mult)):
            block_in = ch * in_ch_mult[i]
            block_out = ch * ch_mult[i]
            block = nn.Sequential(
                ResidualBlock(block_in, block_out, dropout),
                ResidualBlock(block_out, block_out, dropout),
            )
            
            down = nn.Module()
            down.block = block
            if i != len(ch_mult) - 1:
                down.downsample = Downsample(block_out)
            else:
                down.downsample = nn.Identity()
                
            self.down.append(down)
            curr_channels = block_out

        # middle
        self.mid = nn.Module()
        self.mid.res_block_1 = ResidualBlock(curr_channels, curr_channels)
        self.mid.attn_block_1 = AttentionBlock(in_channels=curr_channels)
        self.mid.res_block_2 = ResidualBlock(curr_channels, curr_channels)
        
        # end
        self.out = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=curr_channels),
            nn.SiLU(),
            nn.Conv2d(curr_channels, 2*z_channels, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(2*z_channels, 2*z_channels, kernel_size=1, stride=1, padding=0))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_in(x)
        
        for down in self.down:
            x = down.block(x)
            x = down.downsample(x)

        x = self.mid.res_block_1(x)
        x = self.mid.attn_block_1(x)
        x = self.mid.res_block_2(x)

        x = self.out(x)
        return x


class VAE_Decoder(nn.Module):
    def __init__(self, ch_mult: List[int]=[1, 2, 4, 8], dropout: float=0.0, z_channels: int=8):
        super().__init__()

        ch = 128
        block_in = ch*ch_mult[-1]
        self.conv_in = nn.Sequential(nn.Conv2d(z_channels, z_channels, kernel_size=1, padding=0), 
                                     nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1))
        
    
        # mid
        self.mid = nn.Module()
        self.mid.res_block_1 = ResidualBlock(block_in, block_in)
        self.mid.attn_block_1 = AttentionBlock(in_channels=block_in)
        self.mid.res_block_2 = ResidualBlock(block_in, block_in)

        # upsampling
        self.up = nn.ModuleList()
        in_ch_mult = [1] + ch_mult
        for i in reversed(range(len(ch_mult))):
            block_out = ch * ch_mult[i]
            block = nn.Sequential(
                ResidualBlock(block_in, block_out),
                ResidualBlock(block_out, block_out),
                ResidualBlock(block_out, block_out)
            )
            up = nn.Module()
            up.block = block
            if i != 0:
                up.upsample = UpSample(in_channels=block_out)
            else:
                up.upsample = nn.Identity()
            self.up.append(up)
            block_in = block_out

        self.out = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=ch), 
            nn.SiLU(),
            nn.Conv2d(ch, 3, kernel_size=3, stride=1, padding=1))

        
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, channels, h, w)
        x /= 0.18215
        x = self.conv_in(x)

        for up in self.up:
            x = up.block(x)
            x = up.upsample(x)

        out = self.out(x)
        return out
        
        
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = VAE_Encoder(in_channels=3)
        self.decoder = VAE_Decoder(z_channels=8)

    def encode(self, x: torch.Tensor, noise=None) -> torch.Tensor:
        # z: (batch_size, channels, h, w)
        z = self.encoder(x)
        mean, log_variance = z.chunk(2, dim=1)
        log_variance = torch.clamp(log_variance, -20, 30)
        variance = log_variance.exp()
        stdev = torch.sqrt(variance)
        if noise:
            return mean + stdev * noise
        else:
            return mean + stdev * torch.randn_like(stdev)
        

    def decode(self, z: torch.Tensor):
        return self.decoder(z)

class VQ_VAE(nn.Module):
    def __init__(self):
        super().__init__()
        

Overwriting models/vae.py


In [4]:
from models.vae import VAE
import torch 

random_tensor = torch.randn((1, 3, 64, 64))
encoder = VAE()
out_tensor = encoder.encode(random_tensor)
print(out_tensor.shape)
out_tensor = encoder.decode(out_tensor)
print(out_tensor.shape)

torch.Size([1, 8, 8, 8])
torch.Size([1, 3, 64, 64])
