In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [14]:
class PositionalEmbedding(nn.Module):
    def __init__(self, ):
        super().__init__()
    
    def forward(self, x):
        return x

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, emb_dim=768, patch_size=16,
                 in_channels=3, img_size=224, bias = True):
        super().__init__()
        self.emb_dim, self.patch_size, self.in_channels, self.img_size = (emb_dim, patch_size,
                 in_channels, img_size)
        self.n_patches = ( img_size // patch_size ) ** 2
        self.patchify = nn.Conv2d( in_channels=in_channels, out_channels=emb_dim,
                                  kernel_size=patch_size, stride = patch_size, bias=bias )
    
    def forward(self, x):
        """
        receives (B C H W)
        1. conv into patches
        """
        x = self.patchify(x)            # (B C H W ) -> (B, emd_dim, sqrt(n_patches), sqrt(n_patches))
        x = x.flatten(2)                # abve to (B emd_dim n_patches)
        return x

pe = PatchEmbedding()
x = torch.randn( size=(2, 3, 224, 224) )
pe(x).shape

torch.Size([2, 768, 196])

In [16]:
class Block(nn.Module):
    def __init__(self, ):
        super().__init__()
    
    def forward(self, x):
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, emb_dim,
                 patch_size,
                 in_channels,
                 img_size, 
                 n_classes, pooling_type,
                 n_blocks
                 ):
        super().__init__()
        self.n_classes = n_classes
        self.pooling_type = pooling_type
        self.patch_emb = PatchEmbedding(
                 emb_dim, patch_size,
                 in_channels, img_size
                 )
        self.pos_emb = PositionalEmbedding(emb_dim, self.pooling_type)
        self.out_norm = nn.LayerNorm(emb_dim)
        self.out_proj = nn.Linear(emb_dim, self.n_classes)
        self.blocks = nn.Sequential(* [ Block() for _ in range(n_blocks) ] )
    
    def forward(self, x, targets = None):
        B = x.shape[0]
        x = self.patch_emd(x)
        
        if self.pooling_type == "cls":
            pass
        elif self.pooling_type == "avg":
            pass

        x = x + self.pos_emb(x)
        for block in self.blocks:
            x = block(x)
        
        x = self.out_norm(x)

        if self.pooling_type == "cls":
            agg_score = x[:, 0]                         # Take the cls token for out_proj
        elif self.pooling_type == "avg":
            agg_score = x.mean(dim=1)                   # take avg across time step dim

        logits = self.out_proj(agg_score)
        loss = None
        if targets is not None:
            pass
        
        return logits, loss

In [None]:
model = VisionTransformer()
x = torch.randn( size=(2, 3, 224, 224) ) # (B, image channels, height, width)
model(x)

torch.Size([2, 3, 224, 224])
