# Creating a VLM
To develop a __Visual Language Model__ (VLM), it's essential to handle both textual and visual inputs. For text, we utilize a Large Language Model (LLM), while images are processed through the CLIP model. It's important to note, though, that CLIP includes a minor language component that is superfluous for our needs. 

## Load the components

In [1]:
# Global constants for LLM
llm_model_id = "google/gemma-2b-it"
llm_model_folder = f"./models/{llm_model_id.split('/')[-1]}"

# Global constants for CLIP model
clip_model_id = "google/siglip-base-patch16-384"
clip_model_folder = f"./models/{clip_model_id.split('/')[-1]}"

In [2]:
from transformers import AutoProcessor, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM

llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_folder, local_files_only=True)
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_folder, local_files_only=True)

clip_model = AutoModel.from_pretrained(clip_model_folder, local_files_only=True)
clip_processor = AutoProcessor.from_pretrained(clip_model_folder, local_files_only=True)

# Get the vision model and the image processor from CLIP
clip_vision_model = clip_model.vision_model
image_processor = clip_processor.image_processor 


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Adapt the components

CLIP typically generates a singular embedding to represent an entire image. This approach is less desirable for our purposes, as we favor maintaining a sequence of image embeddings. This sequence-based approach preserves more detailed information within the image, as opposed to condensing it into a single, less descriptive vector.

In the CLIP model, an image goes through this whole process:

1. __Preprocessing of the Image__:
  - The input image is first resized to the specified `image_size` of 384x384 pixels.
    It’s then divided into patches of `patch_size` 16x16 pixels. This results in a grid of patches (24x24 = 576 patches for a 384x384 image), as the stride matches the patch size, ensuring no overlap.

2. __Patch Embedding__:
  - Each patch is flattened and passed through a convolutional layer (`Conv2d`) that acts as a patch embedding layer. This converts each patch into a 768-dimensional vector (`hidden_size` = 768). This step transforms the spatial patch information into a format suitable for processing by the subsequent transformer layers.

3. __Position Embedding__:
  - Positional embeddings are added to the patch embeddings to retain information about the original position of each patch within the image. The `position_embedding` component embeds the sequential position of each patch into its representation.

4. __SiglipEncoder Processing__:
  - The resulting embedded patches, now augmented with positional information, are passed through the `SiglipEncoder`. This encoder consists of multiple (`num_hidden_layers` = 12) identical layers, each comprising:
    - __Self-Attention Mechanism__: Each layer contains a `SiglipAttention` module that computes self-attention for each patch embedding, allowing the model to weigh the importance of different patches based on their content and relation to other patches.
    - __Intermediate Feed-Forward Network (MLP)__: After attention computation, the data passes through a two-layer feed-forward network (MLP) with a GELU-Tanh activation function (`hidden_act = "gelu_pytorch_tanh"`) between the layers, which allows for non-linear processing of each patch's information.
    - __Layer Normalization__: Each attention and MLP operation is followed by layer normalization (`layer_norm_eps = 1e-06`) to stabilize the learning process and improve convergence.

5. __Post Layer Normalization__:
  - After processing through the encoder, a final layer normalization step is applied to the output of the last encoder layer to ensure that the output features are normalized before passing to the classification head.

6. __Classification Head__:
  - The `SiglipMultiheadAttentionPoolingHead` combines the features from all patches to form a single coherent representation of the image. This involves another round of attention to pool information across patches, followed by layer normalization and a final MLP for processing.


By the time the input data has passed through all the encoder layers, it has been transformed into a sequence of embeddings. Each embedding in this sequence corresponds to a patch of the original image, but now represents a deeply processed, high-dimensional feature vector encapsulating both the intrinsic properties of the patch and its contextual relationships with other patches in the image. 

At stage 5, we're left with a series of embedding vectors, each 768 dimensions in size, that collectively depict the processed image. CLIP, however, simplifies this array into a single embedding vector for subsequent comparison with text embeddings. This simplification occurs in stage 6, where the model aggregates the detailed sequence of image embeddings into one cohesive vector that captures the core attributes of the image. An attention layer along with a Multilayer Perceptron (MLP) is employed to forge this final image representation. From the resultant sequence, only the initial vector is retained, serving as the all-encompassing embedding for the image. 

To keep the whole sequence of embedding, HuggingFace allows us to access the intermediate outputs of the model so that we can retrieve the output we want. You only have to pass `output_hidden_states=True` when you call the model. It will return the hidden states of the model at the output of each layer plus the initial embedding outputs.

The problem is that you don't know where they come from. To know which layer corresponds to which output you have to look at the model implementation in HuggingFace and trace `output_hidden_states` through all the submodules to find out which ones are used. For SigLIP the code is [here](https://github.com/huggingface/transformers/blob/df1542581ee89107eea0569ee044fa8797b66ab0/src/transformers/models/siglip/modeling_siglip.py). It uses 13 layers: input embeddings + 12 layers from `SiglipEncoder`.

In [3]:
import requests
import torch
from PIL import Image


# Return the intermediate outputs with output_hidden_states=True
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
    clip_vision_model_output = clip_vision_model(**inputs, output_hidden_states=True)

print(f"Vision model output has: {list(clip_vision_model_output.keys())}")

# The intermediate outputs are stored at hidden_states
hidden_states = clip_vision_model_output.hidden_states
print(f"We have the output of {len(hidden_states)} layers")

# We have one output for the input embedding and one for each layer of the SiglipEncoder. The output shape is always
# (batch size, num. patches, embedding size). We use patches of 16x16 pixels for an image of 384x384 => 576 patches
for idx, hidden_output in enumerate(hidden_states):
    print(f"Output shape of layer {idx + 1}: {tuple(hidden_output.shape)}")

Vision model output has: ['last_hidden_state', 'pooler_output', 'hidden_states']
We have the output of 13 layers
Output shape of layer 1: (1, 576, 768)
Output shape of layer 2: (1, 576, 768)
Output shape of layer 3: (1, 576, 768)
Output shape of layer 4: (1, 576, 768)
Output shape of layer 5: (1, 576, 768)
Output shape of layer 6: (1, 576, 768)
Output shape of layer 7: (1, 576, 768)
Output shape of layer 8: (1, 576, 768)
Output shape of layer 9: (1, 576, 768)
Output shape of layer 10: (1, 576, 768)
Output shape of layer 11: (1, 576, 768)
Output shape of layer 12: (1, 576, 768)
Output shape of layer 13: (1, 576, 768)


The problem with using `output_hidden_states=True` is that it requires too much memory to store all the hidden states. If we only want to get one of them, we can create a new model that only stores the output we need. Also, all the computation that is done after that layer is unnecessary, so we can remove these modules to reduce the memory used by the model and the computations it has to perform.

We can inspect the implementation of the vision encoder of SigLIP (`SiglipVisionTransformer`):
```python
class SiglipVisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.head = SiglipMultiheadAttentionPoolingHead(config)

    @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
    def forward(
        self,
        pixel_values,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        hidden_states = self.embeddings(pixel_values)

        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        last_hidden_state = self.post_layernorm(last_hidden_state)

        pooled_output = self.head(last_hidden_state)

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
```

As we can see it is just a Pytorch Module. It consists of four components:
- `embedding`: This is the part of the model responsible for converting raw pixel values of the input images into a higher-dimensional vector space.
- `encoder`: The encoder is the core of the model. It processes the sequence of embedded patches through several layers of self-attention and feed-forward neural networks.
- `post_layernorm`: it is used to apply normalization after the encoder to ensure that the output across different patches is normalized before any further processing.
- `head`: the "head" refers to the component that takes the output of the transformer encoder and performs a specific task, such as classification. In the context of `SiglipVisionTransformer`, this head is designed for pooling the transformer outputs into a single vector representation of the input image.

If we want the last layer of the encoder, we can get rid of the `post_layernorm` and the `head` and create a model that does the forward pass with the `embedding` and the `encoder` only.

In [4]:
from torch import nn


class ModifiedSiglipVisionModel(nn.Module):
    def __init__(self, ref_clip_vision_model: nn.Module):
        super().__init__()
        self.config = ref_clip_vision_model.config

        self.embeddings = ref_clip_vision_model.embeddings
        self.encoder = ref_clip_vision_model.encoder

    def forward(self, pixel_values):

        # Convert pixel values to embeddings
        embeddings = self.embeddings(pixel_values)

        # Process embeddings through the encoder
        encoder_outputs = self.encoder(inputs_embeds=embeddings)

        # There is only one output: last_hidden_state
        sequence_output = encoder_outputs[0]  # `encoder_outputs.last_hidden_state` is the same

        return sequence_output


# Create the vision model
vision_model = ModifiedSiglipVisionModel(ref_clip_vision_model=clip_vision_model)

# Do inference with the new model
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
    vision_model_output = vision_model(**inputs)

# It's the same output as before
assert torch.equal(vision_model_output, hidden_states[-1])
print(f"Vision model output shape: {tuple(vision_model_output.shape)}")

Vision model output shape: (1, 576, 768)


We can print the configuration and architecture of the model. __Note that the configuration does not correspond to this new module but to the original one__.

In [5]:
# Model configuration
print(f"Vision model configuration:\n{vision_model.config}")

# Model architecture
print(f"Model architecture:\n{vision_model}")

Vision model configuration:
SiglipVisionConfig {
  "attention_dropout": 0.0,
  "hidden_act": "gelu_pytorch_tanh",
  "hidden_size": 768,
  "image_size": 384,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-06,
  "model_type": "siglip_vision_model",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "transformers_version": "4.38.2"
}

Model architecture:
ModifiedSiglipVisionModel(
  (embeddings): SiglipVisionEmbeddings(
    (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), padding=valid)
    (position_embedding): Embedding(576, 768)
  )
  (encoder): SiglipEncoder(
    (layers): ModuleList(
      (0-11): 12 x SiglipEncoderLayer(
        (self_attn): SiglipAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_fe

We can compare the number of parameters before and after the prunning of the model.

In [6]:
print(f"Parameters before prunning: {sum(p.numel() for p in clip_vision_model.parameters())}")
print(f"Parameters after prunning: {sum(p.numel() for p in vision_model.parameters())}")

Parameters before prunning: 93176064
Parameters after prunning: 86087424


## Visual and Text connection
Now we have a visual model that can produce a sequence of tokens. This sequence can be consumed by the LLM and we can start asking question about
the image. Internally, the LLM will use the attention mechanism to look at the part of the image with the relevant information to answer the question.

However, the embedding dimension of our visual model (768 dimensions) is not the same as the input embeddinds of the LLM (2048). Therefore, we need to transform the visual model's output to a 2048-dimensional embedding using a MLP or a linear layer. We call this last module, the visual adapter.


At this stage, our visual model effectively translates images into a sequence of tokens, creating a bridge for our Language Model (LLM) to interpret and engage with visual data. By enabling the LLM to process these tokens, we can pose questions about the image's content and receive informed responses. This process leverages the LLM's attention mechanism, which selectively focuses on segments of the image encoded in the tokens that are most pertinent to the question at hand.

However, a technical challenge arises due to a discrepancy in the dimensionality of the embeddings between our visual model and the LLM. Specifically, the visual model outputs embeddings with a dimensionality of 768, whereas the LLM expects inputs with a dimensionality of 2048 for its embeddings. To bridge this gap, we employ a Multi-Layer Perceptron (MLP) or a linear transformation layer, aptly named the __visual adapter__.

The visual adapter serves a critical function: it transforms the 768-dimensional output from the visual model into a format compatible with the LLM's 2048-dimensional input embedding space. 

In [7]:
# LLM input size
embedding = llm_model.model.embed_tokens(torch.LongTensor([1175]))
print(f"Embedding shape of 'the' token: {tuple(embedding.shape)}")

# Visual encoder output size
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
    vision_model_output = vision_model(**inputs)

print(f"Vision model output shape: {tuple(vision_model_output.shape)}")

Embedding shape of 'the' token: (1, 2048)
Vision model output shape: (1, 576, 768)


Now we are going to add the visual adapter to the visual model

In [8]:
import transformers
from torch import nn


class ModifiedSiglipVisionModelWithAdapter(nn.Module):
    def __init__(self, ref_clip_vision_model: nn.Module, llm_model:transformers.AutoModel):
        super().__init__()
        self.config = ref_clip_vision_model.config

        self.embeddings = ref_clip_vision_model.embeddings
        self.encoder = ref_clip_vision_model.encoder

        self.adapter = nn.Linear(in_features=self.config.hidden_size, out_features=llm_model.config.hidden_size)

    def forward(self, pixel_values):

        # Convert pixel values to embeddings
        embeddings = self.embeddings(pixel_values)

        # Process embeddings through the encoder
        encoder_outputs = self.encoder(inputs_embeds=embeddings)

        # There is only one output: last_hidden_state
        sequence_output = encoder_outputs[0]  # `encoder_outputs.last_hidden_state` is the same

        # Prepare the output for the LLM
        sequence_output = self.adapter(sequence_output)

        return sequence_output


# Create the vision model
vision_model_with_adapter = ModifiedSiglipVisionModelWithAdapter(
    ref_clip_vision_model=clip_vision_model, llm_model=llm_model
)

# Do inference with the new model
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
    vision_model_with_adapter_output = vision_model_with_adapter(**inputs)

print(f"Vision model with adapter output shape: {tuple(vision_model_with_adapter_output.shape)}")

Vision model with adapter output shape: (1, 576, 2048)


Upon adapting the visual model's output to match the embedding size required by the LLM, one might assume we can straightforwardly feed these transformed tokens into the LLM and proceed with generating responses. However, a crucial nuance arises at this juncture: the tokens generated from the image, despite being in the correct dimensionality, represent a set of embeddings unfamiliar to the LLM. These embeddings do not correspond to any pre-existing token IDs within the LLM's vocabulary. Essentially, these new embeddings are akin to a foreign language to the LLM, containing potentially any value within the embedding space, without any predefined semantic association that the LLM can recognize.

This disparity prevents us from utilizing the standard approach of passing token IDs to the LLM for text generation. Token IDs serve as predefined pointers to specific embeddings within the LLM's vocabulary, acting as the basis for generating meaningful text responses. In the absence of corresponding token IDs for the new image-derived embeddings, this method is not viable.

However, there is a solution: bypassing the token ID-based input mechanism in favor of directly injecting the embeddings into the LLM. By sending these adapted embeddings straight to the LLM, we effectively communicate the visual information in a format the LLM can process, despite the embeddings not being part of its initial vocabulary. This approach enables the LLM to leverage its attention and contextual understanding capabilities on the embeddings derived from visual data, allowing it to generate relevant text responses based on the content of the images. 

In [11]:
# Do inference with the new model
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
    vision_model_with_adapter_output = vision_model_with_adapter(**inputs)

# Get the embedding for the text
text_embedding = llm_model.model.embed_tokens(
    torch.LongTensor(llm_tokenizer.encode("Describe the image.", add_special_tokens=False))
)

# Add the batch dimension
text_embedding = text_embedding.unsqueeze(0)

# Combine thte image and the text
embeddings = torch.concat([vision_model_with_adapter_output, text_embedding], dim=1)
print(f"Image embedding shape: {tuple(vision_model_with_adapter_output.shape)}")
print(f"Text embedding shape: {tuple(text_embedding.shape)}")
print(f"Image and text embedding shape: {tuple(embeddings.shape)}")

# Generate the output of the LLM
generate_ids = llm_model.generate(inputs_embeds=embeddings, max_new_tokens=256)
text_output_without_special_tokens = llm_tokenizer.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]

print(f"Raw model's output (id of the tokens): {generate_ids}")
print(f"Final output without special tokens: {repr(text_output_without_special_tokens)}")

Image embedding shape: (1, 576, 2048)
Text embedding shape: (1, 4, 2048)
Image and text embedding shape: (1, 580, 2048)
Raw model's output (id of the tokens): tensor([[   109,    109,    109,    109,    109,    109,    109,    109,    109,
            109, 235280,    578, 235279, 235269, 235248, 235248, 235248,    109,
            109,    109,    109,    109,    109,    109,    109,    109,    109,
            109, 235269, 235248, 235248, 235248, 235248,    109,    109,    109,
            109,    109,    109,    109,    109,    109,    109,    108,    688,
              1]])
Final output without special tokens: '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nA andT,   \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n,    \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n**'


Of course, the output makes no sense because the model has not been trained with these image tokens nor has the visual adapter been trained.

To ensure our model can effectively learn from and respond to both visual and textual inputs, we need to develop a function that seamlessly merges image and text embeddings. This is crucial for maintaining consistency across training and inference phases, thereby preventing discrepancies that could negatively impact model performance.

At the heart of our approach is the introduction of a special token, designated as a placeholder, which signals the position within the text sequence where image embeddings should be integrated. This system allows us to dynamically insert visual information into the model's input stream, ensuring that both text and image data are processed in a unified manner.

Consider a scenario where we wish to query the model about the contents of an image. The input might be structured as follows: `"<bos><image> What is the content of the image?"`. Through our custom method, this input is transformed into a concatenated sequence of embeddings: `["EMBEDDING_FOR_TOKEN_<bos>", "EMBEDDING_FOR_IMAGE_PATCH_1", ..., "EMBEDDING_FOR_IMAGE_PATCH_N", "EMBEDDING_FOR_TOKEN_What", ..., "EMBEDDING_FOR_TOKEN_?"]`. Each `"EMBEDDING_FOR_IMAGE_PATCH_X"` represents an embedding for a segment of the image, allowing the model to consider discrete portions of the visual input alongside textual information.

**Detailed Process:**

1. **Token Identification**: The method first identifies the special placeholder token within the input sequence. This token acts as a marker for where the image embeddings will be inserted, replacing the placeholder.

2. **Embedding Concatenation**: The model then constructs a new input sequence by concatenating the appropriate embeddings. Text embeddings are taken directly from the pre-trained language model's vocabulary, while image embeddings are generated by processing the image through our visual encoder.

3. **Sequence Expansion**: Given that an image is represented by multiple embeddings (one for each patch or segment of the image), the input sequence is dynamically expanded to accommodate these additional embeddings. This ensures that the model receives a comprehensive representation of both the textual query and the visual content.

This process is applied identically during both training and inference, ensuring that the model consistently interprets and processes the combined input. Such consistency is vital for the model to accurately learn from its training data and to apply this knowledge effectively during inference.

In [12]:
# First, you'll need to add the special token to your tokenizer.
image_placeholder_token = '<image>'
special_tokens_dict = {'additional_special_tokens': [image_placeholder_token]}
num_added_toks = llm_tokenizer.add_special_tokens(special_tokens_dict)
print(f"We have added {num_added_toks} tokens")

# After adding the special token to the tokenizer, you need to ensure that the model’s embedding layer is aware of this 
# new token. This involves resizing the embedding layer to accommodate the additional token(s).
llm_model.resize_token_embeddings(len(llm_tokenizer))

# Adjust configuration of the model
llm_model.config.vocab_size = len(llm_tokenizer)

print(f"LLM model config:\n{llm_model.config}")

We have added 1 tokens
LLM model config:
GemmaConfig {
  "_name_or_path": "./models/gemma-2b-it",
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.2",
  "use_cache": true,
  "vocab_size": 256001
}



In [13]:
import torch

# Now we can create the function that will merge the text and image embeddings
def merge_input_ids_with_image_features(image_features, input_ids, attention_mask, llm_model, image_token_id):
    # Initialize variables for the new sequence dimensions
    batch_size, seq_length = input_ids.shape
    image_feature_dim = image_features.size(-1)
    image_seq_length = image_features.size(1)

    # Convert input_ids to embeddings
    inputs_embeds = llm_model.model.embed_tokens(input_ids)

    # Device adjustment for new tensors based on the device of inputs_embeds
    device = inputs_embeds.device
    
    # Find positions of the <image> tokens
    image_token_mask = input_ids == image_token_id
    num_image_tokens_per_seq = torch.sum(image_token_mask, dim=-1)
    image_token_positions = image_token_mask.nonzero()

    # There must be always one and only one image per sequence
    if not all(num_image_tokens_per_seq == 1):
        raise RuntimeError(
            f"Expecting one and only one image (token_id = {image_token_id}) per sequence in the batch: {input_ids}"
        )
    
    # All image tokens must be at the same position
    if not all(image_token_positions[0, 1] == image_token_positions[:, 1]):
        raise RuntimeError(
            f"Image token (token_id = {image_token_id}) is not at the same position in all sequences of the batch:"
            f"{input_ids}"
        )

    # Calculate the new sequence length after inserting image features
    new_seq_length =  image_seq_length + seq_length - 1  # - 1 because we remove the image token

    # Prepare containers for the new embeddings and attention mask
    new_inputs_embeds = torch.zeros(batch_size, new_seq_length, image_feature_dim, device=device)
    new_attention_mask = torch.zeros(batch_size, new_seq_length, device=device)
    
    # Copy the text and image embeddings into the new containers
    for batch_idx, pos in image_token_positions:
        # Copy text embeddings up to the image token
        if pos > 0:
            new_inputs_embeds[:, :pos, :] = inputs_embeds[:, :pos, :]
            new_attention_mask[:, :pos] = attention_mask[:, :pos]

        # Insert image features
        image_start_idx = pos
        image_end_idx = image_start_idx + image_seq_length
        new_inputs_embeds[batch_idx, image_start_idx:image_end_idx, :] = image_features[batch_idx, :, :]
        new_attention_mask[batch_idx, image_start_idx:image_end_idx] = 1

        # Copy remaining text embeddings after the <image> token
        if pos < seq_length - 1:
            remaining_text_start_idx = image_end_idx
            new_inputs_embeds[batch_idx, remaining_text_start_idx:, :] = inputs_embeds[batch_idx, pos+1:, :]
            new_attention_mask[batch_idx, remaining_text_start_idx:] = attention_mask[batch_idx, pos+1:]

    return new_inputs_embeds, new_attention_mask


# Get the inputs
prompt = f"{image_placeholder_token}\nDescribe the image."
image_token_id = llm_tokenizer.convert_tokens_to_ids(image_placeholder_token)
tokenizer_output = llm_tokenizer(prompt, return_tensors="pt", padding=True)
print(f"Using prompt: {repr(prompt)}")
print(f"Tokenizer input_ids: {tokenizer_output.input_ids}")

image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = image_processor(images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
    image_features = vision_model_with_adapter(**inputs)

print(f"Image features shape: {tuple(image_features.shape)}")

new_inputs_embeds, new_attention_mask = merge_input_ids_with_image_features(
    image_features, tokenizer_output.input_ids, tokenizer_output.attention_mask, llm_model, image_token_id
)

num_text_tokens = tokenizer_output.input_ids.shape[1] - 1
print(f"The prompt {repr(prompt)} will have {num_text_tokens} text tokens "
      f"and 576 image tokens = {576 + num_text_tokens} tokens/embeddings"
)

print(
    f"After merging the image and text embeddings the final shape of the embeddings and attention mask is: "
    f"{tuple(new_inputs_embeds.shape)} - {tuple(new_attention_mask.shape)}"
)

Using prompt: '<image>\nDescribe the image.'
Tokenizer input_ids: tensor([[     2, 256000,    108,  50721,    573,   2416, 235265]])
Image features shape: (1, 576, 2048)
The prompt '<image>\nDescribe the image.' will have 6 text tokens and 576 image tokens = 582 tokens/embeddings
After merging the image and text embeddings the final shape of the embeddings and attention mask is: (1, 582, 2048) - (1, 582)


In [14]:
# With a batch of prompts
prompts = [f"{image_placeholder_token}\nDescribe the image.", f"{image_placeholder_token}\nCount the cats."]
image_token_id = llm_tokenizer.convert_tokens_to_ids(image_placeholder_token)
tokenizer_output = llm_tokenizer(prompts, return_tensors="pt", padding=True)
print(f"Using prompts: {repr(str(prompts))}")
print(f"Tokenizer input_ids: {tokenizer_output.input_ids}")

# Use the same image twice
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = image_processor(images=[image, image], padding="max_length", return_tensors="pt")
with torch.no_grad():
    image_features = vision_model_with_adapter(**inputs)

print(f"Image features shape: {tuple(image_features.shape)}")

new_inputs_embeds, new_attention_mask = merge_input_ids_with_image_features(
    image_features, tokenizer_output.input_ids, tokenizer_output.attention_mask, llm_model, image_token_id
)

# We can validate that both images ends up with the same embedding at the same place
assert torch.equal(new_inputs_embeds[:, 1:577, :], image_features)  # First element of each sequence is <bos> embedding

# Some shape information
num_text_tokens = tokenizer_output.input_ids.shape[1] - 1
print(f"Each prompt will have {num_text_tokens} text tokens "
      f"and 576 image tokens = {576 + num_text_tokens} tokens/embeddings"
)

print(
    f"After merging the image and text embeddings the final shape of the embeddings and attention mask is: "
    f"{tuple(new_inputs_embeds.shape)} - {tuple(new_attention_mask.shape)}"
)

Using prompts: "['<image>\\nDescribe the image.', '<image>\\nCount the cats.']"
Tokenizer input_ids: tensor([[     2, 256000,    108,  50721,    573,   2416, 235265],
        [     2, 256000,    108,   3074,    573,  19493, 235265]])
Image features shape: (2, 576, 2048)
Each prompt will have 6 text tokens and 576 image tokens = 582 tokens/embeddings
After merging the image and text embeddings the final shape of the embeddings and attention mask is: (2, 582, 2048) - (2, 582)


## Create one model with everything we need
We want to create a model with each element that is needed to do the inference. This model needs to store the visual encoder (with the adapter) and the language model. In addition, we want to train the model with the HuggingFace Trainer library and store and load its weights so we will create this model as a HuggingFace model.

To do inference with a HuggingFace model we use the `generate` method. To use this method on a custom HuggingFace model we need to implement `prepare_inputs_for_generation` and `_reorder_cache` due to the peculiarities of our model  (we need to get the image embeddings from the vision encoder). But we can do it more easily by relying on the language model to do the generation part. We just need to combine the text and image embeds before calling generate on the language model.

In [2]:
from dataclasses import dataclass
from typing import Optional
import requests
import torch
from PIL import Image
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_utils import ModelOutput
from transformers import AutoProcessor, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch import nn


class ModifiedSiglipVisionModelWithAdapter(nn.Module):
    def __init__(self, ref_clip_vision_model: nn.Module, llm_model:AutoModel):
        super().__init__()
        self.config = ref_clip_vision_model.config

        self.embeddings = ref_clip_vision_model.embeddings
        self.encoder = ref_clip_vision_model.encoder

        self.adapter = nn.Linear(in_features=self.config.hidden_size, out_features=llm_model.config.hidden_size)

    def forward(self, pixel_values):

        # Convert pixel values to embeddings
        embeddings = self.embeddings(pixel_values)

        # Process embeddings through the encoder
        encoder_outputs = self.encoder(inputs_embeds=embeddings)

        # There is only one output: last_hidden_state
        sequence_output = encoder_outputs[0]  # `encoder_outputs.last_hidden_state` is the same

        # Prepare the output for the LLM
        sequence_output = self.adapter(sequence_output)

        return sequence_output


@dataclass
class VLMOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None


class VLMConfig(PretrainedConfig):
    model_type = "vlm"
    is_composition = False

    def __init__(
        self,
        **kwargs,
    ):
        super().__init__(**kwargs)


class VLM(PreTrainedModel):
    config_class = VLMConfig

    def __init__(self, llm_model, clip_vision_model, image_placeholder_token, image_token_id, num_tokens):
        super().__init__(VLM.config_class())

        self.llm_model = llm_model
        self.vision_model_with_adapter = ModifiedSiglipVisionModelWithAdapter(clip_vision_model, llm_model)       

        self.image_placeholder_token = image_placeholder_token
        self.image_token_id = image_token_id

        # Resize the embedding layer to accommodate any additional token
        self.llm_model.resize_token_embeddings(num_tokens)

        # Adjust configuration of the llm model
        self.llm_model.config.vocab_size = num_tokens

    def _merge_input_ids_with_image_features(self, image_features, input_ids, attention_mask):
        # Initialize variables for the new sequence dimensions
        batch_size, seq_length = input_ids.shape
        image_feature_dim = image_features.size(-1)
        image_seq_length = image_features.size(1)

        # Convert input_ids to embeddings
        inputs_embeds = self.llm_model.model.embed_tokens(input_ids)

        # Device adjustment for new tensors based on the device of inputs_embeds
        device = inputs_embeds.device
        
        # Find positions of the <image> tokens
        image_token_mask = input_ids == self.image_token_id
        num_image_tokens_per_seq = torch.sum(image_token_mask, dim=-1)
        image_token_positions = image_token_mask.nonzero()

        # There must be always one and only one image per sequence
        if not all(num_image_tokens_per_seq == 1):
            raise RuntimeError(
                f"Expecting one and only one image (id={self.image_token_id}) per sequence in the batch: {input_ids}"
            )
        
        # All image tokens must be at the same position
        if not all(image_token_positions[0, 1] == image_token_positions[:, 1]):
            raise RuntimeError(
                f"Image token (id={self.image_token_id}) is not at the same position in all sequences of the batch:"
                f"{input_ids}"
            )

        # Calculate the new sequence length after inserting image features
        new_seq_length =  image_seq_length + seq_length - 1  # - 1 because we remove the image token

        # Prepare containers for the new embeddings and attention mask
        new_inputs_embeds = torch.zeros(batch_size, new_seq_length, image_feature_dim, device=device)
        new_attention_mask = torch.zeros(batch_size, new_seq_length, device=device)
        
        # Copy the text and image embeddings into the new containers
        for batch_idx, pos in image_token_positions:
            # Copy text embeddings up to the image token
            if pos > 0:
                new_inputs_embeds[:, :pos, :] = inputs_embeds[:, :pos, :]
                new_attention_mask[:, :pos] = attention_mask[:, :pos]

            # Insert image features
            image_start_idx = pos
            image_end_idx = image_start_idx + image_seq_length
            new_inputs_embeds[batch_idx, image_start_idx:image_end_idx, :] = image_features[batch_idx, :, :]
            new_attention_mask[batch_idx, image_start_idx:image_end_idx] = 1

            # Copy remaining text embeddings after the <image> token
            if pos < seq_length - 1:
                remaining_text_start_idx = image_end_idx
                new_inputs_embeds[batch_idx, remaining_text_start_idx:, :] = inputs_embeds[batch_idx, pos+1:, :]
                new_attention_mask[batch_idx, remaining_text_start_idx:] = attention_mask[batch_idx, pos+1:]

        return new_inputs_embeds, new_attention_mask

    def create_prompt(self, instruction: str, expected_response: str = ""):
        # This function is what you have to use before sending the text to the tokenizer
        # `expected_response` is only for training
        return (
            f"<start_of_turn>user\n{self.image_placeholder_token}\n{instruction}<end_of_turn>\n<start_of_turn>model\n"
            f"{expected_response}"
        )

    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        input_ids: torch.LongTensor = None,
        attention_mask: torch.Tensor = None,
        labels: Optional[torch.LongTensor] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):

        return_dict = return_dict if return_dict is not None else False
        
        # Get the embeddings for the image
        image_features = self.vision_model_with_adapter(pixel_values)

        # Merge the embeddings of the image and text
        inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
            image_features, input_ids, attention_mask
        )
        
        # Get the llm model's output
        outputs = self.llm_model(
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=return_dict
        )

        logits = outputs[0]

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            if attention_mask is not None:
                shift_attention_mask = attention_mask[..., 1:]
                shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
                shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
            else:
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return VLMOutput(
            loss=loss,
            logits=logits
        )

    def generate(self,
        pixel_values: torch.FloatTensor = None,
        input_ids: torch.LongTensor = None,
        attention_mask: torch.Tensor = None,
        max_new_tokens: int = 256
    ):
        # Get the embeddings for the image
        image_features = self.vision_model_with_adapter(pixel_values)

        # Merge the embeddings of the image and text
        inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
            image_features, input_ids, attention_mask
        )
        
        # Generate with the LLM
        return self.llm_model.generate(
            inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens
        )


# Create each individual component
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_folder, local_files_only=True)
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_folder, local_files_only=True)

clip_model = AutoModel.from_pretrained(clip_model_folder, local_files_only=True)
clip_processor = AutoProcessor.from_pretrained(clip_model_folder, local_files_only=True)

# Get the vision model and the image processor from CLIP
clip_vision_model = clip_model.vision_model
image_processor = clip_processor.image_processor 

# Prepare the tokenizer
image_placeholder_token = '<image>'
special_tokens_dict = {'additional_special_tokens': [image_placeholder_token]}
num_added_toks = llm_tokenizer.add_special_tokens(special_tokens_dict)

image_token_id = llm_tokenizer.convert_tokens_to_ids(image_placeholder_token)
num_tokens = len(llm_tokenizer)

# Create the model
vlm = VLM(llm_model, clip_vision_model, image_placeholder_token, image_token_id, num_tokens)

# Prepare the inputs
prompt = vlm.create_prompt(instruction="Describe the image.")
text_inputs = llm_tokenizer(prompt, return_tensors="pt", padding=True)
print(f"Using prompt: {repr(prompt)}")
print(f"Tokenizer input_ids: {text_inputs.input_ids}")

image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
image_inputs = image_processor(images=image, padding="max_length", return_tensors="pt")

print(f"Image shape: {tuple(image_inputs.pixel_values.shape)}")

# Do inference
with torch.no_grad():
    generate_ids = vlm.generate(
        pixel_values=image_inputs.pixel_values, 
        input_ids=text_inputs.input_ids, 
        attention_mask=text_inputs.attention_mask,
        max_new_tokens=256
    )

text_output_without_special_tokens = llm_tokenizer.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]

print(f"Raw model's output (id of the tokens): {generate_ids}")
print(f"Final output without special tokens: {repr(text_output_without_special_tokens)}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using prompt: '<start_of_turn>user\n<image>\nDescribe the image.<end_of_turn>\n<start_of_turn>model\n'
Tokenizer input_ids: tensor([[     2,    106,   1645,    108, 256000,    108,  50721,    573,   2416,
         235265,    107,    108,    106,   2516,    108]])
Image shape: (1, 3, 384, 384)
Raw model's output (id of the tokens): tensor([[235248,    108, 235248,    108, 235248,    108, 235248,    108, 235295,
         235278,    483, 235303, 256000, 256000, 256000, 256000, 256000, 256000,
         256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
         256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
         256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
         256000, 256000, 256000, 256000, 256000,  16590,    675,   3591,  11615,
         256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
         256000, 256000, 256000, 256000, 235269,    665,   1412, 256000, 256000,
         256000, 25