# Creating a VLM
To develop a __Visual Language Model__ (VLM), it's essential to handle both textual and visual inputs. For text, we utilize a Large Language Model (LLM), while images are processed through the CLIP model. It's important to note, though, that CLIP includes a minor language component that is superfluous for our needs. 

## Load the components

In [3]:
# Global constants for LLM
llm_model_id = "google/gemma-2b-it"
llm_model_folder = f"./models/{llm_model_id.split('/')[-1]}"

# Global constants for CLIP model
clip_model_id = "google/siglip-base-patch16-384"
clip_model_folder = f"./models/{clip_model_id.split('/')[-1]}"

In [26]:
from transformers import AutoProcessor, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM

llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_folder, local_files_only=True)
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_folder, local_files_only=True)

clip_model = AutoModel.from_pretrained(clip_model_folder, local_files_only=True)
clip_processor = AutoProcessor.from_pretrained(clip_model_folder, local_files_only=True)

# Get the vision model and the image processor from CLIP
clip_vision_model = clip_model.vision_model
image_processor = clip_processor.image_processor 


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Adapt the components

CLIP typically generates a singular embedding to represent an entire image. This approach is less desirable for our purposes, as we favor maintaining a sequence of image embeddings. This sequence-based approach preserves more detailed information within the image, as opposed to condensing it into a single, less descriptive vector.

In the CLIP model, an image goes through this whole process:

1. __Preprocessing of the Image__:
  - The input image is first resized to the specified `image_size` of 384x384 pixels.
    It’s then divided into patches of `patch_size` 16x16 pixels. This results in a grid of patches (24x24 = 576 patches for a 384x384 image), as the stride matches the patch size, ensuring no overlap.

2. __Patch Embedding__:
  - Each patch is flattened and passed through a convolutional layer (`Conv2d`) that acts as a patch embedding layer. This converts each patch into a 768-dimensional vector (`hidden_size` = 768). This step transforms the spatial patch information into a format suitable for processing by the subsequent transformer layers.

3. __Position Embedding__:
  - Positional embeddings are added to the patch embeddings to retain information about the original position of each patch within the image. The `position_embedding` component embeds the sequential position of each patch into its representation.

4. __SiglipEncoder Processing__:
  - The resulting embedded patches, now augmented with positional information, are passed through the `SiglipEncoder`. This encoder consists of multiple (`num_hidden_layers` = 12) identical layers, each comprising:
    - __Self-Attention Mechanism__: Each layer contains a `SiglipAttention` module that computes self-attention for each patch embedding, allowing the model to weigh the importance of different patches based on their content and relation to other patches.
    - __Intermediate Feed-Forward Network (MLP)__: After attention computation, the data passes through a two-layer feed-forward network (MLP) with a GELU-Tanh activation function (`hidden_act = "gelu_pytorch_tanh"`) between the layers, which allows for non-linear processing of each patch's information.
    - __Layer Normalization__: Each attention and MLP operation is followed by layer normalization (`layer_norm_eps = 1e-06`) to stabilize the learning process and improve convergence.

5. __Post Layer Normalization__:
  - After processing through the encoder, a final layer normalization step is applied to the output of the last encoder layer to ensure that the output features are normalized before passing to the classification head.

6. __Classification Head__:
  - The `SiglipMultiheadAttentionPoolingHead` combines the features from all patches to form a single coherent representation of the image. This involves another round of attention to pool information across patches, followed by layer normalization and a final MLP for processing.


By the time the input data has passed through all the encoder layers, it has been transformed into a sequence of embeddings. Each embedding in this sequence corresponds to a patch of the original image, but now represents a deeply processed, high-dimensional feature vector encapsulating both the intrinsic properties of the patch and its contextual relationships with other patches in the image. 

At stage 5, we're left with a series of embedding vectors, each 768 dimensions in size, that collectively depict the processed image. CLIP, however, simplifies this array into a single embedding vector for subsequent comparison with text embeddings. This simplification occurs in stage 6, where the model aggregates the detailed sequence of image embeddings into one cohesive vector that captures the core attributes of the image. An attention layer along with a Multilayer Perceptron (MLP) is employed to forge this final image representation. From the resultant sequence, only the initial vector is retained, serving as the all-encompassing embedding for the image. 

To keep the whole sequence of embedding, HuggingFace allows us to access the intermediate outputs of the model so that we can retrieve the output we want. You only have to pass `output_hidden_states=True` when you call the model. It will return the hidden states of the model at the output of each layer plus the initial embedding outputs.

The problem is that you don't know where they come from. To know which layer corresponds to which output you have to look at the model implementation in HuggingFace and trace `output_hidden_states` through all the submodules to find out which ones are used. For SigLIP the code is [here](https://github.com/huggingface/transformers/blob/df1542581ee89107eea0569ee044fa8797b66ab0/src/transformers/models/siglip/modeling_siglip.py). It uses 13 layers: input embeddings + 12 layers from `SiglipEncoder`.

In [25]:
import requests
import torch
from PIL import Image


# Return the intermediate outputs with output_hidden_states=True
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
    clip_vision_model_output = clip_vision_model(**inputs, output_hidden_states=True)

print(f"Vision model output has: {list(clip_vision_model_output.keys())}")

# The intermediate outputs are stored at hidden_states
hidden_states = clip_vision_model_output.hidden_states
print(f"We have the output of {len(hidden_states)} layers")

# We have one output for the input embedding and one for each layer of the SiglipEncoder. The output shape is always
# (batch size, num. patches, embedding size). We use patches of 16x16 pixels for an image of 384x384 => 576 patches
for idx, hidden_output in enumerate(hidden_states):
    print(f"Output shape of layer {idx + 1}: {tuple(hidden_output.shape)}")

Vision model output has: ['last_hidden_state', 'pooler_output', 'hidden_states']
We have the output of 13 layers
Output shape of layer 1: (1, 576, 768)
Output shape of layer 2: (1, 576, 768)
Output shape of layer 3: (1, 576, 768)
Output shape of layer 4: (1, 576, 768)
Output shape of layer 5: (1, 576, 768)
Output shape of layer 6: (1, 576, 768)
Output shape of layer 7: (1, 576, 768)
Output shape of layer 8: (1, 576, 768)
Output shape of layer 9: (1, 576, 768)
Output shape of layer 10: (1, 576, 768)
Output shape of layer 11: (1, 576, 768)
Output shape of layer 12: (1, 576, 768)
Output shape of layer 13: (1, 576, 768)


The problem with using `output_hidden_states=True` is that it requires too much memory to store all the hidden states. If we only want to get one of them, we can create a new model that only stores the output we need. Also, all the computation that is done after that layer is unnecessary, so we can remove these modules to reduce the memory used by the model and the computations it has to perform.

We can inspect the implementation of the vision encoder of SigLIP (`SiglipVisionTransformer`):
```python
class SiglipVisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.head = SiglipMultiheadAttentionPoolingHead(config)

    @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
    def forward(
        self,
        pixel_values,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        hidden_states = self.embeddings(pixel_values)

        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        last_hidden_state = self.post_layernorm(last_hidden_state)

        pooled_output = self.head(last_hidden_state)

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
```

As we can see it is just a Pytorch Module. It consists of four components:
- `embedding`: This is the part of the model responsible for converting raw pixel values of the input images into a higher-dimensional vector space.
- `encoder`: The encoder is the core of the model. It processes the sequence of embedded patches through several layers of self-attention and feed-forward neural networks.
- `post_layernorm`: it is used to apply normalization after the encoder to ensure that the output across different patches is normalized before any further processing.
- `head`: the "head" refers to the component that takes the output of the transformer encoder and performs a specific task, such as classification. In the context of `SiglipVisionTransformer`, this head is designed for pooling the transformer outputs into a single vector representation of the input image.

If we want the last layer of the encoder, we can get rid of the `post_layernorm` and the `head` and create a model that does the forward pass with the `embedding` and the `encoder` only.

In [35]:
from torch import nn


class ModifiedSiglipVisionModel(nn.Module):
    def __init__(self, ref_clip_vision_model: nn.Module):
        super().__init__()
        self.config = ref_clip_vision_model.config

        self.embeddings = ref_clip_vision_model.embeddings
        self.encoder = ref_clip_vision_model.encoder

    def forward(self, pixel_values):

        # Convert pixel values to embeddings
        embeddings = self.embeddings(pixel_values)

        # Process embeddings through the encoder
        encoder_outputs = self.encoder(inputs_embeds=embeddings)

        # There is only one output: last_hidden_state
        sequence_output = encoder_outputs[0]  # `encoder_outputs.last_hidden_state` is the same

        return sequence_output


# Create the vision model
vision_model = ModifiedSiglipVisionModel(ref_clip_vision_model=clip_vision_model)

# Do inference with the new model
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
    vision_model_output = vision_model(**inputs)

# It's the same output as before
assert torch.equal(vision_model_output, hidden_states[-1])
print(f"Vision model output shape: {tuple(vision_model_output.shape)}")

Vision model output shape: (1, 576, 768)


We can print the configuration and architecture of the model. __Note that the configuration does not correspond to this new module but to the original one__.

In [34]:
# Model configuration
print(f"Vision model configuration:\n{vision_model.config}")

# Model architecture
print(f"Model architecture:\n{vision_model}")

Vision model configuration:
SiglipVisionConfig {
  "attention_dropout": 0.0,
  "hidden_act": "gelu_pytorch_tanh",
  "hidden_size": 768,
  "image_size": 384,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-06,
  "model_type": "siglip_vision_model",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "transformers_version": "4.38.2"
}

Model architecture:
ModifiedSiglipVisionModel(
  (embeddings): SiglipVisionEmbeddings(
    (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), padding=valid)
    (position_embedding): Embedding(576, 768)
  )
  (encoder): SiglipEncoder(
    (layers): ModuleList(
      (0-11): 12 x SiglipEncoderLayer(
        (self_attn): SiglipAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_fe

We can compare the number of parameters before and after the prunning of the model.

In [32]:
print(f"Parameters before prunning: {sum(p.numel() for p in clip_vision_model.parameters())}")
print(f"Parameters after prunning: {sum(p.numel() for p in vision_model.parameters())}")

Parameters before prunning: 93176064
Parameters after prunning: 86087424
