In [3]:
import torch
import torch.nn as nn
from einops import rearrange
from utils import print_tensor_info

  Referenced from: <CFED5F8E-EC3F-36FD-AAA3-2C6C7F8D3DD9> /Users/himanshubhenwal/miniforge3/envs/torch2/lib/python3.11/site-packages/torchvision/image.so
  warn(


In [4]:
shape = (2, 3, 4)
pred_features = torch.rand(shape)
target_features = torch.rand(shape)

In [5]:
pred_features

tensor([[[0.8649, 0.5216, 0.6526, 0.9176],
         [0.8061, 0.0014, 0.2069, 0.8448],
         [0.9477, 0.2697, 0.5624, 0.7117]],

        [[0.6907, 0.2050, 0.5935, 0.4493],
         [0.8745, 0.7627, 0.5185, 0.7344],
         [0.1786, 0.0843, 0.0201, 0.8117]]])

In [6]:
target_features

tensor([[[0.5644, 0.0294, 0.6311, 0.6940],
         [0.4751, 0.0606, 0.8409, 0.2881],
         [0.8621, 0.3122, 0.8568, 0.2087]],

        [[0.5362, 0.9477, 0.3087, 0.8687],
         [0.9727, 0.1991, 0.8457, 0.8976],
         [0.1155, 0.5781, 0.9644, 0.7395]]])

## Creating a mask

In [7]:
batch_size, max_targets, embed_dim = shape
mask = torch.zeros(batch_size, max_targets, dtype=torch.bool)
print(mask)
print(mask.shape)

tensor([[False, False, False],
        [False, False, False]])
torch.Size([2, 3])


In [8]:
num_targets_per_batch = [2, 3]

In [9]:
for b in range(batch_size):
    mask[b, :num_targets_per_batch[b]] = True

In [10]:
mask

tensor([[ True,  True, False],
        [ True,  True,  True]])

## Normalizing features

In [11]:
pred_features_squared = pred_features**2
print(pred_features_squared)
print(pred_features_squared.shape)

tensor([[[7.4799e-01, 2.7204e-01, 4.2586e-01, 8.4193e-01],
         [6.4973e-01, 1.8391e-06, 4.2828e-02, 7.1363e-01],
         [8.9809e-01, 7.2744e-02, 3.1631e-01, 5.0648e-01]],

        [[4.7701e-01, 4.2029e-02, 3.5221e-01, 2.0183e-01],
         [7.6468e-01, 5.8169e-01, 2.6884e-01, 5.3928e-01],
         [3.1893e-02, 7.1122e-03, 4.0231e-04, 6.5891e-01]]])
torch.Size([2, 3, 4])


In [12]:
sum_squared = torch.sum(pred_features_squared, dim=-1, keepdims=True)
print(sum_squared)
print(sum_squared.shape)

tensor([[[2.2878],
         [1.4062],
         [1.7936]],

        [[1.0731],
         [2.1545],
         [0.6983]]])
torch.Size([2, 3, 1])


In [13]:
pred_norm = torch.sqrt(sum_squared)
print(pred_norm)
print(pred_norm.shape)

tensor([[[1.5126],
         [1.1858],
         [1.3393]],

        [[1.0359],
         [1.4678],
         [0.8357]]])
torch.Size([2, 3, 1])


In [14]:
normalized_pred_features = pred_features / pred_norm
print(normalized_pred_features)
print(normalized_pred_features.shape)

tensor([[[0.5718, 0.3448, 0.4314, 0.6066],
         [0.6797, 0.0011, 0.1745, 0.7124],
         [0.7076, 0.2014, 0.4199, 0.5314]],

        [[0.6667, 0.1979, 0.5729, 0.4337],
         [0.5958, 0.5196, 0.3532, 0.5003],
         [0.2137, 0.1009, 0.0240, 0.9714]]])
torch.Size([2, 3, 4])


In [15]:
target_features_squared = target_features**2
sum_squared = torch.sum(target_features_squared, dim=-1, keepdims=True)
target_norm = torch.sqrt(sum_squared)
normalized_target_features = target_features / target_norm
print(normalized_target_features)
print(normalized_target_features.shape)

tensor([[[0.5154, 0.0268, 0.5763, 0.6337],
         [0.4705, 0.0600, 0.8328, 0.2853],
         [0.6777, 0.2454, 0.6735, 0.1641]],

        [[0.3758, 0.6643, 0.2164, 0.6089],
         [0.6144, 0.1257, 0.5342, 0.5669],
         [0.0855, 0.4280, 0.7140, 0.5475]]])
torch.Size([2, 3, 4])


## Loss

In [16]:
import torch.nn as nn
loss_fn = nn.MSELoss(reduction="none")
loss_per_element = loss_fn(normalized_pred_features, normalized_target_features)
print(loss_per_element)
print(loss_per_element.shape)

tensor([[[3.1827e-03, 1.0112e-01, 2.0974e-02, 7.3297e-04],
         [4.3774e-02, 3.4608e-03, 4.3337e-01, 1.8237e-01],
         [8.9553e-04, 1.9393e-03, 6.4286e-02, 1.3492e-01]],

        [[8.4644e-02, 2.1748e-01, 1.2712e-01, 3.0691e-02],
         [3.4589e-04, 1.5514e-01, 3.2733e-02, 4.4409e-03],
         [1.6440e-02, 1.0699e-01, 4.7610e-01, 1.7971e-01]]])
torch.Size([2, 3, 4])


## Vision Transformer (ViT) Sequential Flow

### 1. Image to Patches (PatchEmbed)
- **Input:** Image tensor (B, C, H, W)
- **Process:** 
  - Apply Conv2d with kernel_size=patch_size, stride=patch_size
  - (B, C, H, W) → (B, embed_dim, H//patch_size, W//patch_size)
  - Flatten spatial dimensions: (B, embed_dim, H//patch_size, W//patch_size) → (B, embed_dim, N) where N=(H×W)/(patch_size²)
  - Transpose: (B, embed_dim, N) → (B, N, embed_dim)
- **Output:** Patch tokens (B, N, embed_dim)

In [17]:
class PatchEmbed(nn.Module):
    
    def __init__(self, patch_size : int = 16, img_size : int = 224, in_chans : int = 3, embed_dim : int = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size)
        
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape

        assert H == W == self.img_size, f"Input image size doesn't match with model size."

        # (B, C, H, W) -> (B, embed_dim, H//patch_size, W//patch_size)
        # (B, embed_dim, H', W') -> (B, embed_dim, N) N = (HxW)/(patch_size)^2 = (H'xW')
        # (B, embed_dim, N) -> (B, N, embed_dim)
        x = rearrange(self.proj(x), 'b c h w -> b (h w) c')
        return x

In [18]:
p = PatchEmbed()
t = torch.rand(1, 3, 224, 224)
print_tensor_info(t)

+---------------+-----------------------+
| Property      | Value                 |
| Name          | tensor                |
+---------------+-----------------------+
| Shape         | (1, 3, 224, 224)      |
+---------------+-----------------------+
| Dimensions    | 4                     |
+---------------+-----------------------+
| Dtype         | torch.float32         |
+---------------+-----------------------+
| Device        | cpu                   |
+---------------+-----------------------+
| Min           | 1.138448715209961e-05 |
+---------------+-----------------------+
| Max           | 0.9999926090240479    |
+---------------+-----------------------+
| Mean          | 0.4993129372596741    |
+---------------+-----------------------+
| Std           | 0.2894613444805145    |
+---------------+-----------------------+
| Memory (MB)   | 0.57421875            |
+---------------+-----------------------+
| Requires Grad | False                 |
+---------------+-----------------

## Multi-head Self-Attention
- **Input:** Normalized tokens (B, N+1, embed_dim)
- **Process:**
  - QKV projection: (B, N+1, embed_dim) → (B, N+1, 3×embed_dim) → reshape to (B, N+1, 3, num_heads, head_dim)
  - Permute to (3, B, num_heads, N+1, head_dim) and split into Q, K, V
  - Calculate attention: (q @ k.transpose) * scale → (B, num_heads, N+1, N+1)
  - Apply softmax and dropout
  - Apply attention to values: (B, num_heads, N+1, head_dim)
  - Reshape and project: (B, N+1, embed_dim)
- **Output:** Self-attention output (B, N+1, embed_dim)

In [20]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads : int = 8, qkv_bias : bool = True, attn_drop : float = 0., proj_drop : float = 0.):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.scale = dim**-0.5
        self.qkv = nn.Linear(dim, dim * 3, bias = qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        qkv = rearrange(self.qkv(x), 'b n (t h d) -> t b h n d', t = 3, h = self.num_heads)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1))*self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = rearrange(attn @ v, 'b h n d -> b n (h d)')
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [21]:
# Step 0: Initial input tensor
x = torch.randn(2, 4, 6)  # [batch_size, sequence_length, embedding_dim]
print("Input : ")
print(x)
print(f"Input shape: {x.shape}")  # [2, 4, 6]

# For our Attention module with 2 heads
attention = Attention(dim=6, num_heads=2)

# Step 1: QKV projection
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# Projects from [2, 4, 6] to [2, 4, 18] (6*3=18)
qkv_projected = attention.qkv(x)
print("QKV Projected : ")
print(qkv_projected)
print(f"After QKV projection: {qkv_projected.shape}")  # [2, 4, 18]

# Step 2: Reshape QKV to separate heads
# Rearrange from [2, 4, 18] to [3, 2, 2, 4, 3]
# [3, batch, heads, sequence, head_dim]
qkv = rearrange(qkv_projected, 'b n (three h d) -> three b h n d', three=3, h=2)
print("QKV : ")
print(qkv)
print(f"After reshaping QKV: {qkv.shape}")  # [3, 2, 2, 4, 3]

# Step 3: Separate Q, K, V
q, k, v = qkv[0], qkv[1], qkv[2]
print("Q : ")
print(q)
print("K : ")
print(k)
print("V : ")
print(v)
print(f"Query shape: {q.shape}")  # [2, 2, 4, 3]
print(f"Key shape: {k.shape}")    # [2, 2, 4, 3]
print(f"Value shape: {v.shape}")  # [2, 2, 4, 3]

# Step 4: Calculate attention scores
# q @ k.transpose(-2, -1) means matrix multiplication of q and k transposed
# q: [2, 2, 4, 3] and k.transpose: [2, 2, 3, 4]
# Result: [2, 2, 4, 4] - each token attends to all other tokens
scale = 3 ** -0.5  # Scale factor: 1/sqrt(head_dim)
attn = (q @ k.transpose(-2, -1)) * scale
print(f"Attention scores shape: {attn.shape}")  # [2, 2, 4, 4]

# Step 5: Apply softmax to get attention weights
attn = attn.softmax(dim=-1)
print("Attention weights : ")
print(attn)
print(f"Attention weights shape: {attn.shape}")  # [2, 2, 4, 4]
# The weights now sum to 1 along the last dimension

# Step 6: Apply dropout (skipping actual dropout for clarity)
# attn = attention.attn_drop(attn)

# Step 7: Apply attention weights to values
# attn: [2, 2, 4, 4] and v: [2, 2, 4, 3]
# Result: [2, 2, 4, 3]
weighted_v = attn @ v
print("Weighted V : ")
print(weighted_v)
print(f"After applying attention: {weighted_v.shape}")  # [2, 2, 4, 3]

# Step 8: Combine heads
# Reshape from [2, 2, 4, 3] to [2, 4, 6]
combined = rearrange(weighted_v, 'b h n d -> b n (h d)')
print(f"After combining heads: {combined.shape}")  # [2, 4, 6]

# Step 9: Final projection
# self.proj = nn.Linear(dim, dim)
# Projects from [2, 4, 6] to [2, 4, 6]
output = attention.proj(combined)
print("Output : ")
print(output)
print(f"After final projection: {output.shape}")  # [2, 4, 6]

# Step 10: Apply dropout (skipping actual dropout for clarity)
# output = attention.proj_drop(output)

# Final output shape matches input shape
print(f"Final output shape: {output.shape}")  # [2, 4, 6]

Input : 
tensor([[[ 0.6379, -1.7139,  0.5090, -1.4304, -0.1238,  0.0871],
         [-1.5324, -0.4285,  0.8974, -0.0831,  1.1962,  0.7543],
         [-1.1553,  0.4374, -1.0537,  1.9426,  1.2371,  0.8994],
         [ 0.4621,  1.0332,  0.0703, -2.1739, -1.3538,  0.3809]],

        [[ 0.5076,  1.6884,  0.4170, -0.5509, -1.3361, -2.7045],
         [-1.4617, -0.5990,  0.9987, -0.8241, -0.3655,  0.8913],
         [-0.6579, -1.5810,  1.3470,  1.1436,  1.1443,  0.7243],
         [-0.4441,  0.2662,  1.1105, -0.3284,  0.7214, -1.1463]]])
Input shape: torch.Size([2, 4, 6])
QKV Projected : 
tensor([[[ 0.7425, -0.0707, -0.0828,  0.4573,  0.4544,  1.0330,  0.2471,
           0.0914,  0.7612, -0.0415,  1.4642,  0.0638, -1.0289,  0.5725,
           0.2907,  0.3723,  0.7210,  0.5929],
         [ 0.3572, -0.8977,  0.0735, -0.6721,  0.5857,  0.6755,  1.1203,
          -0.0960, -0.2443,  0.9797, -0.5174, -0.4123, -0.8497, -0.5494,
           1.0104,  0.8182, -0.1323,  0.0854],
         [-0.4196, -0.7100, -

## MLP Block
- **Input:** Normalized tokens (B, N+1, embed_dim)
- **Process:**
  - FC1: (B, N+1, embed_dim) → (B, N+1, hidden_dim) where hidden_dim = mlp_ratio * embed_dim
  - Activation (GELU)
  - Dropout
  - FC2: (B, N+1, hidden_dim) → (B, N+1, embed_dim)
  - Dropout
- **Output:** MLP output (B, N+1, embed_dim)

In [27]:
from typing import Optional
class MLP(nn.Module):
    def __init__(self, 
        in_features : int,
        hidden_features : Optional[int],
        out_features : Optional[int],
        activation : nn.Module = nn.GELU,
        drop : float = 0.
    ):
        super().__init__()
        self.in_features = in_features
        self.activation = activation()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)
    
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.activation(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

In [28]:
t = torch.randn(2, 4)
mlp = MLP(4, 8, 4)
print(mlp(t))

tensor([[ 0.3741,  0.3086, -0.2607, -0.4080],
        [-0.7085, -1.5666, -0.3569,  0.7105]], grad_fn=<AddmmBackward0>)


In [None]:
class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0.):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output


# Transformer Block
class Block(nn.Module):

    def __init__(
            self,
            dim : int,
            num_heads : int = 8,
            mlp_ratio : float = 4,
            qkv_bias : bool = True,
            drop : float = 0.,
            attn_drop : float = 0.,
            drop_path : float = 0.,
            act_layer : nn.Module = nn.GELU,
            norm_layer : nn.Module = nn.LayerNorm,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
        )
    
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

### 4. Final Layer Normalization
- **Input:** Output from last encoder block (B, N+1, embed_dim)
- **Output:** Normalized representation (B, N+1, embed_dim)

### 5. Feature Extraction
- For classification: Use only class token (B, 1, embed_dim) → (B, embed_dim)
- For dense prediction: Use patch tokens (B, N, embed_dim)

## Architecture Organization
- **VisionTransformer (main class):**
  - Contains PatchEmbed, ClassToken, PositionalEmbedding, Blocks, and final normalization
  - Manages the overall forward pass through all components
  
- **PatchEmbed (component):**
  - Handles the initial tokenization of the image

- **Block (component):**
  - Contains one complete transformer unit (Attention + MLP)
  
- **Attention (component):**
  - Implements the multi-head self-attention mechanism
  
- **MLP (component):**
  - Implements the feed-forward network in each block

This sequence ensures the image is properly tokenized, enhanced with positional information, and processed through the self-attention mechanism to capture global relationships between patches.