In [None]:
'''
# Do this once. 
!curl -L http://i.imgur.com/8o9DXSj.jpeg --output image.jpg 
# Make sure to restart your runtime before running again
!pip install transformers!pip install transformers[sentencepiece]
!pip install transformers[sentencepiece]
!pip install sentencepiece
'''

## Import the Pretrained SiglipVisionModel from Hugging Face
- This model will be imported as SiglipVisionModel
- from_pretrained means the entire Model with pretrained weights from Hugging Face
- "google/siglip-base-patch16-224": Is the model checkpoint. patch16 means 16x16 patches. 224 means it uses a 224x224 image as input

Print the hf_vision_model at the end. It should have all the layers in the SIGLIP : VISION TRANSFORMER ARCHITECTURE DIAGRAM specified in the Readme.Md
- **i) The Embeddings:** with Patch Embeddings and Position Embeddings
- **ii) Encoder :** with 12x Single Encoder layers. Each Encoder layer will have layer_norm1, self_attention, layer_norm2, mlp. \
  Each self_attn(multi head attention) block wil have K, Q, V and a out_proj layer (the final linear layer after the concatenation). \
  Each MLP will have fc1, Gelu and fc2
- **iii) Post Layer Norm**

In [None]:
from transformers import SiglipVisionModel, SiglipVisionConfig
model_checkpoint = "google/siglip-base-patch16-224"
vision_model = SiglipVisionModel.from_pretrained(model_checkpoint, 
                                                 config=SiglipVisionConfig(vision_use_head=False))
vision_model

## Input Image + Preprocess Image
The model cannot accept the image as is. 
- It has to be resized to 224x224
- It has to be converted to a tensor
- It has to be normalized: these numbers come from the Imagenet dataset (industry standard)
- Unsqueeze the tensor to include the batch dimension so that the transformer model can use it (in this case batch dimension is 1). (3,224,224) --> unsqueeze -->(1,3,224,224)

In [None]:
from PIL import Image

img = Image.open("image.jpg")
img

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from dataclasses import dataclass
from torchvision import transforms

def preprocess_image(image, image_size=224):
    # define the preprocess operation
    preprocess = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std =[0.229, 0.224, 0.225]
        )
    ])

    # actually preprocess the image
    image_tensor = preprocess(image)
    #(3,224,224) --> unsqueeze -->(1,3,224,224)
    image_tensor = image_tensor.unsqueeze(0)
    return image_tensor

image_tensor = preprocess_image(img)

# Patches and embeddings

## Patch embeddings
details in Readme.md

In [None]:
# Embedding parameters. These are the parameters used by PaliGemma2 model
# embed_dim means that each patch will be converted to a vector of dimension = 768 (that is the embedding output)
embed_dim = 768 
patch_size = 16
image_size = 224
num_patches = (image_size // patch_size)**2

## Patch embedding
with torch.no_grad():
    '''
    torch.no grad means we are not going to update the weights of the convolution filter.
    A nn.conv2d filter with random weights has been created. 
    The patch embeddings using this filter will be calculated.
    '''
    # i) input = image_tensor
    input = image_tensor
    # ii) layer = patch_embedding_filter
    # This is like a mini __init__
    patch_embedding_filter = nn.Conv2d(in_channels =3,
                                out_channels= embed_dim,
                                kernel_size = patch_size,
                                stride = patch_size)
    # iii) output = patch_embeddings
    # This one is a like a mini forward    
    patch_embeddings = patch_embedding_filter(input)

# Flatten the patches
# After flattening (1, embed_dim , num_patches) = (1,768,196)
flattened_patch_embeddings = patch_embeddings.flatten(start_dim =2, end_dim =-1)
# (1,768,196) -> (1,196,768) = (1, num_patches, embed_dim) 
flattened_patch_embeddings = flattened_patch_embeddings.transpose(1,2)

print("------ PATCH EMBEDDINGS -------")
print(" The following would show there are 14 patches on the height & 14 on the width.\
Each patch has been converted to a vector of 768. \
Total number of patches = 14x14 = 196 \n")
print("num_patches   =", num_patches)
print("i)   input : image_tensor.shape : ", image_tensor.shape)
print("ii)  layer : patch_embedding_filter  :", patch_embedding_filter)
print("iii) output: patch_embeddings.shape : ", patch_embeddings.shape)
print("iii) output: flattened_patch_embeddings.shape : ", flattened_patch_embeddings.shape)


## Position Embeddings
details in Readme.md

In [None]:
## Position Embeddings
'''
Find out why torch.no_grad() is not used here ?. 
'''

# i) input = position_ids. 
# Notice that there is no image_tensor involved for position_embeddings. 
# Its just a lookup based on input position ids.
# the expand((1,-1)) just means expand it by the batch dimension so that the transformer can use it
position_ids = torch.arange(num_patches).expand((1,-1))
input = position_ids
# ii) layer = position_embedding_lookup 
position_embedding_lookup = nn.Embedding(num_patches, embed_dim)
# iii) output = position_embeddings
position_embeddings = position_embedding_lookup(position_ids)

print("\n------ POSITION EMBEDDINGS -------")
print("i)   input : position_ids.shape : ", position_ids.shape)
print("ii)  layer : position_embedding_lookup :", position_embedding_lookup)
print("iii) output: position_embeddings.shape : ", position_embeddings.shape)
print("\n")
print("i)   input : the list of position_ids \n", position_ids)

## Total Embeddings (patch & position embeddings)

In [None]:
## Total Embeddings
embeddings = flattened_patch_embeddings + position_embeddings
print("\n------ BOTH EMBEDDINGS : patch and position embeddings -------")
print("embeddings.shape :", embeddings.shape)

## Visualize Embeddings : Before Training
details in Readme.md

embeddings[0]: [0] refers to the batch dimension and likely the first image in the batch. Since there is only one image in the batch, it would be [0] index

In [None]:
import matplotlib.pyplot as plt

def visualize_embeddings(embeds_viz, title):
    plt.figure(figsize=(8,4))
    plt.imshow(embeds_viz, aspect='auto', cmap='viridis')
    plt.colorbar()
    
    plt.title(title)
    plt.xlabel('Embedding dimension')
    plt.ylabel('Patch number')
    plt.show()

# Visualize the embeddings before training 
print("Flattened Patch Embeddings: Before Training")
print("They should look all random, since the weights of the conv2d filter are random at the initialization")
embeds_viz = flattened_patch_embeddings[0].detach().numpy() #shape: [196, 768]
print("flattened_embeds_viz.shape =", embeds_viz.shape)
visualize_embeddings(embeds_viz, "Flattened Patch Embeddings: Before Training")

print("\n\nPosition Embeddings: Before Training")
print("They should look all random, since the weights of nn.Embedding lookup are random at the initialization")
embeds_viz = position_embeddings[0].detach().numpy() #shape: [196, 768]
print("position_embeds_viz.shape =", embeds_viz.shape)
visualize_embeddings(embeds_viz, "Position Embeddings: Before Training")

print("\n\nEmbeddings(both, flattened_patch + position): Before Training")
print("They should look all random, since the weights are random at the initialization. \n\
Notice total_embeddings look pretty similar to patch_embeddings, despite adding position_embeddings.\n\
This is likely because position_embeddings are supposed to be small displacement vectors. \n\
And, that the change they have caused is not visible in such a visualization.")
embeds_viz = embeddings[0].detach().numpy() #shape: [196, 768]
print("total_embeds_viz.shape =", embeds_viz.shape)
visualize_embeddings(embeds_viz, "Embeddings(both,flattened_patch + position): Before Training")

## Visualize Embeddings : After Training
- Its not that the  nn.Conv2d Filter , and nn.Embedding lookup table used to create untrained patch and position embeddings respectively have been trained. i.e. there is no model training or embeddings training step in **vit_step1_img_prepocess_embeddings.ipynb**
- Instead download the pre-trained SiglipVisionModel from Hugging face . Visualize the trained embeddings from this model.
- details in Readme.md

trained_total_embeddings[0]: [0] refers to the batch dimension and likely the first image in the batch. Since there is only one image in the batch, it would be [0] index

In [None]:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(model_checkpoint)

vision_model.eval()
inputs = processor(images=img, return_tensors="pt")

with torch.no_grad():
    trained_patch_embeddings = vision_model.vision_model.embeddings.patch_embedding(inputs.pixel_values)    
    trained_position_embeddings = vision_model.vision_model.embeddings.position_embedding(position_ids)
    trained_total_embeddings = vision_model.vision_model.embeddings(inputs.pixel_values)


# Flatten the patches
# After flattening (1, embed_dim , num_patches, embed_dim) = (1,196, 768)
# Note: no need to transpose these (like the before training ones). These come out transposed
trained_flattened_patch_embeddings = trained_patch_embeddings.flatten(start_dim =2, end_dim =-1)

print("trained_patch_embeddings.shape           : ", trained_patch_embeddings.shape)
print("trained_flattened_patch_embeddings.shape : ", trained_flattened_patch_embeddings.shape)
print("trained_position_embeddings.shape        : ", trained_position_embeddings.shape)
print("trained_total_embeddings.shape           : ", trained_total_embeddings.shape)

# Visualize the embeddings after training 
print("\n\nFlattened Patch Embeddings: After Training i.e. from pretrained Hugging Face SiglipVision model")
embeds_viz = trained_flattened_patch_embeddings[0].detach().numpy() #shape: [196, 768]
print("flattened_embeds_viz.shape =", embeds_viz.shape)
visualize_embeddings(embeds_viz, "Flattened Patch Embeddings: After Training")

print("\n\nPosition Embeddings: After Training i.e. from pretrained Hugging Face SiglipVision model")
embeds_viz = trained_position_embeddings[0].detach().numpy() #shape: [196, 768]
print("position_embeds_viz.shape =", embeds_viz.shape)
visualize_embeddings(embeds_viz, "Position Embeddings: After Training")

print("\n\nEmbeddings(both, flattened_patch + position): After Training i.e. from pretrained Hugging Face SiglipVision model")
embeds_viz = trained_total_embeddings[0].detach().numpy() #shape: [196, 768]
print("embeds_viz =", embeds_viz.shape)
visualize_embeddings(embeds_viz, "Embeddings(both,flattened_patch + position): After Training")