In [12]:
# official code (JAX) from several Google transformer papers: https://github.com/google-research/vision_transformer
# code taken from mildlyoverfitted's tutorial: https://www.youtube.com/watch?v=ovB0ddFtzzA&ab_channel=mildlyoverfitted
# PyTorch code used in tutorial with pretrained weights: https://github.com/huggingface/pytorch-image-models

import torch # pytorch 2.0.1 (https://pytorch.org/get-started/pytorch-2.0/)
import torch.nn as nn    

In [15]:
class PatchEmbed(nn.Module):
    # splits each image into linear projections, or patches, so that the transformer can learn from them
    # PARAMS:
        # img_size: size of input image
        # patch_size: size of each patch
        # channels: num of input channels
        # embed_dim: embedding dimension
    
    # ATTRIBUTES:
        # patches: num of patches per image
        # proj (nn.Conv2d): convolutional layer that splits image into patches and embeds
    
    def __init__(self, img_size, patch_size, channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(
            channels, 
            embed_dim, 
            kernel_size = patch_size, 
            stride = patch_size) # both kernel size and stride are equal to the patch size 
                                 # so that there will never be overlapping patches
    def forward(self, x):
        # run a forward pass. a tensor simply represents a batch of images.
        # PARAMS:
            # x (torch.Tensor): shape (n_samples, channels, img_size, img_size) for a square
        
        # OUTPUT: a 3D tensor representing the set of resulting patches
            # torch.Tensor: shape (n_samples, n_patches, embed_dim)
        
        x = self.proj(
                x # by running the input tensor through the Conv2d layer, we will get a 4D tensor
        ) # shape (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2) # flattens the tensor from the 2nd axis onward
                         # eg., (n_patches ** 0.5) * (n_patches ** 0.5) = n_patches
        x = x.transpose(1, 2) # swap 1st and 2nd axes
    
        return x
        

class Attention(nn.Module):
    # mechanism for generating attention matrix
    # PARAMS:
        # token_dim: the input and output dimensions of per token features
        # heads: number of attention heads
        # qkv_bias: bool, if a bias value is included in the query, key, and value (qkv) projections
        # attn_p: Dropout probability (ratio) applied to the qkv tensors
        # proj_p: Dropout probability applied to the output tensor
    
    # ATTRIBUTES:
        # scale: normalizing constant for dot product (attention matrix)
        # qkv (nn.Linear): linear projection for qkv
        # proj (nn.Linear): linear mapping that takes concatenated attention matrix as input and maps into a new space
        # attn_drop, proj_drop: Dropout layers for qkv and output
    
    def __init__(self, token_dim, heads=12, qkv_bias=True, attn_p=0, proj_p=0):
        super().__init__()
        self.heads = heads
        self.token_dim = token_dim
        self.head_dim = token_dim // heads
        # head dimensions are specified in this way so that once the resulting attention heads are concantenated,
        # the new tensor will have the same dimensions as the input
        
        self.scale = self.head_dim ** -0.5 
        # this scaling value comes from the "Attention is All you Need" paper
        # its purpose is to prevent very large values from being fed into the softmax layer,
        # which would otherwise cause small gradients
        
        self.qkv = nn.Linear(token_dim, token_dim * 3, bias = qkv_bias)
        # linear mapping that will take token embedding as an input and produce qkv values
        self.proj = nn.Linear(token_dim, token_dim)
        # linear mapping that takes concatenated heads as an input and maps into new space
        
        self.attn_drop = nn.Dropout(attn_p)
        self.proj_drop = nn.Dropout(proj_p)
    
    def forward(self, x):
        # run forward pass. the +1 to n_patches is to ensure that the class token of each sample
        # is always included as the first token in the sequence
        # PARAMS:
            # x (torch.Tensor): shape (samples, patches + 1, dim)
        
        # OUTPUT (same shape):
            # torch.Tensor: shape (samples, patches + 1, dim)
        
        samples, tokens, dim = x.shape
        
        if dim != self.dim: # check whether input embedding dimension matches declared dimension in constructor
            raise ValueError
        
        qkv = self.qkv(x) # (samples, patches + 1, 3 * dim)
        
        # to be continued: learning how linear layer behaves in 