<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/LlaVA_Img_Text_Embedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q transformers==4.36.0
!pip install -q bitsandbytes==0.41.3 accelerate==0.25.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
import torch
import torch.nn as nn
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union

class LlaVA_Img_Text_Embedding(nn.Module):
    def __init__(self, ):
        super(LlaVA_Img_Text_Embedding, self).__init__()
        kwargs = {"device_map": "auto"}
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
        self.model_llava = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf",**kwargs)
        self.get_input_embeddings = self.model_llava.get_input_embeddings
        self.vision_tower = self.model_llava.vision_tower
        self.multi_modal_projector = self.model_llava.multi_modal_projector
        self._merge_input_ids_with_image_features = self.model_llava._merge_input_ids_with_image_features

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            pixel_values: torch.FloatTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,

        ):
        vision_feature_select_strategy = "default"
        # 1. Extra the input embeddings
        inputs_embeds = self.get_input_embeddings()(input_ids)
        # 2. Merge text and images
        if pixel_values is not None and input_ids.shape[1] != 1:
            selected_image_feature = self.vision_tower(pixel_values).last_hidden_state

            if vision_feature_select_strategy == "default":
                selected_image_feature = selected_image_feature[:, 1:]
            elif vision_feature_select_strategy == "full":
                selected_image_feature = selected_image_feature
            else:
                raise ValueError(
                    f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
                )

            image_features = self.multi_modal_projector(selected_image_feature)

            print('image_features:', image_features.shape, 'text input_ids:', input_ids.shape, input_ids)
            inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
                image_features, inputs_embeds, input_ids, attention_mask, position_ids
            )


        return inputs_embeds, attention_mask, position_ids

prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
inputs = processor(text=prompt, images=image, return_tensors="pt")
# inputs['input_ids'] = inputs['input_ids'].cuda()
for key, val in inputs.items():
    inputs[key] = val.cuda()

img_text_embeder = LlaVA_Img_Text_Embedding().cuda()
inputs_embeds, attention_mask, position_ids = img_text_embeder(**inputs)

print(inputs_embeds.shape)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

image_features: torch.Size([1, 576, 4096]) text input_ids: torch.Size([1, 20])
torch.Size([1, 595, 4096])


In [3]:
inputs_embeds

tensor([[[ 0.0045, -0.0038,  0.0017,  ..., -0.0088,  0.0025, -0.0025],
         [ 0.0718,  0.7710, -0.4734,  ..., -0.2423,  0.6743, -0.1895],
         [ 0.5903,  0.4768, -0.6670,  ..., -0.4971,  0.5259, -0.3545],
         ...,
         [-0.0187, -0.0017,  0.0177,  ...,  0.0238,  0.0052,  0.0101],
         [ 0.0066, -0.0161,  0.0117,  ..., -0.0103,  0.0148,  0.0073],
         [ 0.0039,  0.0015,  0.0055,  ..., -0.0042,  0.0151,  0.0024]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<IndexPutBackward0>)

In [7]:
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4'
)
model_llava = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf",**kwargs)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [19]:
prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
inputs = processor(text=prompt, images=image, return_tensors="pt")
# inputs['input_ids'] = inputs['input_ids'].cuda()
for key, val in inputs.items():
    inputs[key] = val.cuda()
    print(key)

print(inputs['pixel_values'].shape, inputs['input_ids'].shape, inputs['input_ids'])

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


input_ids
attention_mask
pixel_values
torch.Size([1, 3, 336, 336]) torch.Size([1, 20]) tensor([[    1, 32000, 29871,    13, 11889, 29901,  1724, 29915, 29879,   278,
          2793,   310,   278,  1967, 29973,    13, 22933,  9047, 13566, 29901]],
       device='cuda:0')


In [24]:
get_input_embeddings = model_llava.get_input_embeddings
vision_tower = model_llava.vision_tower
multi_modal_projector = model_llava.multi_modal_projector
_merge_input_ids_with_image_features = model_llava._merge_input_ids_with_image_features
inputs_embeds = get_input_embeddings()(inputs['input_ids'])
print(inputs_embeds.shape, inputs_embeds)

torch.Size([1, 20, 4096]) tensor([[[ 0.0045, -0.0038,  0.0017,  ..., -0.0088,  0.0025, -0.0025],
         [ 0.0007,  0.0006, -0.0005,  ..., -0.0006, -0.0001, -0.0005],
         [-0.0012,  0.0013, -0.0127,  ...,  0.0026, -0.0012, -0.0053],
         ...,
         [-0.0187, -0.0017,  0.0177,  ...,  0.0238,  0.0052,  0.0101],
         [ 0.0066, -0.0161,  0.0117,  ..., -0.0103,  0.0148,  0.0073],
         [ 0.0039,  0.0015,  0.0055,  ..., -0.0042,  0.0151,  0.0024]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<EmbeddingBackward0>)


In [25]:
selected_image_feature = vision_tower(inputs['pixel_values']).last_hidden_state
print(selected_image_feature.shape, selected_image_feature)

selected_image_feature = selected_image_feature[:, 1:]
print(selected_image_feature.shape, selected_image_feature)

torch.Size([1, 577, 1024]) tensor([[[-0.4695, -0.2588, -0.0212,  ..., -0.1227, -0.8413,  0.4075],
         [ 0.5752,  1.1084,  0.0914,  ...,  1.5645,  0.1642,  0.9097],
         [ 0.2900,  0.3425,  0.7036,  ...,  1.6426, -0.4028,  0.5298],
         ...,
         [ 1.8662,  1.0088, -0.2004,  ...,  1.8193, -0.8154,  0.8516],
         [ 0.3364,  0.2842,  0.9233,  ...,  0.4353,  0.3420, -0.4443],
         [ 0.0479,  0.2493,  0.5264,  ...,  0.0032,  0.2112,  0.1550]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
torch.Size([1, 576, 1024]) tensor([[[ 0.5752,  1.1084,  0.0914,  ...,  1.5645,  0.1642,  0.9097],
         [ 0.2900,  0.3425,  0.7036,  ...,  1.6426, -0.4028,  0.5298],
         [ 0.4731, -0.1261,  0.5146,  ..., -0.2083,  0.1995, -0.4390],
         ...,
         [ 1.8662,  1.0088, -0.2004,  ...,  1.8193, -0.8154,  0.8516],
         [ 0.3364,  0.2842,  0.9233,  ...,  0.4353,  0.3420, -0.4443],
         [ 0.0479,  0.2493,  0.5264,  ...,  0.0032,  0.2112,  0.15

In [26]:
image_features = multi_modal_projector(selected_image_feature)
print(image_features.shape, image_features)

torch.Size([1, 576, 4096]) tensor([[[ 0.0718,  0.7710, -0.4734,  ..., -0.2423,  0.6743, -0.1895],
         [ 0.5903,  0.4768, -0.6670,  ..., -0.4971,  0.5259, -0.3545],
         [ 0.5405,  0.4556, -0.1338,  ..., -0.3860,  1.0166, -0.4998],
         ...,
         [ 0.8389,  0.0048, -1.1475,  ..., -0.2019,  0.4473,  0.0886],
         [ 0.6997,  0.1176, -0.7114,  ..., -0.6050,  0.8491,  0.0121],
         [ 0.0582,  0.4683, -0.4680,  ..., -0.5923,  0.4304, -0.4587]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MatMul4BitBackward>)


In [27]:
print(image_features.shape, inputs_embeds.shape)
inputs_embeds, attention_mask, position_ids = _merge_input_ids_with_image_features(
                image_features, inputs_embeds, inputs['input_ids'], attention_mask, position_ids
            )

print(inputs_embeds.shape, inputs_embeds)

torch.Size([1, 576, 4096]) torch.Size([1, 20, 4096])
torch.Size([1, 595, 4096]) tensor([[[ 0.0045, -0.0038,  0.0017,  ..., -0.0088,  0.0025, -0.0025],
         [ 0.0718,  0.7710, -0.4734,  ..., -0.2423,  0.6743, -0.1895],
         [ 0.5903,  0.4768, -0.6670,  ..., -0.4971,  0.5259, -0.3545],
         ...,
         [-0.0187, -0.0017,  0.0177,  ...,  0.0238,  0.0052,  0.0101],
         [ 0.0066, -0.0161,  0.0117,  ..., -0.0103,  0.0148,  0.0073],
         [ 0.0039,  0.0015,  0.0055,  ..., -0.0042,  0.0151,  0.0024]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<IndexPutBackward0>)


In [36]:
def merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, position_ids):
    num_images, num_image_patches, embed_dim = image_features.shape
    batch_size, sequence_length = input_ids.shape
    left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(model_llava.pad_token_id))
    # 1. Create a mask to know where special image tokens are
    special_image_token_mask = input_ids == model_llava.config.image_token_index
    num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
    # Compute the maximum embed dimension
    max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
    batch_indices, non_image_indices = torch.where(input_ids != model_llava.config.image_token_index)

    # 2. Compute the positions where text should be written
    # Calculate new positions for text tokens in merged image-text sequence.
    # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
    # `torch.cumsum` computes how each image token shifts subsequent text token positions.
    # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
    new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
    nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
    if left_padding:
        new_token_positions += nb_image_pad[:, None]  # offset for left padding
    text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

    # 3. Create the full embedding, already padded to the maximum position
    final_embedding = torch.zeros(
        batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
    )
    final_attention_mask = torch.zeros(
        batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
    )

    # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
    # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
    final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
    final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]

    # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
    image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
    image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]

    if image_to_overwrite.sum() != image_features.shape[:-1].numel():
        raise ValueError(
            f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
            f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
        )

    final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
    final_attention_mask |= image_to_overwrite
    position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
    return final_embedding, final_attention_mask, position_ids

print(image_features.shape, inputs_embeds.shape)
inputs_embeds, attention_mask, position_ids = merge_input_ids_with_image_features(
                image_features, inputs_embeds, inputs['input_ids'], attention_mask, position_ids
            )

print(inputs_embeds.shape, inputs_embeds)


torch.Size([1, 576, 4096]) torch.Size([1, 595, 4096])
torch.Size([1, 595, 4096]) tensor([[[ 0.0045, -0.0038,  0.0017,  ..., -0.0088,  0.0025, -0.0025],
         [ 0.0718,  0.7710, -0.4734,  ..., -0.2423,  0.6743, -0.1895],
         [ 0.5903,  0.4768, -0.6670,  ..., -0.4971,  0.5259, -0.3545],
         ...,
         [ 1.4023,  0.5166, -0.0838,  ..., -0.8745,  0.1522, -0.6401],
         [ 0.4265,  0.7480, -1.0049,  ..., -0.1869,  0.7275,  0.0514],
         [-0.0573,  0.6553, -0.5039,  ..., -0.5781,  0.2234,  0.0302]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<IndexPutBackward0>)


In [34]:
model_llava.pad_token_id, model_llava.config.image_token_index

(32001, 32000)