In [2]:
import torch
from torch import nn
from torch.nn import functional as F

  cpu = _conversion_method_template(device=torch.device("cpu"))


## Converting Image to a Sequence of Patches

In [3]:
class PatchEmbeddings(nn.Module):
    def __init__(
        self, img_size: int = 96, patch_size: int = 16, hidden_dim: int = 512
    ) -> None:
        super().__init__()
        # Store the input image size, the patch size and hidden dimension
        self.img_size = img_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        # Calculate the total number of patches
        self.num_patches = (self.img_size // self.patch_size) ** 2

        # Create a convolution to extract patch embeddings
        # in_channels=3 asummes a 3-channel image (RGB)
        # outp_channels=hidden_dim sets the number of output channels to match the hidden dimension
        # kernel_size=patch_size and stride=patch_size ensuring each patch is embedded separately
        self.conv = nn.Conv2d(
            in_channels=3,
            out_channels=self.hidden_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # Extract patch embeddings from the input image
        # Output shape: (batch_size, hidden_dim, (self.img_size // self.patch_size), (self.img_size // self.patch_size))
        X = self.conv(X)

        # Flatten the spatial dimensions (height and width) of the patch embeddings
        # This step flattens the patch dimensions to a single dimension
        # Output shape: (batch_size, hidden_dim, self.num_patches)
        X = X.flatten(2)

        # Transpose the dimensions to obtain the shape (batch_size, num_patches, hidden_dim)
        # This step brings the num_patches dimension to the second position
        # Output shape: (batch_size, self.num_patches, hidden_dim)
        X = X.transpose(1, 2)

        return X

In [4]:
B, C, H, W = 1, 3, 96, 96  # Batch size, Channels, Height, Width
X = torch.randn(B, C, H, W)

patch_size = 16
hidden_dim = 512

patch_embeddings = PatchEmbeddings(
    img_size=H, patch_size=patch_size, hidden_dim=hidden_dim
)
patches = patch_embeddings(X)
print(f"Shape of image patches: {patches.shape}")

Shape of image patches: torch.Size([1, 36, 512])


In [5]:
num_patches = (H // patch_size) ** 2
assert patches.shape == (B, num_patches, hidden_dim), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Attention Mechanism 
Attention Mechanism across both the vision encoder and language decoder

### The implementation of the Attention Head

In [6]:
class Head(nn.Module):
    def __init__(
        self,
        n_embed: int,
        head_size: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Linear layer for Key projection
        self.key = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Linear layer for Query projection
        self.query = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Linear layer for Value projection
        self.value = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Dropout layer for regularization to prevent overfitting
        self.dropout = nn.Dropout(p=dropout)

        # Flag indicating wheter the head is used as a decoder
        self.is_decoder = is_decoder

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Get batch size (B), sequence length (T), and embedding dimension (C) from the input tensor
        B, T, C = x.shape

        # Compute Key, Query, and Value projections
        k = self.key(x)  # Shape: (B, T, head_size)
        q = self.query(x)  # Shape: (B, T, head_size)
        v = self.value(x)  # SHape: (B, T, head_size)

        # Compute attention scores by taking the dot product of Query and Key
        # and scaling by the square root of the embedding dimension
        wei = q @ k.transpose(-2, -1) * (C**-0.5)  # Shape: (B, T, T)

        if self.is_decoder:
            # If this head is used in the decoder, apply causal mask to the attention scores
            # to prevent attention to future positions
            tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
            wei = wei.masked_fill(mask=tril == 0, value=float("-inf"))

        # Apply softmax to the attention scores to obtain attention probabilities
        # Sum of probabilities for each row will be 1
        wei = F.softmax(input=wei, dim=-1)  # Shape: (B, T, T)

        # Apply Dropout to the attention probabilities for regularization
        wei = self.dropout(wei)

        # Perform a weighted aggregation of values using the attention probabilities
        out = wei @ v  # Shape: (B, T, head_size)

        return out

In [7]:
B, T, C = patches.shape  # Batch size, Sequence length, Embedding dimension
head_size = 16  # Size of the attention head

head = Head(n_embed=C, head_size=head_size)
output = head(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 16])


In [8]:
assert output.shape == (B, T, head_size), "Output shape is incorrect"
print("Test passed!")

Test passed!


### The implementation of Multihead Attention

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        n_embed: int,
        num_heads: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Ensure that the embedding dimension is divisible by the number of heads
        assert n_embed % num_heads == 0, "n_embed must be divisible by num_heads!"

        # Create a ModuleList of attention heads
        self.heads = nn.ModuleList(
            modules=[
                Head(
                    n_embed=n_embed,
                    head_size=n_embed // num_heads,
                    dropout=dropout,
                    is_decoder=is_decoder,
                )
                for _ in range(num_heads)
            ]
        )

        # Linear layer for projecting the concatenated head outputs
        self.proj = nn.Linear(in_features=n_embed, out_features=n_embed)

        # Dropout layer for regularization
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply each attention head to the input tensor
        head_outputs = [
            h(x) for h in self.heads
        ]  # Shape: num_heads * (B, T, head_size)

        # Concatenate the outputs from all heads along the last dimension
        out = torch.cat(tensors=head_outputs, dim=-1)  # Shape: (B, T, m_embed)

        # Apply the projection layer to the concatenated outputs
        out = self.proj(out)  # Shape: (B, T, m_embed)

        # Apply Dropout to the projected outputs for regularization
        out = self.dropout(out)

        return out

In [10]:
num_heads = 2
dropout = 0.1
mha = MultiHeadAttention(n_embed=C, num_heads=num_heads, dropout=dropout)

In [11]:
output = mha(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [12]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


### The Multilayer Perceptron

In [13]:
class MLP(nn.Module):
    def __init__(
        self, n_embed: int, dropout: float = 0.1, is_decoder: bool = False
    ) -> None:
        super().__init__()

        # Define the layers of the MLP
        layers = [
            # First linear layer that expands the input dimension from n_embed to 4 * n_embed
            nn.Linear(in_features=n_embed, out_features=4 * n_embed),
            # Activation function: ReLU if is_decoder is True, else GELU
            nn.ReLU() if is_decoder else nn.GELU(),
            # Second linear layer that projects the intermediate dimension back to n_embed
            nn.Linear(in_features=4 * n_embed, out_features=n_embed),
            # Dropout layer for regularization
            nn.Dropout(p=dropout),
        ]

        # Create the MLP as a sequence of layers
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass the input through the MLP layers
        return self.net(x)

In [14]:
dropout = 0.1
mlp = MLP(n_embed=C, dropout=dropout)

In [15]:
output = mlp(output)  # Previous output of the Multihead Attention
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [16]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


### Transformer Blocks

In [17]:
class Block(nn.Module):
    def __init__(
        self,
        n_embed: int,
        num_heads: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Layer normalization for the input to the attention layer
        self.ln1 = nn.LayerNorm(normalized_shape=n_embed)

        # Multi-head attention module
        self.mhattn = MultiHeadAttention(
            n_embed=n_embed, num_heads=num_heads, dropout=dropout, is_decoder=is_decoder
        )

        # Layer normalization for the input to the FFN
        self.ln2 = nn.LayerNorm(normalized_shape=n_embed)

        # Feed-forward neural network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(in_features=n_embed, out_features=4 * n_embed),
            nn.GELU(),  # Activation function
            nn.Linear(
                in_features=4 * n_embed, out_features=n_embed
            ),  # Projection back to the original dimension
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Saving the input for residual connection
        original_x = x

        # Apply layer normalization to the input
        x = self.ln1(x)

        # Apply multi-head attention
        mhattn_output = self.mhattn(x)

        # Add the residual connection (original input) to the attention output
        x = original_x + mhattn_output

        # Apply later normalization to the input to the FFN
        x = self.ln2(x)

        # Apply the FFN
        ffn_output = self.ffn(x)

        # Apply the residual connection (input to the FFN) to the FFN output
        x = x + ffn_output

        return x

In [18]:
num_heads = 2
dropout = 0.1
block = Block(n_embed=C, num_heads=num_heads, dropout=dropout)

In [19]:
output = block(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [20]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Vision Encoder - Vision Transformer (ViT)

Combining patchification logic and attention block in to ViT

In [21]:
class ViT(nn.Module):
    def __init__(
        self,
        img_size: int,
        patch_size: int,
        num_hiddens: int,
        num_heads: int,
        num_blocks: int,
        emb_dropout: float,
        block_dropout: float,
    ) -> None:
        super().__init__()

        # Patch embedding layer to convert the input image into patches
        self.patch_embedding = PatchEmbeddings(
            img_size=img_size, patch_size=patch_size, hidden_dim=num_hiddens
        )

        # Learnable classification token
        self.cls_token = nn.Parameter(data=torch.zeros(size=(1, 1, num_hiddens)))

        # Calculate the number of patches
        num_patches = (img_size // patch_size) ** 2

        # Learnable position embedding
        self.pos_embedding = nn.Parameter(
            data=torch.randn(size=(1, num_patches + 1, num_hiddens))
        )

        # Dropout layer for the embeddings
        self.dropout = nn.Dropout(p=emb_dropout)

        # Stack of transformer blocks
        self.blocks = nn.ModuleList(
            [
                Block(
                    n_embed=num_hiddens,
                    num_heads=num_heads,
                    dropout=block_dropout,
                    is_decoder=False,
                )
                for _ in range(num_blocks)
            ]
        )

        # Layer normalization for the final representation
        self.layer_norm = nn.LayerNorm(normalized_shape=num_hiddens)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # Convert the input image into patch embeddings
        x = self.patch_embedding(X)  # Shape: (B, num_patches, num_hiddens)

        # Expand the classification token to match the batch size
        cls_tokens = self.cls_token.expand(
            x.shape[0], -1, -1
        )  # Shape: (B, 1, num_hiddens)

        # Concatenate the classification token with the patch embeddings
        x = torch.cat(
            tensors=(cls_tokens, x), dim=1
        )  # Shape: (B, num_patches + 1, num_hiddens)

        # Add the position embedding to the patch embeddings
        x += self.pos_embedding  # Shape: (B, num_patches + 1, num_hiddens)

        # Apply dropout to the embeddings
        x = self.dropout(x)  # Shape: (B, num_patches + 1, num_hiddens)

        # Pass the embeddings through the transformer blocks
        for block in self.blocks:
            x = block(x)  # Shape: (B, num_patches + 1, num_hiddens)

        # Apply layer normalization to the `[CLS]` token's final representation
        x = self.layer_norm(x[:, 0])  # Shape: (B, num_hiddens)

        return x

In [23]:
B, C, H, W = 2, 3, 96, 96  # Batch size, Channels, Height, Width
X = torch.randn(B, C, H, W)
vit = ViT(
    img_size=H,
    patch_size=16,
    num_hiddens=64,
    num_heads=2,
    num_blocks=2,
    emb_dropout=0.1,
    block_dropout=0.1,
)

In [24]:
output = vit(X)
print(f"Output shape: {output.shape}")

Output shape: torch.Size([2, 64])


In [25]:
assert output.shape == (B, 64), "Output shape is incorrect"
print("Test passed!")

Test passed!
