In [2]:
import torch
from PIL import Image

from colpali_engine.models import ColQwen2, ColQwen2Processor

model = ColQwen2.from_pretrained(
        "vidore/colqwen2-v1.0",
        torch_dtype=torch.bfloat16,
        device_map=None,  # or "mps" if on Apple Silicon
    ).eval()
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")



Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.07s/it]


In [4]:
# Import necessary libraries
import os
import json
from PIL import Image
from transformers import AutoProcessor, Blip2ForImageTextRetrieval
import torch
from tqdm import tqdm

# Step 1: Define the paths
metadata_path = "meta_data_old_without_filtering/dataset_coco.json"  # Path to the metadata JSON file
image_dir = "/Users/doruktarhan/Desktop/MSCOCO_trial_images_small/train2014"   # Path to the directory containing images

# Step 2: Load the metadata
with open(metadata_path, "r") as file:
    metadata = json.load(file)

# Step 3: Extract image-caption pairs
image_caption_pairs = []

for image_data in tqdm(metadata["images"], desc="Processing metadata"):
    image_id = image_data["imgid"]
    filename = image_data["filename"]
    filepath = os.path.join(image_dir, filename)

    # Verify if the image file exists
    if os.path.exists(filepath):
        for sentence_data in image_data["sentences"]:
            caption_id = sentence_data["sentid"]
            caption = sentence_data["raw"]
            image_caption_pairs.append((image_id, caption_id, filepath, caption))


print(f"Total image-caption pairs extracted: {len(image_caption_pairs)}")
print("Example pair:", image_caption_pairs[0])

Processing metadata: 100%|██████████| 123287/123287 [00:00<00:00, 313151.97it/s]

Total image-caption pairs extracted: 160
Example pair: (42475, 40576, '/Users/doruktarhan/Desktop/MSCOCO_trial_images_small/train2014/COCO_train2014_000000000089.jpg', 'An oven with a stove on top of it in a kitchen.')





In [7]:
# Step 5: Process the images and captions
processed_data = []

for image_id, caption_id, filepath, caption in tqdm(image_caption_pairs, desc="Processing image-caption pairs"):
    
    # Open the image
    image = Image.open(filepath).convert("RGB")
    
    # Preprocess the image
    batch_images = processor.process_images(images=[image])
    batch_captions = processor.process_queries(queries=[caption])
    # Forward pass through the model
    with torch.no_grad():
        image_embeds = model(**batch_images)
        text_embeds = model(**batch_captions)
    
    # Save the embeddings
    processed_data.append({
        "image_id": image_id,
        "caption_id": caption_id,
        "image_filepath": filepath,
        "caption": caption,
        "image_embeds": image_embeds,
        "text_embeds": text_embeds
    })
    



Processing image-caption pairs: 100%|██████████| 160/160 [34:33<00:00, 12.96s/it] 


In [12]:
processed_data[0]['image_embeds'].shape, processed_data[0]['text_embeds'].shape

(torch.Size([1, 402, 128]), torch.Size([1, 25, 128]))

In [13]:

print(f"Total processed pairs: {len(processed_data)}")

# Step 6: Inspect results
print("Example Processed Pair Outputs:")
example = processed_data[0]
print("Image ID:", example["image_id"])
print("Caption ID:", example["caption_id"])
print("Caption:", example["caption"])

Total processed pairs: 160
Example Processed Pair Outputs:
Image ID: 42475
Caption ID: 40576
Caption: An oven with a stove on top of it in a kitchen.


In [25]:
from torch.nn.functional import normalize,cosine_similarity

# Process the embeddings
for data in processed_data:

    image_embeds = normalize(data["image_embeds"][:,0,:], p=2, dim=-1)
    text_embeds = normalize(data["text_embeds"][:,0,:], p=2, dim=-1)

    # Compute cosine similarity for each of the 32 image embeddings
    similarities = cosine_similarity(image_embeds.squeeze(0), text_embeds.squeeze(0), dim=-1)  # [32]
    
    # Get the maximum similarity score and the corresponding index
    max_similarity, max_index = similarities.max(dim=0)
    
    # Save the result back to the data for inspection
    data["max_similarity"] = max_similarity.item()
    data["max_index"] = max_index.item()

# Inspect results
for data in processed_data:
    print(f"Max Similarity: {data['max_similarity']}, Max Index: {data['max_index']}")


Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 0.99609375, Max Index: 0
Max Similarity: 