<a href="https://colab.research.google.com/github/devaru-ai/ContextVision/blob/main/ContextVision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install required packages

In [None]:
!pip install transformers faiss-cpu datasets gradio


## Import all necessary libraries


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
import faiss
import gradio as gr
from PIL import Image
import random


## Load Dataset

We use the Flickr8k dataset, which contains 8,000 images with five captions each. For evaluation, we select a subset of queries and their corresponding ground truth images.


In [None]:
dataset = load_dataset("jxie/flickr8k")['train']


In [None]:
def collate_fn(batch):
    captions, images = zip(*batch)
    return list(captions), list(images)


In [None]:
class Flickr8kPairDataset(Dataset):
    def __init__(self, hf_dataset):
        self.data = hf_dataset

    def __getitem__(self, idx):
        item = self.data[idx]
        caption = item[f'caption_{np.random.randint(0,5)}']  # Random caption for each image
        image = item['image']
        return caption, image

    def __len__(self):
        return len(self.data)

pair_dataset = Flickr8kPairDataset(dataset)
pair_loader = DataLoader(pair_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)



# Load CLIP model and processor (ViT-B/32) and move to GPU if available


In [None]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model.to(device)


# Define symmetric contrastive loss for CLIP (aligns image and text embeddings)


In [None]:
def clip_contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    image_embeds = F.normalize(image_embeds, p=2, dim=1)
    text_embeds = F.normalize(text_embeds, p=2, dim=1)
    logits = image_embeds @ text_embeds.t() / temperature  # [N, N]
    labels = torch.arange(len(image_embeds)).to(image_embeds.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    return (loss_i2t + loss_t2i) / 2


# Fine-tune CLIP on (image, caption) pairs using contrastive loss and the Adam optimizer


In [None]:
optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5)
num_epochs = 6

clip_model.train()
for epoch in range(num_epochs):
    for captions, images in pair_loader:
      text_inputs = clip_processor(text=captions, return_tensors="pt", padding=True, truncation=True).to(device)
      text_emb = clip_model.get_text_features(**text_inputs)
      img_inputs = clip_processor(images=images, return_tensors="pt").to(device)
      img_emb = clip_model.get_image_features(**img_inputs)
      loss = clip_contrastive_loss(img_emb, text_emb)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    print(f"Epoch {epoch+1} completed. Last batch loss: {loss.item():.4f}")
clip_model.eval()


# Extract image embeddings and build a FAISS index for fast similarity search


In [None]:
batch_size = 32
image_embeddings = []
image_indices = []

with torch.no_grad():
    for i in range(0, len(dataset), batch_size):
        batch_imgs = [dataset[j]['image'] for j in range(i, min(i + batch_size, len(dataset)))]
        inputs = clip_processor(images=batch_imgs, return_tensors="pt", padding=True).to(device)
        emb = clip_model.get_image_features(**inputs)
        emb = emb.cpu().numpy()
        image_embeddings.append(emb)
        for j in range(i, min(i + batch_size, len(dataset))):
            image_indices.append(j)

image_embeddings = np.vstack(image_embeddings)
np.save("image_embeddings.npy", image_embeddings)
np.save("image_indices.npy", np.array(image_indices))


In [None]:
embedding_dim = image_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(image_embeddings)
faiss.write_index(index, "image_index.faiss")


# Define Image Retrieval Functions

In [None]:
def search_images_by_text_baseline(text_query, top_k=5, return_images=False):
    text_inputs = clip_processor(text=[text_query], return_tensors="pt").to(device)
    with torch.no_grad():
        text_emb = clip_model.get_text_features(**text_inputs).cpu().numpy()
    distances, indices = index.search(text_emb, k=top_k)
    if return_images:
        return [dataset[int(i)]['image'] for i in indices[0]]
    else:
        return [int(i) for i in indices[0]]


In [None]:
def search_images_by_image(query_image, top_k=5, return_images=False):
    img_inputs = clip_processor(images=[query_image], return_tensors="pt").to(device)
    with torch.no_grad():
        query_emb = clip_model.get_image_features(**img_inputs).cpu().numpy()
    distances, indices = index.search(query_emb, k=top_k)
    if return_images:
        return [dataset[int(i)]['image'] for i in indices[0]]
    else:
        return [int(i) for i in indices[0]]


# Build an interactive Gradio app for text and image-based image search


In [None]:
import gradio as gr

with gr.Blocks() as demo:
    gr.Markdown("# ContextVision: Multimodal Image Search")

    with gr.Tab("Text to Image Search"):
        text_input = gr.Textbox(label="Enter text to find similar images")
        text_output = gr.Gallery(label="Retrieved Images")
        text_input.change(
            fn=lambda q: search_images_by_text_baseline(q, top_k=5, return_images=True),
            inputs=text_input,
            outputs=text_output
        )


    with gr.Tab("Image to Image Search"):
        image_input = gr.Image(type="pil", label="Query Image")
        image_output = gr.Gallery(label="Retrieved Images")
        image_input.change(
            fn=lambda img: search_images_by_image(img, top_k=5, return_images=True),
            inputs=image_input,
            outputs=image_output
        )

demo.launch()


## Evaluation Metrics


In [None]:
import numpy as np

def precision_at_k(retrieved, relevant, k):
    retrieved_k = retrieved[:k]
    return len(set(retrieved_k) & set(relevant)) / k

def recall_at_k(retrieved, relevant, k):
    retrieved_k = retrieved[:k]
    return len(set(retrieved_k) & set(relevant)) / len(relevant) if len(relevant) > 0 else 0.0

def reciprocal_rank(retrieved, relevant):
    for rank, idx in enumerate(retrieved, start=1):
        if idx in relevant:
            return 1.0 / rank
    return 0.0

def dcg(relevances):
    return np.sum(relevances / np.log2(np.arange(2, len(relevances) + 2)))

def ndcg(retrieved, relevant, k):
    relevances = [1 if idx in relevant else 0 for idx in retrieved[:k]]
    idcg = dcg(sorted(relevances, reverse=True))
    return dcg(relevances) / idcg if idcg > 0 else 0.0

# --- Step 1: Prepare queries and ground truths ---
queries = []
ground_truths = []
for i in range(100):
    item = dataset[i]
    query = item['caption_1']
    queries.append(query)
    ground_truths.append([i])

# --- Step 2: Run retrieval/search function for each query ---
retrieved_results = []
for query in queries:
    retrieved_indices = search_images_by_text_baseline(query, top_k=5)
    retrieved_results.append(retrieved_indices)

# --- Step 3: Compute metrics ---
k = 5
precision_scores = []
recall_scores = []
mrr_scores = []
ndcg_scores = []

for relevant, retrieved in zip(ground_truths, retrieved_results):
    precision_scores.append(precision_at_k(retrieved, relevant, k))
    recall_scores.append(recall_at_k(retrieved, relevant, k))
    mrr_scores.append(reciprocal_rank(retrieved, relevant))
    ndcg_scores.append(ndcg(retrieved, relevant, k))

mean_precision = np.mean(precision_scores)
mean_recall = np.mean(recall_scores)
mean_mrr = np.mean(mrr_scores)
mean_ndcg = np.mean(ndcg_scores)

print(f"Precision@{k}: {mean_precision:.3f}")
print(f"Recall@{k}: {mean_recall:.3f}")
print(f"MRR: {mean_mrr:.3f}")
print(f"nDCG@{k}: {mean_ndcg:.3f}")
