In [None]:
import torch
import torch.nn as nn
from torchinfo import summary

In [24]:
# This Python class is used to segment images into patches and convert them into embeddings.
class ImagePatchEmbedding(nn.Module):

    # The constructor initializes the class instance with image size, patch size, image channels, and the embedding dimension.
    def __init__(self, image_size, patch_size, image_channels, embedding_dimension=768):
        
        # Invoking the parent class's constructor
        super().__init__()

        # Assigning input parameters to class instance variables
        self.image_size = image_size
        self.patch_size = patch_size
        self.number_of_patches = (image_size // patch_size) ** 2 # Here we calculate the total number of patches
        
        # We use a convolutional layer to segment our image into patches.
        # The convolutional layer will have a kernel size and stride equal to the patch size.
        self.conv_layer_to_segment_image = nn.Conv2d(in_channels=image_channels, out_channels=embedding_dimension, kernel_size=patch_size, stride=patch_size)
        
        # A layer to transform the 2D patch feature maps into 1D vectors
        self.flatten_layer = nn.Flatten(start_dim=2, end_dim=3)
        
    # This method processes the image through our convolutional layer to generate the patches and convert them into embeddings.
    def forward(self, input_image):
        
        # Check that the size of the input image is divisible by the patch size
        image_resolution = input_image.shape[-1]
        assert image_resolution % self.patch_size == 0, "Input image size must be divisible by the patch size."
        
        # Segmenting the image into patches
        patches = self.conv_layer_to_segment_image(input_image)
        
        # Flatten the patches into 1D vectors
        flattened_patches = self.flatten_layer(patches)
        
        # Reorder the dimensions to have the batch size, number of patches and embedding dimension.
        reshaped_patches = flattened_patches.permute(0, 2, 1)
        
        return reshaped_patches




In [31]:
rand_image_tensor = torch.randn(32, 3, 224, 224)  # (batch_size, color_channels, height, width)
rand_image_tensor.shape

patch_embedding = ImagePatchEmbedding(224, 16, 3)
patch_embedding_output = patch_embedding.forward(rand_image_tensor)
print(f"Input shape: {rand_image_tensor.shape}")
print(f"Output shape: {patch_embedding_output.shape} -> (batch_size, num_patches, embedding_dim)")

# To summarize, you have 32 batches, each batch contains 196 patches, and each patch has a dimension of 764.


Input shape: torch.Size([32, 3, 224, 224])
Output shape: torch.Size([32, 196, 768]) -> (batch_size, num_patches, embedding_dim)


In [26]:

transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=768,
                                                       nhead=12,
                                                       dim_feedforward=3072,
                                                       dropout=0.1,
                                                       activation="gelu",
                                                       batch_first=True,
                                                       norm_first=True)
transformer_encoder_layer



TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (linear1): Linear(in_features=768, out_features=3072, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=3072, out_features=768, bias=True)
  (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)

In [27]:
# summary(model=transformer_encoder_layer, input_size=patch_embedding_output.shape)


In [28]:
# now we want to stack 12 transformer encoders layers to make up the entire transformer encoder block

# an encoder says "I'll turn your data into a numerical representation to try and find a pattern in that data"
# a decoder says "I'll take your learnable numerical representation and turn it back to human understandable"

transformer_encoder = nn.TransformerEncoder(
    encoder_layer=transformer_encoder_layer,
    num_layers=12
)


summary(model=transformer_encoder, input_size=patch_embedding_output.shape)

Layer (type:depth-idx)                   Output Shape              Param #
TransformerEncoder                       [32, 196, 768]            --
├─ModuleList: 1-1                        --                        --
│    └─TransformerEncoderLayer: 2-1      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-2      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-3      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-4      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-5      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-6      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-7      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-8      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-9      [32, 196, 768]            7,087,872
│    └─TransformerEncoderLayer: 2-10     [32, 196, 768]            7,087,872
│    └─Transfor

In [29]:
# Put it all together and create ViT
class VisionTransformer(nn.Module): # we want to subclass the VisionTransformer class
    def __init__(self, 
                 num_classes,
                 image_size=224, 
                 image_channels=3, 
                 patch_size=16, 
                 embed_dimension=768, 
                 dropout=0.1, 
                 mlp_size=3072, 
                 num_transformer_layers=12, 
                 num_heads=12, # number of multi-head self attention heads
                 ): # 1,000 for the the labels say imageNet
        super().__init__() # initialize the parent classes nn.Module's contructor before adding any intialization logic to VisionTransformer class
    
        assert image_size % patch_size == 0, f"image_size must divisible by patch_size"
    
    
    
        # 1. Create patch embedding
        self.patch_embedding = ImagePatchEmbedding(image_size=image_size, patch_size=patch_size, image_channels=image_channels) # remember, we create embeddings to form a representation of our data so it can be learned and updated over time
        
        # 2. Create class token for a single image
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dimension), requires_grad=True) # requires_grad means that the parameter is "learnable"
        
        # 3. Create positional embedding ("this is the order that the patches come in" -> keeps track of where the patches are positioned in a sequence because the spatial information is lost when it's flattened into a sequence)
        num_patches = (image_size * image_size) // patch_size**2 # N = HW / P^2
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dimension), requires_grad=True) # create positional embed for each patch in image and also the class token
        
        # 4. Create patch + position embedding dropout
        self.embed_dropout = nn.Dropout(p=dropout)
        
        # 5. Create transformer encoder layer
        self.transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dimension, 
                                                                    nhead=num_heads, 
                                                                    dim_feedforward=mlp_size, 
                                                                    activation="gelu", 
                                                                    batch_first=True, 
                                                                    norm_first=True)
        
        
        # 6. Create stack transformer encoder layers
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer=self.transformer_encoder_layer, num_layers=num_transformer_layers  )
        
        
        # 7. Create MLP head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dimension),
            nn.Linear(in_features=embed_dimension, out_features=num_classes)
        )
        
        
    def forward(self, x): # "Hey, what do you want me to do?" If I pass you some data in the form of 'x', I want you to take these steps
        # dims from x
        batch_size = x.shape[0]

        # create patch embedding
        x = self.patch_embedding(x)

        # Expand the class token across the batch dimension
        class_token = self.class_token.expand(batch_size, -1, -1) # copy class token 32 times or the size of the batch and then use -1, to infer the rest of the dimensions
        print(f"Shape of image with class token {x.shape}, it's 197 now instead of 196")


        # Prepend the class token to the patch embeddings
        x = torch.cat((class_token, x), dim=1) # Why dim1, we want to prepend it to our patches dim (batch_size, patches, embed_dim) from the patch_embedding class

        # Add positional embeddings to class token and patch embeddings
        x = self.positional_embedding + x
        print(f"pos + (class token, patch emd) {x.shape}")
        
        
        # Dropout on patch + positional embeddings
        x = self.embed_dropout(x)
        
        # Pass embedding through transformer encoder stack
        x = self.transformer_encoder(x)
        
        # Pass 0th index of x through MLP head, why? Only pass class token to the MLP head for classification
        x = self.mlp_head(x[:, 0]) # : means all of the batches, but only the 0th dim
        
        
        return x




In [30]:
class_names = ["apple", "car", "dog"]

demo_images = torch.randn(32, 3, 224, 224)
vit = VisionTransformer(num_classes=len(class_names))
vit(demo_images)


'''
Input Image Shape: torch.Size([32, 3, 224, 224]) | 32 = Batch Size | 3 = Image Channels | 224 = Image Height | 224 = Image Width
Patched Shape: torch.Size([32, 768, 14, 14]) | 32 = Batch Size | 768 = Embed Dimension | 14 = Patch Height | 14 = Patch Width
Flattened Shape: torch.Size([32, 768, 196]) | 32 = Batch Size | 768 = Embed Dimension | 196 = Flattened Patches
Permuted Shape: torch.Size([32, 196, 768]) | 32 = Batch Size | 196 = Flattened Patches | 768 = Embed Dimension
Shape of image with class token torch.Size([32, 196, 768]), it's 197 now instead of 196
pos + (class token, patch emd) torch.Size([32, 197, 768])
tensor([[ 0.4577,  0.4281, -0.0289],
        [ 0.3021,  0.6390, -0.2092],
        [ 0.3056,  0.6422,  0.0428],
        [ 0.2205,  0.6570, -0.1065],
        [ 0.0620,  0.9290, -0.3605],
        [ 0.3369,  0.7926, -0.0963],
        [ 0.2043,  0.4202,  0.2618],
        [ 0.0527,  0.5007, -0.1787],
        [ 0.0195,  0.6088, -0.0205],
        [ 0.0513,  0.6730, -0.1341],
        [ 0.1731,  0.7196, -0.0676],
        [ 0.0735,  0.5199, -0.1979],
        [ 0.2313,  0.3570, -0.0364],
        [ 0.2900,  0.6627, -0.1341],
        [ 0.0354,  0.8942, -0.2055],
        [ 0.1819,  0.6816, -0.1065],
        [ 0.0393,  0.4170, -0.1408],
        [-0.0083,  0.5982, -0.0175],
        [ 0.1635,  0.6088, -0.0194],
        [ 0.1536,  0.0958, -0.0981],
        [ 0.1868,  0.1048,  0.0895],
        [-0.1212,  0.5534,  0.1458],
        [ 0.4610,  0.6150, -0.1061],
        [ 0.4095,  0.6661, -0.1383],
        [ 0.3896,  0.7688,  0.0430],
        [ 0.0205,  0.5659,  0.0815],
        [ 0.3623,  0.5507, -0.1554],
        [ 0.4797,  0.7323, -0.0357],
        [ 0.2438,  0.4421, -0.0457],
        [ 0.3576,  0.2898, -0.0273],
        [-0.0353,  0.5208, -0.2585],
        [ 0.1440,  0.7174,  0.1166]], grad_fn=<AddmmBackward0>)
'''

Shape of image with class token torch.Size([32, 196, 768]), it's 197 now instead of 196
pos + (class token, patch emd) torch.Size([32, 197, 768])


"\nInput Image Shape: torch.Size([32, 3, 224, 224]) | 32 = Batch Size | 3 = Image Channels | 224 = Image Height | 224 = Image Width\nPatched Shape: torch.Size([32, 768, 14, 14]) | 32 = Batch Size | 768 = Embed Dimension | 14 = Patch Height | 14 = Patch Width\nFlattened Shape: torch.Size([32, 768, 196]) | 32 = Batch Size | 768 = Embed Dimension | 196 = Flattened Patches\nPermuted Shape: torch.Size([32, 196, 768]) | 32 = Batch Size | 196 = Flattened Patches | 768 = Embed Dimension\nShape of image with class token torch.Size([32, 196, 768]), it's 197 now instead of 196\npos + (class token, patch emd) torch.Size([32, 197, 768])\ntensor([[ 0.4577,  0.4281, -0.0289],\n        [ 0.3021,  0.6390, -0.2092],\n        [ 0.3056,  0.6422,  0.0428],\n        [ 0.2205,  0.6570, -0.1065],\n        [ 0.0620,  0.9290, -0.3605],\n        [ 0.3369,  0.7926, -0.0963],\n        [ 0.2043,  0.4202,  0.2618],\n        [ 0.0527,  0.5007, -0.1787],\n        [ 0.0195,  0.6088, -0.0205],\n        [ 0.0513,  0.6730

In [None]:
summary(model=vit, input_size=demo_image.shape)

In [6]:
import torch
import torch.nn as nn
from torchinfo import summary







tensor([[ 0.3584, -0.3832,  0.3969],
        [ 0.1811, -0.7862,  0.4041],
        [ 0.2111, -0.3494,  0.1924],
        [ 0.3677, -0.4780,  0.2955],
        [ 0.1036, -0.3134,  0.0314],
        [ 0.6630, -0.4446,  0.5569],
        [ 0.3351, -0.5870,  0.5440],
        [ 0.1930, -0.4504,  0.4822],
        [ 0.0944, -0.7169,  0.0551],
        [ 0.1774, -0.7897,  0.2435],
        [ 0.1660, -0.4201,  0.0363],
        [ 0.2953, -0.3338,  0.3320],
        [ 0.4666, -0.3435,  0.2051],
        [ 0.2962, -0.4594,  0.2008],
        [ 0.2610, -0.5089,  0.4607],
        [ 0.2954, -0.5357,  0.3910],
        [ 0.4862, -0.1919,  0.0420],
        [ 0.2963, -0.4240,  0.2787],
        [ 0.2013, -0.4501,  0.5937],
        [ 0.3435, -0.2963,  0.4042],
        [ 0.4576, -0.4095,  0.1598],
        [ 0.5392, -0.5060,  0.2855],
        [ 0.1561, -0.3194,  0.3333],
        [ 0.1777, -0.4671,  0.2991],
        [ 0.2925, -0.2399,  0.0908],
        [ 0.3274, -0.4938,  0.3047],
        [ 0.2222, -0.5913,  0.1426],
 

In [None]:

# This is a helper class for creating patches from the input image.
class ImagePatcher(nn.Module):
    def __init__(self, image_channels, patch_size, embedding_dimension):
        super().__init__()
        # This convolutional layer helps in breaking the image into patches and transforming each patch into embeddings.
        self.conv = nn.Conv2d(image_channels, embedding_dimension, kernel_size=patch_size, stride=patch_size)
        # Flatten the 2D patches into 1D vectors.
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)

    # Forward propagation for the image patcher. 
    def forward(self, input_image):
        # The image is passed through the convolution layer to create patches and then flattened.
        # The output tensor is then permuted to get the right shape for the transformer.
        return self.flatten(self.conv(input_image)).permute(0, 2, 1)


In [None]:

# This is the main Vision Transformer class.
class VisionTransformer(nn.Module): 
    def __init__(self, 
                 num_classes,
                 image_size=224, 
                 image_channels=3, 
                 patch_size=16, 
                 embed_dimension=768, 
                 dropout=0.1, 
                 mlp_size=3072, 
                 num_transformer_layers=12, 
                 num_heads=12,
                 ):
        super().__init__()
        # Checking if image_size is divisible by patch_size.
        assert image_size % patch_size == 0, f"image_size must be divisible by patch_size"

        # Initialize the ImagePatcher module.
        self.patch_embedding = ImagePatcher(image_channels, patch_size, embed_dimension)
        # Initialize the class token which is used as the first token in the sequence.
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dimension), requires_grad=True)
        # Calculate the number of patches.
        num_patches = (image_size // patch_size) ** 2
        # Initialize the positional embeddings for each patch and the class token.
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dimension), requires_grad=True)
        # Initialize the dropout layer.
        self.dropout = nn.Dropout(dropout)
        # Initialize the transformer encoder with the specified dimensions and layers.
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(embed_dimension, num_heads, mlp_size, batch_first=True, norm_first=True), num_transformer_layers)
        # Initialize the final classification head which consists of layer normalization and a linear layer.
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dimension),
            nn.Linear(embed_dimension, num_classes)
        )
        
    def forward(self, x):
        # Determine the batch size.
        batch_size = x.shape[0]
        # Create patches from the input image and convert them into embeddings.
        x = self.patch_embedding(x)
        # Add the class token to the beginning of each sequence.
        x = torch.cat((self.class_token.expand(batch_size, -1, -1), x), dim=1)
        # Add the positional embeddings to the patch embeddings.
        x += self.positional_embedding
        # Apply dropout.
        x = self.dropout(x)
        # Pass the sequence through the transformer.
        x = self.transformer(x)
        # Pass the first token (class token) through the final classification head to get the output probabilities for each class.
        return self.mlp_head(x[:, 0])