In [1]:
import torch
from torch import nn
import torchinfo

In [3]:
dummy_image = nn.Parameter(torch.randn(1, 3, 224, 224))

In [6]:
class PatchEmbedding(nn.Module):
    """Takes in a 2D image and creates a 1D sequence learnable embedding vector

        in_channels (int) : number of color channels in the image // default : 2
        out_channels (int) : size of patches of images // default : 16
        embedding_dim (int) : size of embedding to turn every image into // default : 16*16*3
    """
    def __init__(self,
                 in_channels : int=3,
                 patch_size : int=16,
                 embedding_dim : int=768):
        super().__init__()
        # 1. Conv layer to turn the image into patches
        self.patch_size = patch_size

        self.conv_patch_layer = nn.Conv2d(in_channels=in_channels,
                                          out_channels=embedding_dim,
                                          kernel_size=patch_size,
                                          stride=patch_size,
                                          padding=0
                                        )
        
        # 2. Flatten to make a 1D representation
        self.flatten = nn.Flatten(start_dim=2,
                                  end_dim=3)
        
    def forward(self, x):

        assert x.shape[-1] % self.patch_size == 0, f"Input image must be divisible by patch size, patch_size -> {self.patch_size} and input image dims -> {x.shape[-1]}"
        # Generate patches
        patched = self.conv_patch_layer(x)
        print(f'After creating patches : {patched.shape}')

        # Generate flattened 1D representation
        flattened = self.flatten(patched)
        print(f'After flattening : {flattened.shape}')

        return flattened.permute(0, 2, 1)
        

In [7]:
patchembedding = PatchEmbedding(in_channels=3,
                                patch_size=16,
                                embedding_dim=16*16*3)
patched = patchembedding(dummy_image)
print(f"After patching and flattening size : {patched.shape} -> (batch, num_patches, embedding_dim)")

After creating patches : torch.Size([1, 768, 14, 14])
After flattening : torch.Size([1, 768, 196])
After patching and flattening size : torch.Size([1, 196, 768])


In [8]:
# Putting it all together and creating ViT


class ViT(nn.Module):
  def __init__(self,
               img_size=224,
               num_channels=3,
               patch_size=16,
               embedding_dim=768,
               dropout=0.1,
               mlp_size=3072,
               num_transformer_layers=12,
               num_heads=12,
               num_classes=1000):
    super().__init__()

    # Assert image size

    assert img_size % patch_size == 0, "Image size must be divisible by patch size."
    # 1. Create path embedding
    self.patch_embedding = PatchEmbedding(in_channels=num_channels,
                                          patch_size=patch_size,
                                          embedding_dim=embedding_dim)
    # 2. Create class tokens
    self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim),
                                    requires_grad=True)
    # 3. Create positional embedding
    num_patches = (img_size * img_size) // patch_size**2  # N = H*W/P^2
    self.positional_embedding = nn.Parameter(torch.randn(1, num_patches+1, embedding_dim)) # Positional embeddings are used to keep track of where the patches appear in a sequence

    # 4. Create patch + positional embedding dropout
    self.embedding_dropout = nn.Dropout(p=dropout)

    # 5. Create Transformer Encoder layer (single)
    # self.transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim,
    #                                                             nhead=num_heads,
    #                                                             dim_feedforward=mlp_size,
    #                                                             activation='gelu',
    #                                                             batch_first=True,
    #                                                             norm_first=True
    #                                                           )
    # 6. Create stack Tranformer Encoder Layers
    self.transformer_encoder= nn.TransformerEncoder(encoder_layer=nn.TransformerEncoderLayer(d_model=embedding_dim,
                                                    nhead=num_heads,
                                                    dim_feedforward=mlp_size,
                                                    activation='gelu',
                                                    batch_first=True,
                                                    norm_first=True,)
                                                    ,num_layers=num_transformer_layers)
    # 7. Create MLP head
    self.mlp_head = nn.Sequential(
        nn.LayerNorm(normalized_shape=embedding_dim),
        nn.Linear(in_features=embedding_dim,
                  out_features=num_classes)
    )

  def forward(self, x):
    # Get some dimensions from x
    batch_size = x.shape[0]

    # Create the path embedding
    x = self.patch_embedding(x)

    # First expand the class token across the batch size
    class_token = self.class_token.expand(batch_size, -1, -1) # "-1" means infer the dimension

    # Prepend the class token to the patch embedding
    x = torch.cat((class_token, x), dim=1)

    # Add the positional embedding to patch embedding with class token
    x = self.positional_embedding + x

    # Dropout on patch + positional embedding
    x = self.embedding_dropout(x)

    # Pass embedding through transformer encoder stack
    x = self.transformer_encoder(x)

    # Pass the 0th index of x through MLP head
    x = self.mlp_head(x[:, 0])

    return x

In [18]:
batch_size = 32
num_classes = 1000
vit = ViT(num_classes=3)
img = nn.Parameter(torch.randn(1, 3, 224, 224))
vit(dummy_image)



After creating patches : torch.Size([1, 768, 14, 14])
After flattening : torch.Size([1, 768, 196])


tensor([[ 0.1849, -0.0814, -0.3336]], grad_fn=<AddmmBackward0>)