In [1]:
import torch
import torch.nn as nn

from einops import rearrange # it is a function

from einops.layers.torch import Rearrange # it is used as a layer



In [2]:
def pair(inp):
    return inp if isinstance(inp, tuple) else (inp, inp)

class Attention(nn.Module):
    """
    Multi-Head Self-Attention block of each layer
    Args:
        dim: Input dimension to MHSA block, i.e dimension of each token corresponding to each patch
        heads: Total attention heads 
        dim_head: Dimension of Q, K, V vector
    """
    def __init__(self, dim, heads, dim_head):
        """
        - Recall, each self-attention layer(head) in a MHSA unit has its own Q, K, V vector.
        - But we have several heads, if heads=8, we need 8 such vectors for each Q, K, V.
        - We love vectorised implementation, so instead of maintaining 8 different vectors of dimension dim_head for each Q, K, V
          we are maintaining just three vector(for Q, K, V) of dimenstion dim * dim_head, 
          which you should understand as concatenating each of the 8 Q, K, V vectors of each heads(=8 here) as a single vector for Q, K, V each. 
        """
        super().__init__()
        # Dimension of vector (Q/K/V) including all heads
        inner_dim = dim_head * heads 
        self.heads = heads

        # used for dividing product of Q and K vector. Refer Self attention mechanism: https://jalammar.github.io/illustrated-transformer/
        self.scale = dim_head ** 0.5 

        self.norm = nn.LayerNorm(dim)

        # dim=-1 because after multiplying Q and K, last dimension in that product corresponds to columns.
        # We need to apply softmax across columns since each column corresponds to different token, and after 
        # doing softmax, it essentialy shows the probability of a token attenting to each of the other token
        self.attend = nn.Softmax(dim=-1) 

        # 3 is multiplied because as I said, we need 3 vectors of each Q, K, V (each is a concatenated version for entire heads, hence innerdim)
        # We dont need bias term, i.e "a = Wx" is enough
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        # To convert each of the Q, K, V vector back to encoder's dimension, so that it can be processed by FF block.
        # Again bias is insignificant/irrelevant here
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x):
        # If say crop size is 256 by 256 by 3
        # say Patch size is 8 by 8 by 3, so h and w = 265/8 = 32 
        # say Total batch is 15
        # Dimension of x -> (15, h*w , dim)
        # assume total heads=8
        x = self.norm(x)

        # Converts dimension from (15, h*w, dim) to (3, 15, h*w, inner_dim) as it is chunked into 3. 
        # Dim=-1 because actual output of to_qkv returns dimension (15, h*w, inner_dim * 3), axis=-1 corresponds to last dimension
        qkv = self.to_qkv(x).chunk(3, dim=-1)

        # Each of the 3 chunk of dimension (15, h*w, inner_dim) is passed into the rearrange function below
        # n: total patches per image, i.e h*w
        # (h d): inner_dim = heads * dim_head
        # below einops operation seperates "concatenated" vector for each heads
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        # Dimension(as per above assumed numbers) of
        # q: (15, 8, h*w, dim_head)
        # k.transpose(-1,-2): (15, 8, dim_head, h*w)
        # dot_prod: (15, 8, h*w, h*w)
        dot_prod = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        print('dot_prod :', dot_prod.shape)
        # dimension: (15, 8, h*w, h*w)
        attn_weights = self.attend(dot_prod)

        # attn_scores dimension: (15, 8, h*w, dim_head)
        # dimension of v: (15, 8, h*w, dim_head)
        attn_scores = torch.matmul(attn_weights, v)

        # Output dimension: (15, h*w, inner_dim)
        output = rearrange(attn_scores, 'b h n d -> b n (h d)')

        # Output dimension: (15, h*w, dim)
        output = self.to_out(output)
        return output

class FeedForward(nn.Module):
    """
    Feed Forward(FF) block(Essentially MLP) inside each encoder block
    Args:
        dim: Input dimension to FF layer. Which is the internal dimension of each patch
        mlp_dim: Hidden layer dimension of this MLP 
    """
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.layer = nn.Sequential(
            # Layer norm is applied here because in the forward function of transformer only residual connection is performed, layernorm is performed here instead
            nn.LayerNorm(dim),
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim)
        )
    def forward(self, x):
        return self.layer(x)



class Transformer(nn.Module):
    """
    Encoder Block of ViT
    Args:
        depth: Total layers/blocks in encoder
        heads: Total self-attention blocks in a single MHSA unit of a single layer
        mlp_dim: Hidden layer dimension of FeedForward network in each layer/block in encoder
        dim_head: Dimension of Q,K,V vector
        dim: Input dimension of each token corresponding to each patch
    """
    def __init__(self, dim, dim_head, depth, heads, mlp_dim):
        super().__init__()

        # All layers/blocks in encoder. Each layer contains a Multi-head self-attention block and MLP block
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, dim_head, heads),
                FeedForward(dim, mlp_dim),
            ]))
        
        # TODO: For quite easier view
        #self.norm = nn.LayerNorm(dim) 
        #x = attn(self.norm(x)) + x 
        #x = attn(self.norm(x)) + x 

    def forward(self, x):
        for attn, ff in self.layers:
            # Note: Layer norm is applied inside attn(x) and ff(x). Only residual connection is performed below. Refer architecture figure in original paper.
            # Adding x for residual connection 
            x = attn(x) + x 

            # addition of MHSA block's output and its input is passed as input to ff layer
            x = ff(x) + x 
        return x

In [3]:
class ViT(nn.Module):
    """
    Args:
        image_size: Crop dimension say 256 by 256
        patch_size: Patchification of each crop, say 8 by 8
        num_class: Total classes for classification
        dim: Encoder input dimension. Each patch vector is converted to dim dimensional vector
        depth: Total number of encoder layers
        heads: Total number of self-attention layers in Multi-head Self-attention block of each encoder layer
        mlp_dim: Hidden layer dimension of FeedForward network in each layer/block in encoder
        channels: Input image channel. 3 for RGB, 1 for Grey scale.
        dim_head: Dimension of Q, K, V vector used for Self-Attention. Q and K are usually of same dimension, V can be arbitary.
    """
    def __init__(self, image_size, patch_size, num_class, dim, depth, heads, mlp_dim, channels=3, dim_head=64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)


        # For now forcing the crop image dimension to be divisible by patch size
        # TODO: Create overlapping patches when they are not divisible

        assert image_height % patch_height == 0 and image_width % patch_width == 0, "Image dimension should be divisible by patch size"

        # s = HW/p^2 (refer paper, where H, W are height and width of crop image, p is patch dimension-assuming it to be square)
        num_patches = (image_height * image_width) // (patch_height * patch_width)
        
        # Vectorising the patch along its channels
        patch_dim = patch_height * patch_width * channels

        # Converting the vectorised patch to a fixed dimension so that it can be passed into encoder as input
        self.to_patch_embedding = nn.Sequential(
            # refer Einops: https://einops.rocks/
            # Idea is simple, if einops isn't used it is going to be a headache to maintain the dimensions, matrix multiplications etc as we are dealing with lot of different entities like image crop, image patches in batches
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1=patch_height, p2=patch_width),

            # Lets say original input be (15, 3, 256, 256) and p1=p2=16
            # After rearranging (15, 32, 32, 8 * 8 * 3)
            nn.LayerNorm(patch_dim), # normalising across the channels (now 8 * 8 * 3)

            nn.Linear(patch_dim, dim), # Converting the vectorised patches to fixed dimension to get processed by encoder
            nn.LayerNorm(dim)
        )

        # Creating Encoder with multiple layers
        self.transformer = Transformer(dim, dim_head, depth, heads, mlp_dim)

        # No hidden layer in the MLP used for classification outside of encoder
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_class)
        )

    def forward(self, image_crop):
        """
        Assuming batch size=15
        """
        x = self.to_patch_embedding(image_crop)
        # pe = ...
        
        # Dimension of x changed to: (15, 32*32, 8*8*3)
        x = rearrange(x, 'b ... d -> b (...) d') 

        # Dimension of x returned from transformer: (15, h*w, dim)
        x = self.transformer(x)

        # In a updated version of ViT (by same authors, referenced in README.md), 
        # class token is removed and instead the classification is made by taking mean 
        # over tokens for all patches.
        # Dimension of x now: (15, dim)
        x = x.mean(dim=1)
        
        # If total classes=2 (say binary classification problem)
        # then dimension of x: (15, 2)
        x = self.linear_head(x)
        return x

In [5]:
import numpy as np
input = np.random.rand(15, 3, 256, 256)
input = torch.from_numpy(input).to(torch.float32)

model = ViT(
    image_size=256,
    patch_size=8,
    num_class=2,
    dim=128,
    depth=4,
    heads=8,
    mlp_dim=1024,
    channels=3,
    dim_head=64)

In [6]:
model(input).shape

dot_prod : torch.Size([15, 64, 1024, 1024])
dot_prod : torch.Size([15, 64, 1024, 1024])
dot_prod : torch.Size([15, 64, 1024, 1024])
dot_prod : torch.Size([15, 64, 1024, 1024])


torch.Size([15, 2])