<a href="https://colab.research.google.com/github/ugomezjr/Image-Recognition-with-Transformers/blob/main/pytorch_vision_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from torch import nn
import torch
import matplotlib.pyplot as plt
try:
  from torchinfo import summary
except:
  print("[INFO] Couldn't find torchinfo... installing it.")
  !pip install -q torchinfo
  from torchinfo import summary

[INFO] Couldn't find torchinfo... installing it.


## Equation 1: Patch + Position Embedding

Split an image into fixed-sized patches, linearly embed each of them, add position embeddings.

\begin{aligned}
\mathbf{z}_{0} &=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right]+\mathbf{E}_{\text {pos }}, & & \mathbf{E} \in \mathbb{R}^{\left(P^{2} \cdot C\right) \times D}, \mathbf{E}_{\text {pos }} \in \mathbb{R}^{(N+1) \times D}
\end{aligned}

In [None]:
# [0] Calculate patch embedding input and output shapes

# Create variables to mimic terms
height = 224 # H (height)
width = 224 # W (width)
color_channels = 3 # C (color channels)
patch_size = 16 # P (image patch resolution)

# N = HW/P^2 (number of patches)
number_of_patches = int((height * width) / patch_size**2)

# Input shape (training resolution)
embedding_layer_input_shape = (height, width, color_channels)

# Output shape (sequence of flattened 2D patches), xp = N, (P^2 * C)
embedding_layer_output_shape = (number_of_patches, patch_size**2 * color_channels) # (number of patches, embedding dimension)

In [None]:
print(f"Number of Patches: {number_of_patches}")

Number of Patches: 196


In [None]:
print(f"Input shape: {embedding_layer_input_shape}\nOutput shape: {embedding_layer_output_shape}")

Input shape: (224, 224, 3)
Output shape: (196, 768)


### Hybrid Architecture.

The input sequence can be formed from feature maps of a CNN.  

Patch embedding projection E is applied to patches extracted from a CNN feature map. 

Input sequence is obtained by simply flattening the spatial dimensions of the feature map and projecting to the Transformer dimension.

In [None]:
# [1] Split an image into fixed-sized patches

image = torch.randn(3, 224, 224)

# Create the Conv2d layer with hyperparameters from the ViT paper

# Set "kernel_size" and "stride" equal to "patch_size" to effectivley get a layer that 
# splits our image into patches and creates a learnable embedding (referred to as a "Linear Projection" 
# in the ViT paper) of each patch.

conv2d = nn.Conv2d(in_channels=color_channels, # number of color channels
                   out_channels=(patch_size**2 * color_channels), # Hidden size D, this is the embedding size (768)
                   kernel_size=patch_size,
                   stride=patch_size,
                   padding=0)

In [None]:
output = conv2d(image.unsqueeze(0)) # add a single batch dimension
print(f"Input shape (2D image): {image.shape}\nOutput shape (flattened 2D Patches): {output.shape}")

Input shape (2D image): torch.Size([3, 224, 224])
Output shape (flattened 2D Patches): torch.Size([1, 768, 14, 14])


In [None]:
output[0, 0, :, :]

tensor([[-0.2617,  0.5375,  0.5653,  0.0818, -0.7199,  1.2233,  0.9079, -1.1796,
          1.3807,  0.7522,  0.1465, -0.1028, -0.4151, -0.9094],
        [ 0.4685,  1.0137,  0.7614,  0.4743,  0.8700,  0.6178, -0.5635,  0.2802,
         -0.5641,  0.2237, -0.1927,  0.8294, -0.3786,  0.1160],
        [ 0.2053, -0.7453, -0.3989, -0.1740,  0.3718, -0.7562,  0.4569, -0.0475,
          0.3184,  0.4324,  1.0421,  0.4952,  0.1185, -0.1388],
        [ 0.3724, -0.3266, -1.2162,  0.5713,  0.0296, -0.0210,  0.9155, -0.9453,
         -0.4576, -0.3443,  0.2580,  0.5521, -0.4516, -0.5196],
        [ 0.0482, -0.0529, -0.3017, -0.3553, -0.4866, -0.3150,  0.0473,  0.2393,
         -0.6070,  0.0756,  0.5224,  0.2691,  0.5255,  0.0338],
        [ 0.1768,  0.2348, -0.2976, -0.1317, -0.7279,  0.3120,  0.8582,  0.7593,
         -0.2002,  0.5486,  0.5005, -1.1663,  0.4659,  0.0623],
        [-0.7629, -0.5654,  0.0653,  0.2391, -0.2418, -0.5778,  0.7394,  0.7111,
          0.2660, -0.2707,  0.0930,  0.0960, -0.3

**Desired Output(1D sequence of flattened 2D patches):** (196, 768) -> (Number of Patches, Embedding Dimension) -> ${N \times\left(P^{2} \cdot C\right)}$

In [None]:
# [2] Flattening the patch embedding

# Create flatten layer (14*14 == 196)
flatten = nn.Flatten(start_dim=2, # flatten feature map height (dimension 2)
                     end_dim=3) # flatten feature map width (dimension 3)

# (number of patches, embedding dimension)
print(f"Output shape (1D sequence of flattened 2D patches): {flatten(output).permute(0, 2, 1).shape}")

Output shape (1D sequence of flattened 2D patches): torch.Size([1, 196, 768])


In [None]:
flatten(output)[0, 0, :]

tensor([-0.2617,  0.5375,  0.5653,  0.0818, -0.7199,  1.2233,  0.9079, -1.1796,
         1.3807,  0.7522,  0.1465, -0.1028, -0.4151, -0.9094,  0.4685,  1.0137,
         0.7614,  0.4743,  0.8700,  0.6178, -0.5635,  0.2802, -0.5641,  0.2237,
        -0.1927,  0.8294, -0.3786,  0.1160,  0.2053, -0.7453, -0.3989, -0.1740,
         0.3718, -0.7562,  0.4569, -0.0475,  0.3184,  0.4324,  1.0421,  0.4952,
         0.1185, -0.1388,  0.3724, -0.3266, -1.2162,  0.5713,  0.0296, -0.0210,
         0.9155, -0.9453, -0.4576, -0.3443,  0.2580,  0.5521, -0.4516, -0.5196,
         0.0482, -0.0529, -0.3017, -0.3553, -0.4866, -0.3150,  0.0473,  0.2393,
        -0.6070,  0.0756,  0.5224,  0.2691,  0.5255,  0.0338,  0.1768,  0.2348,
        -0.2976, -0.1317, -0.7279,  0.3120,  0.8582,  0.7593, -0.2002,  0.5486,
         0.5005, -1.1663,  0.4659,  0.0623, -0.7629, -0.5654,  0.0653,  0.2391,
        -0.2418, -0.5778,  0.7394,  0.7111,  0.2660, -0.2707,  0.0930,  0.0960,
        -0.3198,  0.3457, -0.1577, -0.09

In [None]:
flatten(output).permute(0, 2, 1)

tensor([[[-0.2617, -0.9701,  0.1095,  ...,  0.0537,  1.0546,  0.1155],
         [ 0.5375,  0.4571,  0.8763,  ...,  0.0760, -0.2131, -0.1085],
         [ 0.5653, -0.5079,  0.6855,  ..., -0.1471,  0.3205, -0.5650],
         ...,
         [-0.1441, -1.3473,  0.1196,  ...,  1.0172,  0.4318,  0.8929],
         [-1.2477, -0.9158, -0.7643,  ...,  1.1499, -0.5391,  0.2153],
         [ 0.2955, -0.6296, -0.1625,  ..., -0.2562,  0.4158, -0.0259]]],
       grad_fn=<PermuteBackward0>)

### Turn ViT patch embedding layer into a PyTorch module. 

1. Create a `PatchEmbedding` class and subclass `nn.Module`.
2. Initialize class parameters, `in_channels=3`, `patch_size=16` and `embedding_dim=768`.
3. Create `nn.Conv2d()` layer.
4. Create `nn.Flatten()` layer to flatten the spatial dimensions of the feature map into a 1D learnable embedding.
5. Define `forward()` method. 
6. Ensure the output shape is consistent with that of the ViT architecture (${N \times\left(P^{2} \cdot C\right)}$)


In [None]:
class PatchEmbedding(nn.Module):
  """Turns a 2D input image into a 1D sequence learnable embedding vector.

  Args: 
    in_channels (int): Number of color channels for the input images. Defaults to 3.
    patch_size (int): Size of patches to convert input image into. Defaults to 16.
    embedding_dim (int): Size of embedding to turn image into. Defaults to 768.
  """
  def __init__(self, 
               in_channels: int=3, 
               patch_size: int=16, 
               embedding_dim: int=768):
    super().__init__()

    self.conv = nn.Conv2d(in_channels=in_channels,
                      out_channels=embedding_dim,
                      kernel_size=patch_size,
                      stride=patch_size,
                      padding=0)
                      
    self.flatten = nn.Flatten(start_dim=2,
                              end_dim=3)
    
  def forward(self, x):
    return self.flatten(self.conv(x)).permute(0, 2, 1)

In [None]:
embedding = PatchEmbedding()
embedding(image.unsqueeze(0)), embedding(image.unsqueeze(0)).shape

(tensor([[[ 0.1902,  0.6967,  0.6295,  ...,  0.1795, -0.0255, -0.0537],
          [ 0.0256,  0.0761, -0.5711,  ..., -1.0734, -0.2553, -0.9181],
          [ 0.4979,  0.1233,  0.2194,  ..., -0.1984,  0.0070, -0.0337],
          ...,
          [-0.1950, -0.4613, -0.1825,  ..., -0.8430,  0.3619, -1.3511],
          [ 0.4200, -0.1371,  0.7837,  ..., -0.3174,  0.5432, -0.5316],
          [-0.2547,  0.0764,  0.3250,  ...,  1.0624,  0.2980,  0.7755]]],
        grad_fn=<PermuteBackward0>), torch.Size([1, 196, 768]))

In [None]:
# Get a summary of the input and outputs of PatchEmbedding
summary(PatchEmbedding(),
        input_size=(1, color_channels, height, width),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
PatchEmbedding (PatchEmbedding)          [1, 3, 224, 224]     [1, 196, 768]        --                   True
├─Conv2d (conv)                          [1, 3, 224, 224]     [1, 768, 14, 14]     590,592              True
├─Flatten (flatten)                      [1, 768, 14, 14]     [1, 768, 196]        --                   --
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
Total mult-adds (M): 115.76
Input size (MB): 0.60
Forward/backward pass size (MB): 1.20
Params size (MB): 2.36
Estimated Total Size (MB): 4.17

In [None]:
# [3] Creating the class token embedding

patch_embedded_image = embedding(image.unsqueeze(0))

# Get batch size and embedding dimension
batch_size = patch_embedded_image.shape[0]
embedding_dim = patch_embedded_image.shape[2] # D (embedding dimension)

# Create the class token embedding as a learnable parameter that shares the same size as the embedding dimension (D)
class_token = nn.Parameter(torch.ones(batch_size, 1, embedding_dim), # [batch size, number of tokens, embedding dimension]
                           requires_grad=True) # ensure embedding is learnable

print(f"{class_token.shape}")

torch.Size([1, 1, 768])


In [None]:
# [4] Prepend a learnable embedding to the sequence of embedded patches
patch_embedding_with_class_token = torch.cat((class_token, patch_embedded_image), 
                                             dim=1)

# (batch size, class token + number of patches, embedding dimension)
print(patch_embedding_with_class_token, patch_embedding_with_class_token.shape)

tensor([[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.1902,  0.6967,  0.6295,  ...,  0.1795, -0.0255, -0.0537],
         [ 0.0256,  0.0761, -0.5711,  ..., -1.0734, -0.2553, -0.9181],
         ...,
         [-0.1950, -0.4613, -0.1825,  ..., -0.8430,  0.3619, -1.3511],
         [ 0.4200, -0.1371,  0.7837,  ..., -0.3174,  0.5432, -0.5316],
         [-0.2547,  0.0764,  0.3250,  ...,  1.0624,  0.2980,  0.7755]]],
       grad_fn=<CatBackward0>) torch.Size([1, 197, 768])


### Position Embedding Shape ($\mathbf{E}_{\text {pos }}$):

\begin{aligned}
\mathbf{E}_{\text {pos }} \in \mathbb{R}^{(N+1) \times D}
\end{aligned}

Where:

*   $N=H W / P^{2}$ is the number of patches.
*   $D$ is the **patch embedding** dimension.



In [None]:
# [5] Create the position embedding

# Create the learnable 1D position embedding
position_embedding = nn.Parameter(torch.ones(batch_size, 
                                             number_of_patches + 1, 
                                             embedding_dim),
                                  requires_grad=True)

# (batch_size, number of patches, embedding dimension)
print(f"Position Embedding shape: {position_embedding.shape}")

Position Embedding shape: torch.Size([1, 197, 768])


In [None]:
# [6] Add position embeddings to the patch embeddings to retain positional information

# Add the position embedding to the class token and patch embedding
patch_and_position_embedding = patch_embedding_with_class_token + position_embedding

print(patch_and_position_embedding)
print(f"Patch + Position Embedding shape: {patch_and_position_embedding.shape}")

tensor([[[ 2.0000,  2.0000,  2.0000,  ...,  2.0000,  2.0000,  2.0000],
         [ 1.1902,  1.6967,  1.6295,  ...,  1.1795,  0.9745,  0.9463],
         [ 1.0256,  1.0761,  0.4289,  ..., -0.0734,  0.7447,  0.0819],
         ...,
         [ 0.8050,  0.5387,  0.8175,  ...,  0.1570,  1.3619, -0.3511],
         [ 1.4200,  0.8629,  1.7837,  ...,  0.6826,  1.5432,  0.4684],
         [ 0.7453,  1.0764,  1.3250,  ...,  2.0624,  1.2980,  1.7755]]],
       grad_fn=<AddBackward0>)
Patch + Position Embedding shape: torch.Size([1, 197, 768])


### From Image to Patch + Position Embedding (Extra learnable [class] embedding)

\begin{aligned}
\mathbf{z}_{0} &=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right]+\mathbf{E}_{\text {pos }}, & & \mathbf{E} \in \mathbb{R}^{\left(P^{2} \cdot C\right) \times D}, \mathbf{E}_{\text {pos }} \in \mathbb{R}^{(N+1) \times D}
\end{aligned}

1. Set the patch size.
2. Get a single image, print its shape and store it's height and width.
3. Add batch dimension to the single image for our `PatchEmbedding` layer.
4. Create a `PatchEmbedding` layer with a `patch_size=16` and `embedding_dim=768`.
5. Pass the single image through `PatchEmbedding` layer to create a sequence of patch embeddings. 

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

patch_size = 16

print(f"Image shape: {image.shape}")

height = image.shape[1]
width = image.shape[2]

patch_embedding = PatchEmbedding() # patch_size=16 and embedding_dim=768 by default. 

embedded_patches = patch_embedding(image.unsqueeze(0)) # add batch dimension

print(f"Embedded Patches shape: {embedded_patches.shape}")

class_tokens = nn.Parameter(torch.ones(batch_size, 1, embedding_dim), 
                            requires_grad=True)

embedded_patches_with_tokens = torch.cat((class_tokens, embedded_patches), 
                             dim=1)

print(f"Embedded Patches with extra learnable [class] embedding shape: {embedded_patches_with_tokens.shape}")

position_embedding = nn.Parameter(torch.ones(batch_size, number_of_patches + 1, embedding_dim), 
                                  requires_grad=True)

patch_and_position_embedding_with_tokens = embedded_patches_with_tokens + position_embedding

print(f"Patch + Position Embedding shape: {patch_and_position_embedding_with_tokens.shape}")

Image shape: torch.Size([3, 224, 224])
Embedded Patches shape: torch.Size([1, 196, 768])
Embedded Patches with extra learnable [class] embedding shape: torch.Size([1, 197, 768])
Patch + Position Embedding shape: torch.Size([1, 197, 768])


## Equation 2: Multi-Head Attention (MSA)

A Multi-Head Attention (MSA) layer wrapped in a LayerNorm (LN) layer with a residual connection.

\begin{aligned}
\mathbf{z}_{\ell}^{\prime} &=\operatorname{MSA}\left(\operatorname{LN}\left(\mathbf{z}_{\ell-1}\right)\right)+\mathbf{z}_{\ell-1}, & & \ell=1 \ldots L
\end{aligned}

Where:

* `Layer Normalization` (LayerNorm) normalizes an input over the last dimension.




### Replicate Multi-Head Attension (MSA) with PyTorch layers

1. Create a class called `MultiheadSelfAttentionBlock` that inherits from `torch.nn.Module`.
2. Initialize the class with hyperparameters from Table 1 of the ViT paper for the ViT-Base model.
3. Create a layer normalization (LN) layer with `torch.nn.LayerNorm()` with the `normalized_shape` parameter the same as our embedding dimension ($D$ from Table 1).
4. Create a multi-head attention (MSA) layer with the appropriate `embed_dim`, `num_heads`, `dropout` and `batch_first` parameters.
5. Create a `forward()` method for our class passing the inputs through the LN layer and MSA layer.

In [None]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, 
               embedding_dim: int=768,
               num_heads: int=12,
               attn_dropout: float=0):
    super().__init__()

    self.layer_norm = nn.LayerNorm(normalized_shape=768)
     
    self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                           num_heads=num_heads,
                                           dropout=attn_dropout,
                                           batch_first=True)
    
  def forward(self, xp):
    xp = self.layer_norm(xp)
    attn_out, _ = self.multihead_attn(query=xp, 
                                      key=xp, 
                                      value=xp,
                                      need_weights=False)
    return attn_out



In [None]:
# Initialize an instance of MSABlock
attention = MultiHeadSelfAttention()

# Forward pass patch + position embedding through MSABlock
print(f"Input shape (Embedded Patches): {patch_and_position_embedding_with_tokens.shape}")
print(f"Output shape (MSA Block): {attention(patch_and_position_embedding_with_tokens).shape}")

Input shape (Embedded Patches): torch.Size([1, 197, 768])
Output shape (MSA Block): torch.Size([1, 197, 768])


## Equation 3: Multilayer Perceptron (MLP)

\begin{aligned}
\mathbf{z}_{\ell} &=\operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{z}_{\ell}^{\prime}\right)\right)+\mathbf{z}_{\ell}^{\prime}, & & \ell=1 \ldots L \\
\end{aligned}

where:

* `Multilayer Perceptron` (MLP) contains two linear layers with a GELU (Gaussian Error Linear Units) non-linearity and Dropout layers.




### Replicating Multilayer Perceptron (MLP) with PyTorch layers

1. Create a class called `MLPBlock` that inherits from `torch.nn.Module`.
2. Initialize the class with hyperparameters from Table 1 and Table 3 of the ViT paper for the ViT-Base model. 
3. Create a layer normalization (LN) layer with `torch.nn.LayerNorm()` with the `normalized_shape` parameter the same as our embedding dimension ($D$ from Table 1).
4. Create a sequential series of MLP layer(s) using `torch.nn.Linear()`, `torch.nn.Dropout()`, and `torch.nn.GELU()` with appropriate hyperparameter values from Table 1 and Table 3.
5. Create a `forward()` method for our class passing in the inputs through the LN layer and MLP layer(s).

In [None]:
class MultilayerPerceptron(nn.Module):
  def __init__(self, 
               embedding_dim: int=768, 
               mlp_size: int=3072, 
               dropout: float=0.1):
    super().__init__()

    self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

    self.mlp = nn.Sequential(
        nn.Linear(in_features=embedding_dim,
                  out_features=mlp_size),
        nn.GELU(),
        nn.Dropout(p=dropout, 
                   inplace=True),
        nn.Linear(in_features=mlp_size,
                  out_features=embedding_dim),
        nn.Dropout(p=dropout,
                   inplace=True)
    )


  def forward(self, x):
    return self.mlp(self.layer_norm(x))

In [None]:
mlp = MultilayerPerceptron()

print(f"Input shape (MSA Block): {attention(patch_and_position_embedding_with_tokens).shape}")
print(f"Output shape (MLP Block): {mlp(attention(patch_and_position_embedding_with_tokens)).shape}")

Input shape (MSA Block): torch.Size([1, 197, 768])
Output shape (MLP Block): torch.Size([1, 197, 768])


## Setup Transformer Encoder

`Encoder` or `Auto Encoder` refers to a stack of layers that "encodes" an input (turns into some numerical representation).

The Transformer Encoder will encode our patched image embedding into a learned representation consisting of alternating layers of MSA blocks and MLP blocks. Layernorm (LN) is applied before every block, and residual connections after every block.

One of the main ideas beding residual connections is that they prevent weight values and gradient updates from getting too small and thus allow deeper networks and in turn allow deeper representations to be learned. 



### Replicate Transformer Encoder with MSA and MLP blocks

1. Create a class called `TransformerEncoderBlock` that inherits from `torch.nn.Module`.
2. Initialize the class with hyperparameters from Table 1 and Table 3 of the ViT paper for the ViT-Base model. 
3. Instantiate a MSA block for equation 2 using our `MultiheadSelfAttentionBlock` with the appropriate parameters. 
4. Instantiate a MLP block for equation 3 using our `MLPBlock` with the appropriate parameters. 
5. Create a `forward()` method for our `TransformerEncoderBlock` class.
6. Create a residual connection for the MSA block (for equation 2).
7. Create a residual connection for the MLP block (for equation 3).

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, 
               embedding_dim: int=768,
               mlp_size: int=3072,
               num_heads: int=12,
               attn_dropout: float=0.0,
               mlp_dropout: float=0.1):
    super().__init__()
    
    self.multihead_attn = MultiHeadSelfAttention(embedding_dim=embedding_dim, 
                                                 num_heads=num_heads,
                                                 attn_dropout=attn_dropout)
    
    self.mlp = MultilayerPerceptron(embedding_dim=embedding_dim,
                                    mlp_size=mlp_size,
                                    dropout=mlp_dropout)
    

  def forward(self, x):
    x = self.multihead_attn(x) + x
    x = self.mlp(x) + x
    return x

In [None]:
encoder = TransformerEncoder()

print(f"Input shape (Embedded Patches): {patch_and_position_embedding_with_tokens.shape}")
print(f"Output shape (Transformer Encoder): {encoder(patch_and_position_embedding_with_tokens).shape}")

Input shape (Embedded Patches): torch.Size([1, 197, 768])
Output shape (Transformer Encoder): torch.Size([1, 197, 768])


In [None]:
summary(model=encoder,
        input_size=(1, 197, 768),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
TransformerEncoder (TransformerEncoder)            [1, 197, 768]        [1, 197, 768]        --                   True
├─MultiHeadSelfAttention (multihead_attn)          [1, 197, 768]        [1, 197, 768]        --                   True
│    └─LayerNorm (layer_norm)                      [1, 197, 768]        [1, 197, 768]        1,536                True
│    └─MultiheadAttention (multihead_attn)         --                   [1, 197, 768]        2,362,368            True
├─MultilayerPerceptron (mlp)                       [1, 197, 768]        [1, 197, 768]        --                   True
│    └─LayerNorm (layer_norm)                      [1, 197, 768]        [1, 197, 768]        1,536                True
│    └─Sequential (mlp)                            [1, 197, 768]        [1, 197, 768]        --                   True
│    │    └─Linear (0)                     

## Setup Transformer Encoder with PyTorch's `torch.nn.TransformerEncoderLayer()`

The ViT-Base architecture uses 12 `torch.nn.TransformerEncoderLayer()` stacked on top of each, this can be with `torch.nn.TransformerEncoder(encoder_layer, num_layers)` where:

* `encoder_layer` - The target Transformer Encoder layer created with `torch.nn.TransformerEncoderLayer()`.
* `num_layers` - The number of Transformer Encoder layers to stack together. 

In [None]:
from torch.nn.functional import gelu
encoder_pytorch = nn.TransformerEncoderLayer(d_model=768,
                                             nhead=12,
                                             dim_feedforward=3072,
                                             dropout=0.1,
                                             activation="gelu",
                                             batch_first=True,
                                             norm_first=True)

print(encoder_pytorch)
print(encoder_pytorch(patch_and_position_embedding_with_tokens).shape)

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (linear1): Linear(in_features=768, out_features=3072, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=3072, out_features=768, bias=True)
  (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)
torch.Size([1, 197, 768])


In [None]:
summary(model=encoder_pytorch,
        input_size=(1, 197, 768),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
TransformerEncoderLayer (TransformerEncoderLayer)  [1, 197, 768]        [1, 197, 768]        7,087,872            True
Total params: 7,087,872
Trainable params: 7,087,872
Non-trainable params: 0
Total mult-adds (M): 0
Input size (MB): 0.61
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.61

1. Create a class called `ViT` that inherits from `torch.nn.Module`.
2. Initialize the class with hyperparameters from Table 1 and Table 3 of the ViT paper for the ViT-Base model.
3. Make sure the image size is divisible by the patch size (the image should be split into even patches).
4. Calculate the number of patches using the formula $N = HW/P^2$, where $H$ is the image height, $W$ is the image width and $P$ is the patch size.
5. Create a learnable class embedding token (equation 1).
6. Create a learnable position embedding vector (equation 1).
7. Setup the embedding dropout layer. 
  * *Dropout when used, is applied after every dense layer except for the qkv-projections and directly after adding positional-to-patch embeddings*.
8. Create the patch embedding layer using the `PatchEmbedding` class.
9. Create a series of Transformer Encoder blocks by passing a list of `TransformerEncoderBlock`s to `torch.nn.Sequential()` (equations 2 & 3).
10. Create the MLP head (i.e. classifier head or equation 4) by passing a `torch.nn.LayerNorm()` (LN) layer and a `torch.nn.Linear(out_features=num_classes)` layer (where `num_classes` is the target number of classes) linear layer to `torch.nn.Sequential()`.
11. Create a `forward()` method that accepts an input.
12. Get the batch size of the input (the first dimension of the shape).
13. Create the patching embedding using the layer created in step 8 (equation 1).
14. Create the class token embedding using the layer created in step 5 and expand it across the number of batches found in step 11 using `torch.Tensor.expand()` (equation 1).
15. Concatenate the class token embedding create in step 13 to the first dimension of the patch embedding created in step 12 using `torch.cat` (equation 1).

In [None]:
class ViT(nn.Module):

  def __init__(self, 
               height: int=224,
               width: int=224,
               color_channels: int=3,
               patch_size: int=16,
               batch_size: int=1,
               num_layers: int=12,
               embedding_dim: int=768,
               mlp_size: int=3072,
               num_heads: int=12,
               dropout: float=0.1,
               num_classes: int=1000):
    super().__init__()

    self.num_patches = int((height * width) / patch_size**2)
    self.class_embedding = nn.Parameter(data=torch.randn(batch_size, 1, embedding_dim), 
                                     requires_grad=True)
    self.position_embedding = nn.Parameter(data=torch.randn(batch_size, self.num_patches + 1, embedding_dim),
                                           requires_grad=True)
    self.dropout = nn.Dropout(p=dropout,
                              inplace=True)
    
    self.embedded_patches = PatchEmbedding(in_channels=color_channels,
                                           patch_size=patch_size,
                                           embedding_dim=embedding_dim)

    self.transformer_encoder = nn.Sequential(*[TransformerEncoder(embedding_dim=embedding_dim,
                                                                  mlp_size=mlp_size,
                                                                  num_heads=num_heads,
                                                                  attn_dropout=0.0,
                                                                  mlp_dropout=dropout) for _ in range(num_layers)])
    
    self.mlp_head = nn.Sequential(
        nn.LayerNorm(normalized_shape=embedding_dim),
        nn.Linear(in_features=embedding_dim, 
                  out_features=num_classes)
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    batch_size = x.shape[0]

    xp = self.embedded_patches(x)

    xp = torch.cat((self.class_embedding, xp), dim=1)
    
    xp += self.position_embedding

    xp = self.dropout(xp)

    out = self.transformer_encoder(xp)
    
    return self.mlp_head(out[:, 0, :])



In [None]:
model = ViT()

In [None]:
out = model(image.unsqueeze(0))

In [None]:
out, out.shape

(tensor([[0.1506]], grad_fn=<AddmmBackward0>), torch.Size([1, 1]))

## Setup Optimizer


Train all models, including ResNets, using Adam with $\beta_{1}=0.9, \beta_{2}=0.999$, a batch size of 4096 and apply a high weight decay of 0.1, which was found to be useful for trasfer of all models.

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), 
                             lr=0.003,
                             betas=(0.9, 0.999),
                             weight_decay=0.1)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.003
    maximize: False
    weight_decay: 0.1
)

## Setup Loss

Searching the ViT paper for "loss" or "loss function" or "criterion" returns no results.

Since the target problem is multi-class classification, we'll use `torch.nn.CrossEntropyLoss()`.

In [None]:
loss_fn = nn.CrossEntropyLoss()
loss_fn

CrossEntropyLoss()