In [22]:
import torch
import torch.nn as nn
from torchinfo import summary

In [23]:
class ImagePatcher(nn.Module):
    def __init__(self, image_channels, patch_size, embedding_dimension):
        '''
        Initialize ImagePatcher class.

        This helper class creates patches from the input image and transforms them into embeddings.
        '''
        super().__init__()

        '''
        Initialize a convolutional layer with the number of output channels equal to the embedding dimension, 
        and kernel size and stride both equal to the patch size. This layer helps in breaking the image into patches 
        and transforming each patch into embeddings.
        '''
        self.conv = nn.Conv2d(image_channels, embedding_dimension, kernel_size=patch_size, stride=patch_size)

        '''
        Flatten the 2D patches created by the convolutional layer into 1D vectors. 
        The flattening starts from the 2nd dimension to the 3rd.
        '''
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)

    def forward(self, input_image):
        '''
        Forward propagation for the ImagePatcher module. 

        The input image is passed through the convolutional layer to create patches, which are then flattened.
        The output tensor is then permuted to get the right shape for the transformer.
        '''
        return self.flatten(self.conv(input_image)).permute(0, 2, 1)


In [24]:
class VisionTransformer(nn.Module): 
    def __init__(self, 
                 num_classes,
                 image_size=224, 
                 image_channels=3, 
                 patch_size=16, 
                 embed_dimension=768, 
                 dropout=0.1, 
                 mlp_size=3072, 
                 num_transformer_layers=12, 
                 num_heads=12
                 ):
        super().__init__()

        '''
        Ensure that the image_size is divisible by patch_size. This is required to ensure that the image can be 
        broken down into evenly sized patches.
        '''
        assert image_size % patch_size == 0, "image_size must be divisible by patch_size"

        '''
        Initialize the ImagePatcher module which extracts image patches and projects them into the embedding space. 
        This module takes the number of image channels, the patch size and the embedding dimension as input.
        '''
        self.patch_embedding = ImagePatcher(image_channels, patch_size, embed_dimension)

        '''
        Define the class token which is used as the first token of the sequence. The class token is initialized with
        a random tensor that has the shape (1, 1, embed_dimension) and requires gradient.
        '''
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dimension), requires_grad=True)

        '''
        Compute the total number of patches by dividing the image size by the patch size and squaring the result. 
        The total number of patches represents the number of image patches we can extract from the input image.
        '''
        num_patches = (image_size // patch_size) ** 2

        '''
        Define the positional embeddings that will be added to the patch embeddings. These embeddings represent the 
        relative or absolute position of the patches in the input image.
        '''
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dimension), requires_grad=True)

        '''
        Define a dropout layer with the provided dropout rate. Dropout is a regularization technique for reducing overfitting.
        '''
        self.dropout = nn.Dropout(dropout)

        '''
        Initialize the transformer encoder. The transformer consists of multiple layers of multi-head self-attention 
        and feed-forward neural networks. The size of the feed-forward neural networks, the number of attention heads 
        and the number of layers are configurable.
        '''
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                embed_dimension, num_heads, mlp_size, batch_first=True, norm_first=True
            ), num_transformer_layers
        )

        '''
        Define the final classification head, which is a simple multi-layer perceptron (MLP) that maps the 
        transformer output to the final class logits.
        '''
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dimension),
            nn.Linear(embed_dimension, num_classes)
        )
        
    def forward(self, x):
        '''
        Determine the batch size from the shape of the input tensor.
        '''
        batch_size = x.shape[0]

        '''
        Apply the patch embedding module to the input image to get the patch embeddings.
        '''
        x = self.patch_embedding(x)

        '''
        Add the class token at the beginning of each sequence in the batch.
        '''
        x = torch.cat((self.class_token.expand(batch_size, -1, -1), x), dim=1)

        '''
        Add the positional embeddings to the patch embeddings.
        '''
        x += self.positional_embedding

        '''
        Apply the dropout layer to the embeddings.
        '''
        x = self.dropout(x)

        '''
        Feed the sequence of embeddings into the transformer.
        '''
        x = self.transformer(x)

        '''
        Take the first token (class token) of each sequence, feed it through the classification head to obtain 
        the final class probabilities.
        '''
        return self.mlp_head(x[:, 0])


In [25]:
class_names = ["tiger", "bee", "dog"]

demo_images = torch.randn(32, 3, 224, 224)



vit = VisionTransformer(num_classes=len(class_names))
vit(demo_images)




tensor([[-0.3264, -0.9058,  0.1786],
        [-0.1934, -0.7492, -0.0589],
        [-0.4165, -0.9355, -0.1744],
        [-0.0490, -0.6681,  0.1470],
        [-0.1315, -0.7534, -0.1728],
        [-0.3976, -1.0031, -0.3422],
        [-0.4002, -0.7935,  0.2675],
        [-0.1943, -1.1838, -0.1940],
        [-0.5635, -0.4801, -0.1250],
        [-0.5193, -0.8771, -0.0634],
        [-0.2233, -0.9906,  0.1719],
        [-0.3281, -1.1864,  0.1244],
        [-0.3698, -0.6208, -0.1982],
        [-0.4077, -0.7038,  0.2094],
        [-0.5601, -0.8584, -0.0521],
        [-0.3592, -0.7675, -0.1853],
        [-0.3318, -0.8554,  0.0860],
        [-0.3778, -0.6205, -0.2744],
        [-0.0699, -1.1418, -0.1092],
        [-0.4991, -0.7855, -0.0435],
        [-0.2633, -0.9802,  0.0239],
        [-0.3655, -1.0466, -0.0472],
        [-0.3331, -0.7122, -0.1524],
        [-0.3538, -0.9178,  0.0150],
        [-0.4726, -0.9563, -0.1842],
        [-0.4875, -0.9477, -0.1244],
        [-0.2309, -0.8826, -0.3113],
 

In [26]:
summary(model=vit, input_size=demo_images.shape)

Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [32, 3]                   152,064
├─ImagePatcher: 1-1                           [32, 196, 768]            --
│    └─Conv2d: 2-1                            [32, 768, 14, 14]         590,592
│    └─Flatten: 2-2                           [32, 768, 196]            --
├─Dropout: 1-2                                [32, 197, 768]            --
├─TransformerEncoder: 1-3                     [32, 197, 768]            --
│    └─ModuleList: 2-3                        --                        --
│    │    └─TransformerEncoderLayer: 3-1      [32, 197, 768]            7,087,872
│    │    └─TransformerEncoderLayer: 3-2      [32, 197, 768]            7,087,872
│    │    └─TransformerEncoderLayer: 3-3      [32, 197, 768]            7,087,872
│    │    └─TransformerEncoderLayer: 3-4      [32, 197, 768]            7,087,872
│    │    └─TransformerEncoderLayer: 3-5      [32, 197, 7