In [1]:
import torch
import torch.nn as nn
from torchinfo import summary

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, 
                 input_channel = 3, 
                 patch_size = 16,
                 embedding_size = 768):
        super().__init__()

        self.patch_size = patch_size
        # create a layer to convert an image into patches

        self.patcher = nn.Conv2d(in_channels= input_channel,
                                 out_channels= embedding_size,
                                 kernel_size = patch_size,
                                 stride = patch_size, 
                                 padding = 0)
        
        self.flatten = nn.Flatten(start_dim  = 2, end_dim = 3)

    def forward(self, x):
        # x shape: [batch_size, 3, 224, 224]
        image_resolution = x.shape[-1]
        assert image_resolution % self.patch_size == 0, f"Input image size must be divisble by patch size"

        x = self.patcher(x) # [batch_size, 768, 14, 14]
        x = self.flatten(x) # [batch_size, 768, 196]
        
        return x.permute(0, 2, 1) # [batch_size, 196, 768]
    

image = torch.randn((32,3,224,224))
model = PatchEmbedding()

assert model(image).shape == (32, 196, 768)


In [4]:
transformer_encoder_layer = nn.TransformerEncoderLayer(d_model = 768,
                                                       nhead = 12,
                                                       dim_feedforward= 3072,
                                                       dropout=0.1,
                                                       activation='gelu',
                                                       batch_first=True,
                                                       norm_first=True)

transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer, num_layers= 12)

In [5]:
class ViT(nn.Module):
    def __init__(self, 
                 img_size = 224, 
                 num_channels = 3, 
                 patch_size = 16, 
                 embedding_dim = 768,
                 dropout = 0.1,
                 mlp_size = 3072,
                 num_transformer_layer = 12, 
                 num_heads = 12,
                 num_classes = 1000):
        super().__init__()

        assert img_size % patch_size == 0, 'Image size must be divisible by patch size'

        ## create patch embedding 
        self.path_embedding = PatchEmbedding(input_channel = num_channels,
                                             patch_size= patch_size,
                                             embedding_size= embedding_dim)
        
        ## create class token 
        self.class_token = nn.Parameter(torch.randn((1,1,embedding_dim), requires_grad= True))

        ## create positional embedding
        num_patches = (img_size * img_size) // (patch_size**2) # N  = HW / P^2

        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim))
        self.embedding_dropout = nn.Dropout(dropout)

        ## create encoder layer

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer = nn.TransformerEncoderLayer(d_model = embedding_dim,
                                                                                                           nhead = num_heads,
                                                                                                           dim_feedforward= mlp_size,
                                                                                                           dropout= dropout,
                                                                                                           activation='gelu',
                                                                                                           batch_first=True,
                                                                                                           norm_first=True), 
                                                    num_layers= num_transformer_layer)
        
        ## create MLP head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, num_classes)
        )

    def forward(self, x):

        # x shape: [batch_size, 3, 224, 224]
        batch_size = x.shape[0]

        x = self.path_embedding(x) # [batch_size, 196, 768]

        class_token = self.class_token.expand(batch_size, -1, -1) #[batch_size, 1, 768]

        x = torch.cat( (class_token, x), dim = 1) # [batch_size, 197, 768]

        x = self.positional_embedding + x # [batch_size, 197, 768]

        x = self.embedding_dropout(x)     

        x = self.transformer_encoder(x)   # [batch_size, 197, 768]
        
        x = self.mlp_head(x[:, 0])        # x[:, 0].shape = (batch_size, 768)

        return x                          # [batch_size, 1000]
    

model = ViT()
img = torch.randn((32, 3, 224, 224))
assert model(img).shape == (32, 1000)

In [6]:
summary(model=model, 
        input_size=(1, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
ViT (ViT)                                          [1, 3, 224, 224]     [1, 1000]            152,064              True
├─PatchEmbedding (path_embedding)                  [1, 3, 224, 224]     [1, 196, 768]        --                   True
│    └─Conv2d (patcher)                            [1, 3, 224, 224]     [1, 768, 14, 14]     590,592              True
│    └─Flatten (flatten)                           [1, 768, 14, 14]     [1, 768, 196]        --                   --
├─Dropout (embedding_dropout)                      [1, 197, 768]        [1, 197, 768]        --                   --
├─TransformerEncoder (transformer_encoder)         [1, 197, 768]        [1, 197, 768]        --                   True
│    └─ModuleList (layers)                         --                   --                   --                   True
│    │    └─TransformerEncoderLayer (0)        