In [17]:
# Visual Transformer Implementation
from torch import nn
import torch 
import math 

In [18]:
## Patch Embeddings 
class PatchEmbedding(nn.Module):
    """
    split image into grids, and convert to a latent space vector
    """

    def __init__(self, image_size, patch_size, in_channels, d_model, device):
        super().__init__()
        self.image_size = image_size 
        self.patch_size = patch_size 
        self.in_channels = in_channels
        self.d_model = d_model
        # image_size should be dividable by patch_size 
        # default: image_size=400, patch_size=40
        self.num_patches = (self.image_size // self.patch_size) ** 2 

        # using convolution to create non-overlapping patches
        self.emb = nn.Conv2d(self.in_channels, self.d_model, kernel_size=self.patch_size, stride=self.patch_size, device=device)

    def forward(self, x): 
        # input shape: (batch_size, in_channels, image_size, image_size) 
        # output shape: (batch_size, num_patches, out_channels)
        
        x = self.emb(x)
        # intermediate x's shape: (batch_size, seq_len, sqrt(num_patches), sqrt(num_patches))
        
        x = x.flatten(2).transpose(1, 2)

        return x 

## Positional Embeddings 
class PositionalEmbedding(nn.Module): 
    def __init__(self, d_model, image_size, patch_size, device):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2
        self.max_len = num_patches + 1
        self.d_model = d_model

        self.encoding = nn.Parameter(torch.zeros(1, self.max_len, d_model))

    def forward(self, x):
        _, seq_len, _ = x.shape 
        return self.encoding[:seq_len, :]
    

class TransformerEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, d_model, drop_prob, device):
        super().__init__()
        self.patch_emb = PatchEmbedding(image_size, patch_size, in_channels, d_model, device)
        self.pos_emb = PositionalEmbedding(d_model, image_size, patch_size, device)
        self.dropout = nn.Dropout(p=drop_prob)

        # Similar to BERT model, we should add CLS token to the starting of the sequence 
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

    def forward(self, x): 
        batch_size, _, _, _ = x.shape
        
        x = self.patch_emb(x)
        

        cls_token = self.cls_token.expand(batch_size, -1, -1)
        # cls_token shape: (batch_size, 1, d_model)

        x = torch.cat((cls_token, x), dim=1)
        pos_emb = self.pos_emb(x)

        return self.dropout(x + pos_emb)

In [19]:
## Attention Block 
class SelfAttentionBlock(nn.Module): 
    """
    A attention block with scale dot product attention for 
    Query, Key, Value 
    """
    def __init__(self): 
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, eps=1e-12): 
        # Since this is not a NLP, we don't need any padding-mask or look-ahead mask 
        # also, q, k, v have the same shape 
        batch_size, n_head, seq_len, d_tensor = k.shape

        k_T = k.transpose(2, 3)
        att_weight = (q @ k_T) / math.sqrt(d_tensor)

        # att_weight shape: batch_size, n_head, seq_len, d_tensor 
        att_score = self.softmax(att_weight)

        return att_score @ v

## Multihead Attention Block 
class MultiheadAttentionBlock(nn.Module):
    def __init__(self, n_head, d_model): 
        super().__init__()
        self.n_head = n_head 
        self.d_model = d_model 

        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)

        self.attention = SelfAttentionBlock()

        self.Wconcat = nn.Linear(d_model, d_model)

    def split(self, tensor): 
        """
        split tensor into n_heads 
        (batch_size, seq_len, d_model) -> (batch_size, n_head, seq_len, d_tensor)
        """
        batch_size, seq_len, d_model = tensor.shape 

        d_tensor = d_model // self.n_head

        tensor = tensor.reshape(batch_size, seq_len, self.n_head, d_tensor).transpose(1, 2)
        
        return tensor 

    def concat(self, tensor): 
        """
        reverse of split
        (batch_size, n_head, seq_len, d_tensor) -> (batch_size, seq_len, d_model)
        """
        batch_size, n_head, seq_len, d_tensor = tensor.shape 

        tensor = tensor.transpose(1, 2).reshape(batch_size, seq_len, n_head * d_tensor)
        
        return tensor 

    def forward(self, x): 

        # apply Wq, Wk, Wv to get q, k, v 
        query = self.split(self.Wq(x))
        key = self.split(self.Wk(x))
        value = self.split(self.Wv(x))

        # apply attention 
        out = self.attention(query, key, value)

        out = self.concat(out)

        out = self.Wconcat(out)

        return out
        

In [20]:
## Define FeedForward Network 
class FeedForwardBlock(nn.Module): 
    def __init__(self, d_model, ffn_hidden, drop_prob=0.1): 
        super().__init__()
        self.linear1 = nn.Linear(d_model, ffn_hidden)
        self.linear2 = nn.Linear(ffn_hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x): 
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x 

In [21]:
# Define Transformer EncoderBlock and Encoder 
class EncoderBlock(nn.Module): 
    def __init__(self, n_head, d_model, ffn_hidden, drop_prob=0.1): 
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.multihead_attn = MultiheadAttentionBlock(n_head, d_model)
        self.ffn = FeedForwardBlock(d_model, ffn_hidden, drop_prob)
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x):
        residual = x 
        
        x = self.norm(x)
        x = self.multihead_attn(x)

        x = self.dropout1(x + residual)
        
        residual = x 

        x = self.norm(x)
        x = self.ffn(x)

        x = self.dropout2(x + residual)

        return x
        
class Encoder(nn.Module): 
    def __init__(self, image_size, patch_size, in_channels, n_head, d_model, ffn_hidden, n_layers, device,  drop_prob=0.1):
        super().__init__()
        self.emb = TransformerEmbedding(image_size, patch_size, in_channels, d_model, drop_prob, device)
        self.layers = nn.ModuleList([EncoderBlock(n_head, d_model, ffn_hidden, drop_prob) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x): 
        x = self.emb(x)

        for layer in self.layers: 
            x = layer(x)

        x = self.norm(x)
        return x  

In [22]:
## Define MLP head for final classification 
class MLPHead(nn.Module): 
    def __init__(self, d_model, class_num, mlp_hidden, drop_prob): 
        super().__init__()
        self.linear1 = nn.Linear(d_model, mlp_hidden)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(mlp_hidden, class_num)
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x): 
        x = self.linear1(x) 
        x = self.gelu(x) 

        x = self.linear2(x)
        x = self.dropout(x)
        return x

In [25]:
class ViT(nn.Module): 
    """
    ViT Model 
    """
    def __init__(self, image_size, patch_size, in_channels, n_head, d_model, ffn_hidden, mlp_hidden, n_layers, class_num, device, drop_prob=0.1): 
        super().__init__()
        seq_len = (image_size // patch_size) ** 2 + 1
        self.encoder = Encoder(image_size, patch_size, in_channels, n_head, d_model, ffn_hidden, n_layers, device, drop_prob)
        self.mlp_head = MLPHead(d_model, class_num, mlp_hidden, drop_prob)

    def forward(self, img):
        x = self.encoder(img)

        logits = self.mlp_head(x[:, 0, :])

        return logits