### Coding exercise #1 

### Coding exercise #2 

### Coding exercise #3: Implementing Single-Head Attention Block

In [3]:
import torch
import torchvision
import torch.nn as nn
import math

In [6]:
from PIL import Image
import requests
from io import BytesIO

In [None]:
class SingleHeadAttention(nn.Module):
    def __init__(self, embed_dim):
        """
        Initialize a single-head attention block.
        Args:
            embed_dim (int): The dimensionality of input embeddings
        """
        super().__init__()
        
        # TODO: Initialize the learnable weight matrices for queries, keys, and values
        self.query = # Complete this .. 
        
        # Scaling factor for dot product attention
        self.scaling = None  # 1/sqrt(embed_dim)
        
    def forward(self, x):
        """
        Compute single-head attention for the input.
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_length, embed_dim)
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim)
        """
        # Get batch size, sequence length, and embedding dimension
        batch_size, seq_length, embed_dim = x.shape
        
        # TODO: Compute query, key, and value matrices
        Q = None # Complete this
        K = None
        V = None
        
        # TODO: Compute scaled dot-product attention scores
        # Hint: You can use torch.bmm() for batched matrix multiplication 
        attention_scores = None # Complete this
        
        # TODO: Apply softmax to get attention weights
        attention_weights = None  # Complete this
        
        # TODO: Compute final attention output
        output = None  # Complete this
        
        return output

In [None]:
# Create a sample input tensor
batch_size, seq_length, embed_dim = 2, 4, 8
x = torch.randn(batch_size, seq_length, embed_dim)

# Initialize the attention module
attention = SingleHeadAttention(embed_dim)

# Compute attention
output = attention(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

### Code-exercise #4.1 Implementing positional encodings on input embeddings

In [2]:
### Implement fixed sinusoidal positional encodings
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        """
        Initialize the positional encoding.
        
        Args:
            embed_size (int): Size of the embeddings
            max_len (int): Maximum sequence length to pre-compute
        """
        super(PositionalEncoding, self).__init__()
        
        # TODO: Create a matrix of shape (max_len, embed_size)
        pe = None
        
        # TODO: Create position vector
        position = None 
        
        # Create division term for different dimensions
        # Hint: Use the following div term
        div_term = 1 / (10000.0 ** (torch.arange(0, embed_size, 2).float() / embed_size))
        
        # TODO: Fill the pe matrix with sin values for even indices and cos for odd indices
        # pe[:, 0::2] = ??? 
        # pe[:, 1::2] = ??? 
        
        # Register pe as a buffer
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Add positional encoding to input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_length, embed_size)
        
        Returns:
            torch.Tensor: Input combined with positional encoding
        """
        # TODO: Get sequence length from input tensor
        seq_length = None
        
        # TODO: Add position encoding to input
        x = None 
        
        return x

### Code-exercise #4.2 Implementing fixed positional encodings on input embeddings

In [None]:
### Implement learnable positional encodings
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        """
        Initialize the learnable positional encoding matrix.
        
        Args:
            embed_size (int): Dimension of each embedding vector.
            max_len (int): Maximum length of the sequence for positional encoding.
        """
        super(LearnedPositionalEncoding, self).__init__()
        
        # TODO: Define a learnable positional encoding matrix with shape (max_len, embed_size)
        # Hint: Use nn.Parameter to make it a trainable parameter
        self.positional_encoding = None
        
        
    def forward(self, x):
        """
        Adds the learnable positional encoding to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_length, embed_size).
        
        Returns:
            torch.Tensor: Input tensor with added learnable positional encoding.
        """
        # TODO: Get the sequence length from the input tensor
        seq_length = None
        
        # TODO: Add positional encoding
        x = None 
        
        return x

##### Visualizing positional encoding

In [8]:
# Example usage with visualizations
# Test parameters
batch_size = 32
seq_length = 20
embed_size = 512

# Create random input tensor
x = torch.randn(batch_size, seq_length, embed_size)

# Create positional encoding instance
pos_encoder = PositionalEncoding(embed_size)

# Apply positional encoding
output = pos_encoder(x)

# Print shapes
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Visualize the positional encodings
import matplotlib.pyplot as plt

# Plot the first 100 positions and first 20 dimensions
plt.figure(figsize=(15, 15))
plt.imshow(pos_encoder.pe[0, :100, :20])
plt.xlabel('Embedding Dimension')
plt.ylabel('Sequence Position')
plt.title('Positional Encodings')
plt.colorbar()
plt.show()

# Plot specific dimensions
plt.figure(figsize=(15, 5))
dims_to_plot = [0, 1, 4, 5]  # Plot first few even/odd pairs
for dim in dims_to_plot:
    plt.plot(
        pos_encoder.pe[0, :40, dim].numpy(),
        label=f'dim {dim} ({"even" if dim % 2 == 0 else "odd"})'
    )
plt.legend()
plt.title('Positional Encoding Values')
plt.xlabel('Sequence Position')
plt.ylabel('Encoding Value')
plt.grid(True)
plt.show()

### Code review #5.1: Patch embeddings

In [None]:
# ref: https://github.com/FrancescoSaverioZuppichini/ViT

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        # the size of each patch, and the embedding size for each patch.
        self.patch_size = patch_size
        super().__init__()
        
        # Define a sequential layer for projecting patches. 
        self.projection = nn.Sequential(
            # Conv2d layer for projecting image patches into embeddings:
            # - in_channels is the number of channels in the input image.
            # - emb_size represents the desired output dimensions for each patch (embedding size).
            # - kernel_size and stride are set to patch_size, so it "slides" across the image 
            #   and outputs patches of the specified size.
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            
            # Rearranges the output tensor:
            # - Changes the tensor from [batch, embedding_dim, h, w] to [batch, num_patches, embedding_dim].
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
                
    def forward(self, x: Tensor) -> Tensor:
        # Passes the input tensor through the projection layer to obtain patch embeddings.
        x = self.projection(x)
        return x


In [None]:
# The Conv2d layer with kernel_size = stride = patch_size slides across the input image in patch x patch dimensions,
# just as traditional patch extraction would do.
# Each patch is processed without overlapping its neighbors as kernel_size = stride

# The output channels are set to emb_size, so it outputs an embedding vector per patch directly, essentially combining the patch extraction and linear transformation into one step. 
# This reduces computation time and adds efficiency by avoiding separate operations for patching and embedding.

### Code review #5.2: ViT architecture

In [None]:
class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

In [None]:
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            # Select the [CLS] token (assuming it's at index 0)
            Lambda(lambda x: x[:, 0]),  # Only use the [CLS] token for classification
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

In [None]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

In [None]:
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

In [None]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )