## Imports
These are some preliminary imports that are needed. \
Some other imports for HF output are imported later to prevent confusion in understanding between HFSiglipVisionConfig vs our custom defined SiglipVisionConfig etc.

In [None]:
import math 
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from dataclasses import dataclass
from PIL import Image

## Input Image + Preprocess Image
The model cannot accept the image as is. 
- It has to be resized to 224x224
- It has to be converted to a tensor
- It has to be normalized: these numbers come from the Imagenet dataset (industry standard)
- Unsqueeze the tensor to include the batch dimension so that the transformer model can use it (in this case batch dimension is 1). (3,224,224) --> unsqueeze -->(1,3,224,224)

In [None]:
def preprocess_image(image, image_size=224):
    # image_size is the size to which the image will be resized
    # define the preprocess operation
    preprocess = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std =[0.229, 0.224, 0.225]
        )
    ])

    # actually preprocess the image
    image_tensor = preprocess(image)
    #(3,224,224) --> unsqueeze to include batch dimension -->(1,3,224,224)
    image_tensor = image_tensor.unsqueeze(0)
    return image_tensor

img = Image.open("image.jpg")
img

In [None]:
image_tensor = preprocess_image(img)

## Attention Formula , Single Head of Attention, Multi Head Attention
See Readme.md

## Building Block 1: SiglipVisionConfig
The config values are the ones that the language model PaliGemma2 uses.
- image_size (original input image ) = 224*224 ( you would have to preprocess the image to change to this size. See preprocess_image
- patch_size = 16. So each patch will be 16 x 16 pixels.
- embedding_size = hidden_size = 768. This means every image patch will be converted to a vector of 768 dimension. 
- each image will have 224/16 = 14 patches in every row and every column. So total_num_patches= 14*14 = 196. Each of the 196 patches will be converted to a vector of dimension 768. 

In [None]:
'''
@dataclass is a decorator in Python that automatically generates special methods like __init__, __repr__, and __eq__ for a class. 
This simplifies the creation of classes primarily used for storing data. 
It reduces boilerplate code and improves readability, especially when dealing with objects that mainly hold data.
'''
@dataclass
class SiglipVisionConfig:   
    image_size: int = 224
    patch_size: int = 16
    hidden_size: int = 768 # same as embedding size
    intermediate_size: int = 3072
    num_channels: int = 3
    num_attention_heads: int = 12   
    num_hidden_layers: int = 12 # number of hidden/encoder layers in the encoder as in the paper
    attention_dropout: float = 0.0
    layer_norm_eps: float = 1e-6

## Building Block 2: SiglipVisionEmbeddings
The image is coverted to embeddings. \
All the information in the image is captured by embeddings.

In [None]:
class SiglipVisionEmbeddings(nn.Module):
    
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config    

        self.image_size   = config.image_size
        self.patch_size   = config.patch_size                
        self.embed_dim   = config.hidden_size # same as embedding size
        self.num_channels = config.num_channels
    
        # Patch embedding Layer: 
        # This is just a convolution layer of the patch size with kernel size same as patch size
        # Example: For patch_size = 16, This is a convolution of kernel size: 16*16 and stride 16
        # The input to this is an image tensor (see forward)
        self.patch_embedding = nn.Conv2d(
            in_channels  = self.num_channels,
            out_channels = self.embed_dim,
            kernel_size  = self.patch_size,
            stride=self.patch_size,
            padding = "valid" # same as no padding
        )
    
        # // is floor division. ** is exponentiation. 
        # You square by 2 because you have patches along the length and breadth
        self.num_patches   = (self.image_size// self.patch_size)**2
        self.num_positions = self.num_patches
    
        # Position Embedding Layer. This is a lookup table 
        # The input to this is a bunch of position_ids and not an image tensor, see forward        
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
    
        # Register Buffer/ Position ids
        # registers a non-trainable buffer called position_ids in a nn.Module subclass (so this creates self.position_ids)
        # self.position_ids, which will be a tensor of shape [1, num_patches]
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1,-1)),
            persistent=False, # this is a buffer, so it won't be updated during the forward pass
        )
        
    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        # num_batches, num_channels, height, width
        B, C, H, W = pixel_values.shape
        # Patch embeddings
        patch_embeds = self.patch_embedding(pixel_values)
        flattened_patch_embeds = patch_embeds.flatten(start_dim=2, end_dim=-1)
        flattened_patch_embeds = flattened_patch_embeds.transpose(1, 2)
        # Position embeddings
        position_embeds = self.position_embedding(self.position_ids)
        # Total Embeddings
        total_embeds = flattened_patch_embeds + position_embeds
        return total_embeds

## Building Block 3: SiglipAttention
- Vectorized implementation of Multi head attention. Same implementation as Hugging Face (different from the non vectorized implementation in vit_step3 Head and MultiHeadAttention classes)
- You don't have Single Attention Heads. You process all single attention heads parallely
- This is more memory efficient
- This enables using Hugging Face's pretrained weights in our model

#### Vectorized Implementation
In the vectorized implementation there might seem to be a lot of transposes, shape changes etc. \
Here is a high leve overview of it. 
- i)   q, k, v_states are the same dimensions as hidden_states = [1,196,768] = [batch, num_patches, embedding_dimension] 
- ii)  split q,k,v_states into 12 attention heads, along the embedding_dimension (768)  = [1,196,12,64]
- iii) transpose q,k, v_states so that its [batch, num_heads] first and then [196,64] for vectorized multiplication
- iv)  attention filter = q*k_transpose is a square: [1,12, 196,196]
- v)   scaled attention filter, softmax, dropout are all the same square = [1,12,196,196]
- vi)  mutltiply attention_weights*v and the dimesion is back to : [ 1, 12, 196,64]
- vii) transpose attention again so that [12,64] the embeddind dimensions are back together for the concatenation: [ 1,196,12,64]
- viii) Concatenate: merge 12 attention heads to get back the embedding dimension 768=12*64: [1. 196, 768]
- ix)   Project back to residual states shape(which happens to be the same): [1,196,768]

In [None]:
class SiglipAttention(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.dropout = config.attention_dropout
        
        # The linear layers whose outputs will result in query, key, value respectively
        # This is just one unified set of projection heads across the the multi head attention module
        # i.e there arent multiple projection layers defined for the single attention heads
        # The output of this will be reshaped (see def forward). This is more memory efficient and
        # enables the use of hugging face weights
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)   
        
        # The Final Linear Layer: To project to the desired output size/space
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    # Note: the hidden states are the embeddings
    def forward(self, hidden_states):
        # the hidden states are the embeddings of the patches, so (batch_size, num_patches, embed_dim)
        # Here B: batch_size = 1, T: num_tokens=num_patches = 196, C: embedding_dimension = 768
        B, T, C = hidden_states.shape
        
        # Query, Key and Values
        # These are the q, k,v values across all attention heads. There are 3 large vectors
        # vectorized dot-product(mat-mul) of hidden_states =[1, 196, 768] & q_proj = [768,768]
        # q,k,v_states: shape = [196x768]*[768x768] = [ 196x768] across all batches = [1,196,768]
        q_states = self.q_proj(hidden_states)        
        k_states = self.k_proj(hidden_states)    
        v_states = self.v_proj(hidden_states)
        
        # Reshape Multi Head Attention into 12 units        
        # Divide the Q, K, V vectors across all attention heads. This is what C // self.num_attention_heads does
        # We split along the embedding dimension 768 (and not along the number of patches = 196)
        # Hence qualitatively this is like splitting the large embedding vector into smaller pieces corresponding to the 12 individual heads
        # This step is what we do differently: We do not concatenate anymore. 
        # i.e. in the previous step it was already concatenated and at 768. In this step we split 768 into 12*64 
        # so q,k,v_states: ([1, 196, 768]) --> becomes torch.Size([1, 196, 12, 64])
        q_states = q_states.view(B, T, self.num_heads, C // self.num_heads)
        k_states = k_states.view(B, T, self.num_heads, C // self.num_heads)
        v_states = v_states.view(B, T, self.num_heads, C // self.num_heads)

        # Transpose the states so that the dot product dimensions i,e [196,64] are the last 2.
        # first 2 dimensions are batch_size and num_heads. This enables easy vectorized mat-mul      
        # q_states.transpose(1, 2) swaps dimension-index-1 with dimension-index-2 i.e. 2nd and 3rd dimension
        # so q,k,v_states [1,196,12,64] ---> become [1,12,196,64]  
        q_states = q_states.transpose(1, 2)
        k_states = k_states.transpose(1, 2)
        v_states = v_states.transpose(1, 2)

        # The scale dk = 64
        dk = k_states.size(-1)

        # Attention-Filter = Q*Ktranspose (this is a square)
        # k_states.transpose(-2, -1) swaps dimension-index-2 with dimension-index-1 i.e. last two dimensions
        # vectorized dot-product(mat-mul) of q=[1,12,196,64] & k_transpose = [1,12,64,196]
        # attn: shape = [196x64]*[64x196] = [ 196x196] across all 12 heads = [1,12,196,196]       
        attn = q_states @ k_states.transpose(-2, -1)

        # Scaled attention-filter: attn shape = [1,12,196,196]  
        attn = attn/math.sqrt(dk)

        # Apply softmax to get probability distribution: attn shape = [1,12,196,196]
        # Since attn is [196, 196] it makes no sense to normalize for the entire square 
        # Apply softmax  along dimension -1 = dimension 1. This means "across the columns" i.e. rowwise
        attn = F.softmax(attn, dim=-1).to(q_states.dtype)

        # Dropout : attn shape = = [1,12,196,196]
        # Why is dropout being applied before multiplication by v_states
        attn = F.dropout(attn, p=self.dropout, training=self.training)

        # Weighted Sum: allows information flow between tokens(patches)
        # vectorized dot-product(mat-mul) of attn=[1,12,196,196] & v_states = [1,12,196,64]
        # attn: shape = [196x196]*[196x64] = [ 196x64] across all 12 heads = [1,12,196,64]
        attn = attn @ v_states

        # Transpose it back to the original q_states view where [12,64] are the end
        # so that they can be fused to make the original 12*64 = 768
        # attn.transpose(1, 2) swaps dimension-index-1 with dimension-index-2 i.e. 2nd and 3rd dimension
        # attn [1,12,196,64] ---> becomes [1,196,12,64]
        attn = attn.transpose(1, 2)

        # i) This reshaping concatenates the 12*64 outputs from 12 attention heads back to the embedding_dim = 768
        # ii) Operations like transpose can make memory layout very inefficient.
        #     Make sure that the vector is contiguous in memory for efficient implementation
        # attn [1,196,12,64] ---> becomes [1,196,768]
        attn = attn.reshape(B, T, C).contiguous()
        
        # The Final Linear Layer: To project to the desired output size/space
        # vectorized dot-product(mat-mul) of attn =[1, 196, 768] & out_proj = [768,768]
        # attn: shape = [196x768]*[768x768] = [ 196x768] across all batches = [1,196,768]
        attn = self.out_proj(attn)
        
        return attn

## Building Block 4: SiglipMLP - Multi Layer Perception
This is just a bunch of linear layers to map the hidden state to some other output dimension. \
Since the intermediate size is quite large at 3072 , you can learn more complex relations , at a higher dimension.

In [None]:
class SiglipMLP(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = nn.functional.gelu(hidden_states, approximate="tanh")
        hidden_states = self.fc2(hidden_states)
        return hidden_states

mlp = SiglipMLP(SiglipVisionConfig(hidden_size=768, intermediate_size=3072))
mlp(torch.randn(1, 196, 768)).shape

## Building Block 5: SiglipEncoderLayer
Single Layer of the Encoder

In [None]:
class SiglipEncoderLayer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.embed_dim = config.hidden_size        
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = SiglipAttention(config)        
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(config)


    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

encoder_layer = SiglipEncoderLayer(SiglipVisionConfig(hidden_size=768, intermediate_size=3072))
encoder_layer(torch.randn(1, 196, 768)).shape

## Building Block 6: SiglipEncoder
The Encoder with many SiglipEncoderLayers

In [None]:
class SiglipEncoder(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])


    def forward(self, hidden_states):
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
            
        # Adding this for better readability and understanding for first timers
        last_hidden_states = hidden_states
        return last_hidden_states

encoder = SiglipEncoder(SiglipVisionConfig(hidden_size=768, intermediate_size=3072))
encoder(torch.randn(1, 196, 768)).shape

## Building Block 7: SiglipVisionTransformer

In [None]:
class SiglipVisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(config)
        self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, pixel_values):
        hidden_states = self.embeddings(pixel_values)
        last_hidden_states = self.encoder(hidden_states)
        last_hidden_states = self.post_layernorm(last_hidden_states)
        return last_hidden_states

our_siglip_transformer = SiglipVisionTransformer(SiglipVisionConfig(hidden_size=768, intermediate_size=3072))
our_siglip_transformer(image_tensor).shape

## The Grand Finale: SiglipVisionModel

In [None]:
class SiglipVisionModel(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.vision_model = SiglipVisionTransformer(config)

    def forward(self, pixel_values):
        return self.vision_model(pixel_values)

our_siglip_model = SiglipVisionModel(SiglipVisionConfig(hidden_size=768, intermediate_size=3072))
our_siglip_model(image_tensor).shape

our_siglip_model

## Import the Pretrained SiglipVisionModel from Hugging Face
- This model will be imported as HFSiglipVisionModel
- from_pretrained means the entire Model with pretrained weights from Hugging Face
- "google/siglip-base-patch16-224": Is the model checkpoint. patch16 means 16x16 patches. 224 means it uses a 224x224 image as input

Print the hf_vision_model at the end. It should have all the layers in the SIGLIP : VISION TRANSFORMER ARCHITECTURE DIAGRAM specified in the Readme.Md
- **i) The Embeddings:** with Patch Embeddings and Position Embeddings
- **ii) Encoder :** with 12x Single Encoder layers. Each Encoder layer will have layer_norm1, self_attention, layer_norm2, mlp. \
  Each self_attn(multi head attention) block wil have K, Q, V and a out_proj layer (the final linear layer after the concatenation). \
  Each MLP will have fc1, Gelu and fc2
- **iii) Post Layer Norm**

In [None]:
from transformers import SiglipVisionModel as HFSiglipVisionModel
from transformers import SiglipVisionConfig as HFSiglipVisionConfig

# HF output and HF State dictionary
hf_vision_model = HFSiglipVisionModel.from_pretrained("google/siglip-base-patch16-224", 
                                                    config=HFSiglipVisionConfig(vision_use_head=False))
hf_vision_model

In [None]:
hf_state_dict = hf_vision_model.state_dict()

## Compare Our Output vs HF Output
As in we are not really comparing the output. We are only checking if the keys in the original Hugging Face Vision Model and the keys in the Vision Model that we defined match

In [None]:
our_state_dict = our_siglip_model.state_dict()
our_siglip_model.load_state_dict(hf_state_dict)