# Load LLaVA and extract representation of image

In [39]:
import torch

# Check if CUDA is available and set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cpu


In [40]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path

model_path = "liuhaotian/llava-v1.5-7b"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path)
)

model = model.to(dtype=torch.float32)

You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.31it/s]


In [41]:
# Access the vision tower (the model's visual encoder)
vision_tower = model.model.vision_tower

In [42]:
from PIL import Image

# Load and preprocess the image
image_path = "/homes/talbu/work/repos/2_data.png"
image = Image.open(image_path).convert('RGB')

image_inputs = image_processor(images=image, return_tensors='pt')
image_tensor = image_inputs['pixel_values'].to(device)


In [43]:
with torch.no_grad():
    # Pass the image through the vision tower
    vision_features = vision_tower(image_tensor)

    # Convert vision_features to match the model's dtype
    vision_features = vision_features.to(next(model.parameters()).dtype)

    # Check for the projection layer
    if hasattr(model.model, 'mm_projector'):
        image_embedding = model.model.mm_projector(vision_features)


In [44]:
# Print the shape and dtype of the image embedding
print("Image Embedding Shape:", image_embedding.shape)
print("Image Embedding Dtype:", image_embedding.dtype)

Image Embedding Shape: torch.Size([1, 576, 4096])
Image Embedding Dtype: torch.float32


# Compute cosine similarity between the representations

## Obtain the Text Embedding of the Entity Name

In [45]:
entity_name = "Stop Traffic sign"

# Tokenize the text
inputs = tokenizer(entity_name, return_tensors='pt')
input_ids = inputs['input_ids'].to(model.device)

# Get the text embeddings
with torch.no_grad():
    text_embeddings = model.get_input_embeddings()(input_ids)  # Shape: [1, seq_len, hidden_size]

# Pool the text embeddings
text_embedding = text_embeddings.mean(dim=1)  # Shape: [1, hidden_size]


## Process the Image Embedding

In [47]:
# Pool the image embeddings
image_embedding_pooled = image_embedding.mean(dim=1)  # Shape: [1, hidden_size]

# Ensure embeddings are on the same device and have the same dtype
text_embedding = text_embedding.to(image_embedding_pooled.device)
text_embedding = text_embedding.to(image_embedding_pooled.dtype)

## Compute the Cosine Similarity

In [48]:
from torch.nn.functional import cosine_similarity

# Compute cosine similarity
cos_sim = cosine_similarity(text_embedding, image_embedding_pooled, dim=1)
print("Cosine Similarity:", cos_sim.item())

Cosine Similarity: 0.05746552720665932
