## Imports

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from dataclasses import dataclass

## Attention Formula , Single Head of Attention, Multi Head Attention
See Readme.md

## Single Head Attention, Multi Head Attention
### Precursors to SiglipAttention
Simple introductory implementation of Single Head Attention, and Multi Head Attention that serve as a precursor to SiglipAttention. 

Note: This is not how Hugging face implements it. To be able to use Hugging Face weights, we would have to make use of vectorization and implement it parallely. This will be done in the subsequent section of SiglipAttention

In [None]:
import math

class Head(nn.Module):
    """ Single Head Attention: A single head of the multi head attention module """
    def __init__(self, n_in, head_size, context_length):
        super().__init__()
        self.head_size = head_size
        # The linear layers whose outputs will result in query, key, value respectively
        self.query_linear = nn.Linear(n_in, head_size, bias=False)
        self.key_linear   = nn.Linear(n_in, head_size, bias=False)
        self.value_linear = nn.Linear(n_in, head_size, bias=False)

    def forward(self,x):
        # Note B, T, C is not used here. But it is used in the SiglipAttention class
        # B:batch_size,
        # T:num_tokens (same as number of embeddings = 196) 
        # C: embedding_size=hidden_size = 768
        B, T, C = x.shape
        # Q, K, V       
        q = self.query_linear(x)
        k = self.key_linear(x)
        v = self.value_linear(x)

        # The scale dk
        dk = self.head_size
        # Attention-Filter = Q*Ktranspose
        attn = q@k.transpose(-2,-1)
        # The rest of the operations to calculate the final attention
        attn = attn/math.sqrt(dk)
        attn = F.softmax(attn, dim =-1)
        attn = attn @ v
        return attn

class MultiHeadAttention(nn.Module):
    """ Multi Head Attention: Concatenate outputs of several single head attention modules """

    def __init__(self, num_of_heads, n_in, head_size, context_length):
        super().__init__()
        # Indvidual Single Attention Heads
        self.attn_heads = [Head(n_in, head_size, context_length) for _ in range(num_of_heads)]
        # The Final Linear Layer: To project to the desired output size/space
        self.out_proj  = nn.Linear(n_in, n_in)

    def forward(self, x):
        # Indvidual Single Attention Heads
        out = [h(x) for h in self.attn_heads]
        # Concatenate the individual attention heads
        out = torch.concat(out, -1)
        # The Final Linear Layer: To project to the desired output size/space
        # Here you need to project it back to the size of the incoming hidden_states/ residual 
        # to be able to add and create a residual
        out = self.out_proj(out)
        return out
        

## SiglipVisionConfig

In [None]:
'''
@dataclass is a decorator in Python that automatically generates special methods like __init__, __repr__, and __eq__ for a class. 
This simplifies the creation of classes primarily used for storing data. 
It reduces boilerplate code and improves readability, especially when dealing with objects that mainly hold data.
'''
@dataclass
class SiglipVisionConfig:   
    image_size: int = 224
    patch_size: int = 16
    hidden_size: int = 768

    num_channels: int = 3
    num_attention_heads: int = 12
    attention_dropout: float = 0.0

## SiglipAttention
- Vectorized implementation of Multi head attention. Same implementation as Hugging Face (different from the non vectorized implementation above)
- You don't have Single Attention Heads. You process all single attention heads parallely
- This is more memory efficient
- This enables using Hugging Face's pretrained weights in our model

#### Vectorized Implementation
In the vectorized implementation there might seem to be a lot of transposes, shape changes etc. \
Here is a high leve overview of it. 
- i)   q, k, v_states are the same dimensions as hidden_states = [1,196,768] = [batch, num_patches, embedding_dimension] 
- ii)  split q,k,v_states into 12 attention heads, along the embedding_dimension (768)  = [1,196,12,64]
- iii) transpose q,k, v_states so that its [batch, num_heads] first and then [196,64] for vectorized multiplication
- iv)  attention filter = q*k_transpose is a square: [1,12, 196,196]
- v)   scaled attention filter, softmax, dropout are all the same square = [1,12,196,196]
- vi)  mutltiply attention_weights*v and the dimesion is back to : [ 1, 12, 196,64]
- vii) transpose attention again so that [12,64] the embeddind dimensions are back together for the concatenation: [ 1,196,12,64]
- viii) Concatenate: merge 12 attention heads to get back the embedding dimension 768=12*64: [1. 196, 768]
- ix)   Project back to residual states shape(which happens to be the same): [1,196,768]

In [None]:
class SiglipAttention(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.dropout = config.attention_dropout
        
        # The linear layers whose outputs will result in query, key, value respectively
        # This is just one unified set of projection heads across the the multi head attention module
        # i.e there arent multiple projection layers defined for the single attention heads
        # The output of this will be reshaped (see def forward). This is more memory efficient and
        # enables the use of hugging face weights
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)   
        
        # The Final Linear Layer: To project to the desired output size/space
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    # Note: the hidden states are the embeddings
    def forward(self, hidden_states):
        # the hidden states are the embeddings of the patches, so (batch_size, num_patches, embed_dim)
        # Here B: batch_size = 1, T: num_tokens=num_patches = 196, C: embedding_dimension = 768
        B, T, C = hidden_states.shape
        print("\n-----------SiglipAttention: forward details-------------")
        print("----------------------------------------------------------------------------------------------------------")
        print("Notice the dimensions at each stage to understand the vectorized implementation of multi head attention")
        print("----------------------------------------------------------------------------------------------------------")
        print("B (batch)              :", B)
        print("T (#tokens = #patches) :", T)
        print("C (embed_dim)          :", C)
        
        # Query, Key and Values
        # These are the q, k,v values across all attention heads. There are 3 large vectors
        # vectorized dot-product(mat-mul) of hidden_states =[1, 196, 768] & q_proj = [768,768]
        # q,k,v_states: shape = [196x768]*[768x768] = [ 196x768] across all batches = [1,196,768]
        q_states = self.q_proj(hidden_states)        
        k_states = self.k_proj(hidden_states)    
        v_states = self.v_proj(hidden_states)
        print("\n\n-------------------------------------------------------------------------------------------------")
        print("q_states, k_states, v_states: output of the linear layers")
        print("-------------------------------------------------------------------------------------------------")
        print("q_states.shape =", q_states.shape)
        print("k_states.shape =", k_states.shape)
        print("v_states.shape =", v_states.shape) 
        
        # Reshape Multi Head Attention into 12 units        
        # Divide the Q, K, V vectors across all attention heads. This is what C // self.num_attention_heads does
        # We split along the embedding dimension 768 (and not along the number of patches = 196)
        # Hence qualitatively this is like splitting the large embedding vector into smaller pieces corresponding to the 12 individual heads
        # This step is what we do differently: We do not concatenate anymore. 
        # i.e. in the previous step it was already concatenated and at 768. In this step we split 768 into 12*64 
        # so q,k,v_states: ([1, 196, 768]) --> becomes torch.Size([1, 196, 12, 64])
        q_states = q_states.view(B, T, self.num_heads, C // self.num_heads)
        k_states = k_states.view(B, T, self.num_heads, C // self.num_heads)
        v_states = v_states.view(B, T, self.num_heads, C // self.num_heads)
        print("\n\n-------------------------------------------------------------------------------------------------")
        print("q_states, k_states, v_states: After splitting 768 to 12 attention heads")
        print("-------------------------------------------------------------------------------------------------")
        print("q_states.shape =", q_states.shape)
        print("k_states.shape =", k_states.shape)
        print("v_states.shape =", v_states.shape) 

        # Transpose the states so that the dot product dimensions i,e [196,64] are the last 2 dimensions.
        # first 2 dimensions are batch_size and num_heads. This enables easy vectorized mat-mul      
        # q_states.transpose(1, 2) swaps dimension-index-1 with dimension-index-2 i.e. 2nd and 3rd dimension
        # so q,k,v_states [1,196,12,64] ---> become [1,12,196,64]  
        q_states = q_states.transpose(1, 2)
        k_states = k_states.transpose(1, 2)
        v_states = v_states.transpose(1, 2)
        print("\n\n-------------------------------------------------------------------------------------------------")
        print("q_states, k_states, v_states: transpose for easy vectorized dotproduct(mat-mul)")
        print("so that first 2 dimensions are [batch_size, num_heads] = [1,12]")
        print("last 2 dimensions are [num_patches, embed_dim/num_heads]= [196,64]")
        print("-------------------------------------------------------------------------------------------------")
        print("q_states.shape =", q_states.shape)
        print("k_states.shape =", k_states.shape)
        print("v_states.shape =", v_states.shape) 

        # The scale dk = 64
        dk = k_states.size(-1)
        print("\n\n-------------------------------------------------------------------------------------------------")
        print("Scale is the embedding dimension per single attention head = 768/12 = 64")
        print("Note: in the dk scale diagram in the Readme.md (section 5.4) , \
        \n - you multiply [500,50] x[50,500] and the scale dk = 500. \
        \n   dk = 500 seems to be #of patches instead of the embedding size .(This could have been a mistake) \
        \n - But here in the code, you multiply [196,64] x[64,196] for the attention filter. \
        \n  dk = 64 = embedding dimension. I think embedding dimension makes more sense" )
        print("-------------------------------------------------------------------------------------------------")
        print("dk: scale = ", dk)

        # Attention-Filter = Q*Ktranspose (this is a square)
        # k_states.transpose(-2, -1) swaps dimension-index-2 with dimension-index-1 i.e. last two dimensions
        # vectorized dot-product(mat-mul) of q=[1,12,196,64] & k_transpose = [1,12,64,196]
        # attn: shape = [196x64]*[64x196] = [ 196x196] across all 12 heads = [1,12,196,196]       
        attn = q_states @ k_states.transpose(-2, -1)
        
        # Scaled attention filter: attn shape = = [1,12,196,196]
        attn = attn/math.sqrt(dk)
        print("\n\n-------------------------------------------------------------------------------------------------")
        print(" In all the below cases attn: shape is a square i.e 196*196)")
        print("-------------------------------------------------------------------------------------------------")
        print("i)  attn.shape: scaled attention filter    : ", attn.shape)

        # Softmax rowwise to get probability distribution: attn shape = = [1,12,196,196]
        attn = F.softmax(attn, dim=-1).to(q_states.dtype)
        print("\n\n-------------------------------------------------------------------------------------------------")
        print("Apply softmax. Since attn is [196, 196] it makes no sense to normalize for the entire square \n \
        Apply softmax to get probability distribution along dimension -1 which means across the columns i.e. rowwise")
        print("-------------------------------------------------------------------------------------------------")
        print("ii)  attn.shape: after softmax              : ", attn.shape)

        # Droput : attn shape = = [1,12,196,196]
        # Why is dropout being applied before multiplication by v_states
        attn = F.dropout(attn, p=self.dropout, training=self.training)
        print("iii) attn.shape: after dropout              : ", attn.shape)

        # Weighted Sum: allows information flow between tokens(patches)
        # vectorized dot-product(mat-mul) of attn=[1,12,196,196] & v_states = [1,12,196,64]
        # attn: shape = [196x196]*[196x64] = [ 196x64] across all 12 heads = [1,12,196,64]
        attn = attn @ v_states
        print("\n\n-------------------------------------------------------------------------------------------------")
        print("Multiply attn with v_states. Attention is back to [196,64]")
        print("-------------------------------------------------------------------------------------------------")
        print("iv) attn.shape: mult with v_states         : ", attn.shape)

        # Transpose it back to the original q_states view where [12,64] are the end
        # so that they can be fused to make the original 12*64 = 768
        # attn.transpose(1, 2) swaps dimension-index-1 with dimension-index-2 i.e. 2nd and 3rd dimension
        # attn [1,12,196,64] ---> becomes [1,196,12,64]
        attn = attn.transpose(1, 2)
        print("\n\n-------------------------------------------------------------------------------------------------")
        print("Transpose it back to the original q_states view where [12,64] are the end")
        print("[12,64] can be easily multiplied to recover the original embed_dim = 768")
        print("-------------------------------------------------------------------------------------------------")
        print("v) attn.shape: after transpose            : ", attn.shape)

        # i) This reshaping concatenates the 12*64 outputs from 12 attention heads back to the embedding_dim = 768
        # ii) Operations like transpose can make memory layout very inefficient.
        #     Make sure that the vector is contiguous in memory for efficient implementation
        # attn [1,196,12,64] ---> becomes [1,196,768]
        attn = attn.reshape(B, T, C).contiguous()
        print("\n\n-------------------------------------------------------------------------------------------------")
        print("Concatenate 12 attention heads so that [12x64] = 768")
        print("-------------------------------------------------------------------------------------------------")
        print("vi) attn.shape: after reshaping to B, T, C : ", attn.shape)
        
        # The Final Linear Layer: To project to the desired output size/space
        # vectorized dot-product(mat-mul) of attn =[1, 196, 768] & out_proj = [768,768]
        # attn: shape = [196x768]*[768x768] = [ 196x768] across all batches = [1,196,768]
        attn = self.out_proj(attn)
        print("vii) attn.shape: after out_proj             : ", attn.shape)
        print("\n\n------- End of SiglipAttention: Forward -------------- \n")
        
        return attn

## Create random Hidden State/ Embeddings and pass through Attention Head

In [None]:
batch_size = 1
num_patches = 196
hidden_size = 768 # same as embedding dimension
attention_dropout=0.0
num_attention_heads=12
hidden_size=768

hidden_states = torch.randn(batch_size, num_patches, hidden_size)
config = SiglipVisionConfig(
    attention_dropout=attention_dropout,
    num_attention_heads=num_attention_heads,
    hidden_size=hidden_size
)
attention = SiglipAttention(config)

attention

### Sanity Check #1
The input shape and output shape match i.e the incoming hidden states and the output hidden states/ residual should have the same shape by design. See Siglip Transformer Encoder Atrchitecture for better understanding

In [None]:
output = attention(hidden_states)

print(f"Input shape: {hidden_states.shape}")
print(f"Output shape: {output.shape}")

## Compare Our Output vs HF Output
HF Outout uses default Vision Embeddings. \
Our Output uses our custom Vision Embeddings. \
There is also a mapping between hf_state_dict and our_state_dict.

In [None]:
from transformers import AutoProcessor
from transformers import SiglipVisionModel as HFSiglipVisionModel
from transformers import SiglipVisionConfig as HFSiglipVisionConfig

# HF output and HF State dictionary
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
hf_vision_model = HFSiglipVisionModel.from_pretrained("google/siglip-base-patch16-224", 
                                                      config=HFSiglipVisionConfig(vision_use_head=False))

hf_state_dict = {k.replace("vision_model.embeddings.", ""): v \
                 for k, v in hf_vision_model.state_dict().items() if "vision_model.embeddings." in k}

hf_vision_model

In [None]:
hf_state_dict = hf_vision_model.vision_model.state_dict()
print("\n --------all the key names in hf_state_dict ---------\n")
for k, v in hf_state_dict.items() :
    print("key:", k)

In [None]:
our_state_dict = attention.state_dict()
print("\n --------all the key names in our_state_dict ---------\n")
for k, v in our_state_dict.items() :
    print("key:", k)

## Sanity Check #2
- We have not implemented the entire Siglip : Transformer Image encoder yet
- Hence compare the output from our SiglipAttention vs HF Vision Model's Zeroth Encoder's multi head attention module

If they match 
- it means that we were able to load the weights from Hugging Face Siglip's zero-th encoder layers's multi head attention module succesfully into our SiglipAttention module
- it means that our def forward implementation of SiglipAttention is correct / same as the def forward of Hugging Face Siglip Attention

In [None]:
# This is the key mapping between our model(on the left) and the Hugging face model keys on the right
# Notice that our model just does not have the encoder.layers.0.self_attn. part in the name.
# This is because we have not yet implemented the full encoder. We have only implemented the SiglipAttention
key_mapping = {
    'k_proj.weight': 'encoder.layers.0.self_attn.k_proj.weight',
    'k_proj.bias': 'encoder.layers.0.self_attn.k_proj.bias',
    'v_proj.weight': 'encoder.layers.0.self_attn.v_proj.weight',
    'v_proj.bias': 'encoder.layers.0.self_attn.v_proj.bias',
    'q_proj.weight': 'encoder.layers.0.self_attn.q_proj.weight',
    'q_proj.bias': 'encoder.layers.0.self_attn.q_proj.bias',
    'out_proj.weight': 'encoder.layers.0.self_attn.out_proj.weight',
    'out_proj.bias': 'encoder.layers.0.self_attn.out_proj.bias'
}

for our_key, hf_key in key_mapping.items():
    our_state_dict[our_key].copy_(hf_state_dict[hf_key])

attention.load_state_dict(our_state_dict)

with torch.no_grad():
    our_output = attention(hidden_states)
    hf_output = hf_vision_model.vision_model.encoder.layers[0].self_attn(hidden_states)[0]
    max_diff = torch.max(torch.abs(our_output - hf_output))
    print(f"\n Max difference between our output and HF output: {max_diff:.6f}")
    print((torch.isclose(our_output, hf_output, atol=1e-6)==0).sum())