# ViT

> Components & defs for ViT-based Encoder & Decoder

In [None]:
#| default_exp vit

In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F 

## ViT components 

used in both encoder and decoder.

In [None]:
#| export
class RoPE2D(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        i = torch.arange(0, head_dim // 4)
        self.register_buffer('ifreq', 1.0 / (10000 ** (2 * i / head_dim)))
        self.num_patches = -1

    def calc_pi(self, num_patches, grid_w): 
        self.num_patches = num_patches
        pi = torch.arange(self.num_patches, device=self.ifreq.device)
        self.register_buffer('pih', pi // grid_w)
        self.register_buffer('piw', pi % grid_w)

    def _rotate(self, x, sin, cos):
        x_even, x_odd = x[..., 0::2], x[..., 1::2]
        x_out = torch.empty_like(x)
        x_out[..., 0::2] = x_even * cos - x_odd * sin
        x_out[..., 1::2] = x_odd * cos + x_even * sin
        return x_out

    def forward(self, x, nphw=(16,16)):  # x: (batch, heads, num_patches, head_dim)
        num_patches, grid_w, head_dim = x.shape[2], nphw[-1], x.shape[-1]
        if num_patches != self.num_patches: self.calc_pi(num_patches, grid_w)
        
        freqs_h = torch.outer(self.pih.float(), self.ifreq)[None, None]
        freqs_w = torch.outer(self.piw.float(), self.ifreq)[None, None]
        sin_h, cos_h = torch.sin(freqs_h), torch.cos(freqs_h)
        sin_w, cos_w = torch.sin(freqs_w), torch.cos(freqs_w)
        
        x_h, x_w = x[..., :head_dim//2], x[..., head_dim//2:]
        x_h_out = self._rotate(x_h, sin_h, cos_h)
        x_w_out = self._rotate(x_w, sin_w, cos_w)
        
        return torch.cat([x_h_out, x_w_out], dim=-1)

In [None]:
# testing
head_dim = 768//4
x = torch.rand((2, 8, 256, head_dim))
rope = RoPE2D(head_dim)
rot_x = rope(x) 
print("rot_x.shape = ",rot_x.shape) 

rot_x.shape =  torch.Size([2, 8, 256, 192])


In [None]:
#| export
class Attention(nn.Module):
    def __init__(self, dim, heads, dim_qkv=None):
        super().__init__()
        if dim_qkv is None: dim_qkv = dim
        self.heads, self.dim_qkv  = heads, dim_qkv
        self.head_dim = dim_qkv // heads
        self.rope = RoPE2D(self.head_dim)
        self.qkv = nn.Linear(dim, dim_qkv * 3)
        self.proj = nn.Linear(dim_qkv, dim)
        
    def forward(self, x):  # x: (batch, num_patches, dim)
        B, N = x.shape[:2]
        # Project and split into q, k, v
        qkv = self.qkv(x)  # (B, N, dim_qkv * 3)
        qkv = qkv.reshape(B, N, 3, self.heads, self.head_dim)  # split into 3, heads, head_dim
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (B, heads, N, head_dim)
        q, k = self.rope(q), self.rope(k)

        out = F.scaled_dot_product_attention(q, k, v)  # (B, heads, N, head_dim), "flash attention"
        out = out.transpose(1, 2).reshape(B, N, self.dim_qkv)  # Merge heads
        return self.proj(out)  # project back

In [None]:
# testing 
x = torch.rand(2, 256, 768)
attn = Attention(768, 8) 
a = attn(x) 
print("Done: a.shape = ",a.shape)

attn2 = Attention(768, 8, 64) 
a2 = attn2(x) 
print("Done: a2.shape = ",a2.shape)

Done: a.shape =  torch.Size([2, 256, 768])
Done: a2.shape =  torch.Size([2, 256, 768])


In [None]:
#| export
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, dim_qkv=None, hdim=None):
        super().__init__()
        self.attn = Attention(dim, heads, dim_qkv)
        if hdim is None: hdim = 4*dim
        self.lin1, self.lin2 = nn.Linear(dim, hdim), nn.Linear(hdim, dim)
        self.norm1, self.norm2 = nn.LayerNorm(dim), nn.LayerNorm(dim)
        self.act = nn.GELU() 
        
    def forward(self, x):  # x: (batch, num_patches, dim)
        x = x + self.attn(self.norm1(x))   # "pre-norm"
        x = x + self.lin2(self.act(self.lin1(self.norm2(x))))  
        return x  # (batch, num_patches, dim)

In [None]:
# testing
x = torch.randn(2, 256, 768) 
trans = TransformerBlock(768, 8) 
out = trans(x) 
print("out.shape = ",out.shape) 

out.shape =  torch.Size([2, 256, 768])


## Encoder

Does patch embedding and then some transformer blocks.

In [None]:
#| export
class PatchEmbedding(nn.Module):
    def __init__(self, 
                in_channels=1,  # 1 for solo piano, for midi PR's, = # of instruments
                patch_size=16,  # assuming square patches, e.g. 16 implies 16x16
                dim=768):       # embedding dimension
        super().__init__()
        self.conv = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):  # x: (batch, channels, height, width)
        assert all(s % self.conv.kernel_size[0] == 0 for s in x.shape[-2:]), \
            f"Image size {x.shape[-2:]} must be divisible by patch_size {self.conv.kernel_size[0]}"
        conv_patches = self.conv(x).flatten(2).permute(0,2,1)
        # Check if each patch region in the image is empty
        k = self.conv.kernel_size[0]
        patches = x.unfold(2, k, k).unfold(3, k, k)  # extract patches
        pmask = (patches.amax(dim=(-1,-2)) > 0.2).squeeze(1).flatten(1) # (B, num_patches), 0=empty, 1=not

        # save patch position "coordinates" (for masking later)
        H, W = x.shape[-2] // k, x.shape[-1] // k  # grid dimensions
        rows = torch.arange(H, device=x.device).repeat_interleave(W)
        cols = torch.arange(W, device=x.device).repeat(H)
        pos = torch.stack([rows, cols], dim=-1)  # (num_patches, 2)
        return conv_patches, pmask, pos  


In [None]:
# testing
pe = PatchEmbedding()
x = torch.randn(2, 1, 256, 256)
z, pmask, pos = pe(x) 
print("z.shape, pmask.shape, pos.shape = ",z.shape, pmask.shape, pos.shape) 
print("pos = \n",pos.tolist())

In [None]:
#| export
class ViTEncoder(nn.Module):
    """Vision Transformer Encoder for piano rolls"""
    def __init__(self, 
                in_channels,  # 
                image_size,   # tuple (H,W), e.g. (256, 256)
                patch_size,   # assuming square patches, e.g 16
                dim,          # embedding dim, e.g. 768
                depth,        # number of transformerblock layers -- 4? 
                heads):       # number of attention heads - 8? 
        super().__init__()
        self.patch_embed = PatchEmbedding(in_channels=in_channels,patch_size=patch_size, dim=dim)
        self.blocks = nn.ModuleList([ TransformerBlock(dim, heads) for _ in range(depth) ])
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        
    def forward(self, x, return_cls_only=True):
        x, pmask, pos = self.patch_embed(x)      # x is now patches, pmask=1 for non-empty patches, 0 for empty
        cls = self.cls_token.expand(x.shape[0], -1, -1) # add cls token 
        x = torch.cat([cls, x], dim=1)
        pmask = torch.cat([pmask.new_ones(pmask.shape[0], 1), pmask], dim=1)  # (B, 65)
        for block in self.blocks:  
            x = block(x) 
            x = torch.where(pmask.unsqueeze(-1), x, x * 1e-3)  # empty patches go to small but nonzero #s
        return (x[:, 0] if return_cls_only else x), pmask

In [None]:
#| eval: false
B, C, H, W = 4, 1, 256, 256
patch_size, dim, depth, heads = 16, 768, 4, 8 
x = torch.randn(B,C,H,W) 
encoder = ViTEncoder( C, (H,W), patch_size, dim, depth, heads) 
out, pmask = encoder(x) 
print("out.shape = ",out.shape) 

out2, pmask = encoder(x, return_cls_only=False) 
print("out2.shape = ",out2.shape) 

## Decoder

Like the Encoder, only instead of doing "PatchEmbedding" on the front end "UnPatchify"

In [None]:
#| export
class Unpatchify(nn.Module):
    "Take patches and assemble an image"
    def __init__(self, 
                out_channels=1,  # 1 for solo piano, for midi PR's, = # of instruments
                image_size = (128, 128),  # h,w for output image  
                patch_size=16,  # assuming square patches, e.g. 16 implies 16x16
                dim=768):       # embedding dimension
        super().__init__()
        self.image_size, self.patch_size, self.out_channels = image_size, patch_size, out_channels
        self.npatches_x, self.npatches_y = image_size[0]//patch_size, image_size[1]//patch_size 
        self.lin = nn.Linear(dim, out_channels * patch_size * patch_size )  # (B, 64, 768) -> (B, 64, 256) 
        
    def forward(self, z):  # z: patch embeddings (batch, num_patches, dim)
        out = self.lin(z)  # B, N, D 
        out = out.reshape(-1, self.npatches_x, self.npatches_y, self.patch_size, self.patch_size, self.out_channels)
        out = out.permute(0, 5, 1, 3, 2, 4)  # (B, 1, 8, 16, 8, 16)
        out = out.reshape(-1, self.out_channels, self.image_size[0], self.image_size[1])        
        return out # (B, out_channels, H, W) 

In [None]:
#| eval: false 
z = torch.randn([3, 64, 768]) 
unpatch = Unpatchify() # keep the defaults
img = unpatch(z) 
print("img.shape = ",img.shape) 

In [None]:
#| export
class ViTDecoder(nn.Module):
    """Vision Transformer Decoder for piano rolls"""
    def __init__(self, 
                out_channels,  # 
                image_size,   # tuple (H,W), e.g. (256, 256)
                patch_size,   # assuming square patches, e.g 16
                dim,          # embedding dim, e.g. 768
                depth=4,        # number of transformerblock layers -- 4? 
                heads=8):       # number of attention heads - 8? 
        super().__init__()
        self.blocks = nn.ModuleList([ TransformerBlock(dim, heads) for _ in range(depth) ])
        self.unpatch = Unpatchify(out_channels, image_size, patch_size, dim)
        
    def forward(self, z, strip_cls_token=True):
        for block in self.blocks:  z = block(z)
        if strip_cls_token: z = z[:,1:] 
        img = self.unpatch(z) 
        return img

In [None]:
z = torch.randn(3, 65, 768)  # batch of 3, with CLS token
decoder = ViTDecoder(out_channels=1, image_size=(128,128), patch_size=16, dim=768)
img = decoder(z)
print(img.shape)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()