In [None]:
# === MOUNT GOOGLE DRIVE ===
from google.colab import drive
drive.mount('/content/drive')

# === STANDARD IMPORTS ===
import os
import zipfile

# === CREATE /content/sample_data IF NOT EXIST ===
os.makedirs("/content/sample_data", exist_ok=True)
EXTRACT_BASE = "/content/sample_data"

# === UNZIP DATASETS ===
ZIP_PATHS = {
    "flickr8k": "/content/drive/MyDrive/Flickr8k.zip",
    "flickr30k": "/content/drive/MyDrive/Flickr30k.zip"
}

for name, zip_path in ZIP_PATHS.items():
    extract_path = os.path.join(EXTRACT_BASE, name)
    if not os.path.exists(os.path.join(extract_path, "Images")):  # Skip if already extracted
        print(f"📦 Extracting {name} dataset...")
        os.makedirs(extract_path, exist_ok=True)
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_path)
    else:
        print(f"✅ {name} dataset already extracted.")

In [None]:
!pip install -q transformers torchvision scikit-learn

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

In [None]:
# === HELPER FUNCTIONS ===
from PIL import Image
from collections import defaultdict
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torch

TOP_K_VALUES = [1, 3, 5]

def load_captions(filepath):
    mapping = defaultdict(list)
    with open(filepath, 'r', encoding='utf-8') as f:
        next(f)
        for line in f:
            parts = line.strip().split(',', 1)
            if len(parts) == 2:
                image_id, caption = parts
                mapping[image_id.strip()].append(caption.strip())
    return mapping

def process_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        vec = model.get_image_features(**inputs).squeeze().cpu()
    return vec

def process_caption(caption):
    inputs = processor(text=[caption], return_tensors="pt").to(device)
    with torch.no_grad():
        vec = model.get_text_features(**inputs).squeeze().cpu()
    return vec

def generate_embeddings(image_dir, caption_file, embedding_dir):
    os.makedirs(embedding_dir, exist_ok=True)
    img_to_caps = load_captions(caption_file)

    for img_file in tqdm(sorted(img_to_caps), desc=f"Embedding {os.path.basename(embedding_dir)}"):
        img_name = os.path.splitext(img_file)[0]
        img_out_path = os.path.join(embedding_dir, f"{img_name}.pt")
        cap_out_paths = [os.path.join(embedding_dir, f"{img_name}_cap_{i+1}.pt") for i in range(5)]

        if os.path.exists(img_out_path) and all(os.path.exists(p) for p in cap_out_paths):
            continue

        try:
            img_vec = process_image(os.path.join(image_dir, img_file))
            torch.save(img_vec, img_out_path)
        except Exception as e:
            print(f"Skipping image {img_file}: {e}")
            continue

        for i, cap in enumerate(img_to_caps[img_file][:5]):
            try:
                cap_vec = process_caption(cap)
                torch.save(cap_vec, cap_out_paths[i])
            except Exception as e:
                print(f"Caption {i+1} for {img_file} failed: {e}")

def load_embeddings(path):
    img_emb, cap_emb, cap2img = {}, {}, defaultdict(list)
    for fname in os.listdir(path):
        if not fname.endswith(".pt"):
            continue
        base = fname[:-3]
        full_path = os.path.join(path, fname)
        vec = torch.load(full_path)
        if "_cap_" in base:
            cap_emb[base] = vec
            img_id = base.split("_cap_")[0]
            cap2img[img_id].append(base)
        else:
            img_emb[base] = vec
    return img_emb, cap_emb, cap2img

def compute_top_k(similarities, ranked_keys, gt_keys, top_k):
    metrics = {k: {"accuracy": 0, "precision": 0, "recall": 0} for k in top_k}
    for k in top_k:
        top_k_preds = ranked_keys[:k]
        hits = len(set(top_k_preds) & set(gt_keys))
        metrics[k]["accuracy"] = 1 if hits > 0 else 0
        metrics[k]["precision"] = hits / k
        metrics[k]["recall"] = hits / len(gt_keys)
    return metrics

def aggregate(acc, m):
    for k in m:
        for key in m[k]:
            acc[k][key] += m[k][key]

def report(title, metrics, total):
    print(f"\n📌 {title} ({total} queries):")
    for k in TOP_K_VALUES:
        a, p, r = metrics[k]["accuracy"]/total, metrics[k]["precision"]/total, metrics[k]["recall"]/total
        print(f"Top-{k}: Accuracy = {a:.4f} | Precision = {p:.4f} | Recall = {r:.4f}")

def _load_generic_captions(caption_file):
    captions = defaultdict(list)
    with open(caption_file, 'r', encoding='utf-8') as f:
        header_skipped = False
        for line in f:
            line = line.strip()
            if not line:
                continue
            if not header_skipped and ("image" in line.lower() and "caption" in line.lower()):
                header_skipped = True
                continue

            # Detect separator and split
            if '\t' in line:
                parts = line.split('\t', 1)
            else:
                parts = line.split(',', 1)

            if len(parts) != 2:
                continue  # skip malformed lines

            image_id, caption = parts
            image_id = image_id.strip()
            caption = caption.strip().strip('"')
            captions[image_id].append(caption)

    return dict(captions)

def load_flickr8k_captions(caption_file):
    return _load_generic_captions(caption_file)

def load_flickr30k_captions(caption_file):
    return _load_generic_captions(caption_file)

def evaluate(embedding_dir):
    img_emb, cap_emb, cap2img = load_embeddings(embedding_dir)
    cap_keys, img_keys = sorted(cap_emb), sorted(img_emb)
    cap_matrix = torch.stack([cap_emb[k] for k in cap_keys]).numpy()
    img_matrix = torch.stack([img_emb[k] for k in img_keys]).numpy()

    # I2T
    i2t = {k: {"accuracy": 0, "precision": 0, "recall": 0} for k in TOP_K_VALUES}
    for img_id, img_vec in tqdm(img_emb.items(), desc="Image → Text"):
        sims = cosine_similarity(img_vec.unsqueeze(0).numpy(), cap_matrix)[0]
        ranked = [cap_keys[i] for i in np.argsort(sims)[::-1]]
        truth = cap2img[img_id]
        metrics = compute_top_k(sims, ranked, truth, TOP_K_VALUES)
        aggregate(i2t, metrics)
    report("Image → Text", i2t, len(img_emb))

    # T2I
    t2i = {k: {"accuracy": 0, "precision": 0, "recall": 0} for k in TOP_K_VALUES}
    for cap_id, cap_vec in tqdm(cap_emb.items(), desc="Text → Image"):
        sims = cosine_similarity(cap_vec.unsqueeze(0).numpy(), img_matrix)[0]
        ranked = [img_keys[i] for i in np.argsort(sims)[::-1]]
        truth = [cap_id.split("_cap_")[0]]
        metrics = compute_top_k(sims, ranked, truth, TOP_K_VALUES)
        aggregate(t2i, metrics)
    report("Text → Image", t2i, len(cap_emb))

    # T2T
    t2t = {k: {"accuracy": 0, "precision": 0, "recall": 0} for k in TOP_K_VALUES}
    for i, (cap_id, cap_vec) in enumerate(tqdm(cap_emb.items(), desc="Text → Text")):
        sims = cosine_similarity(cap_vec.unsqueeze(0).numpy(), cap_matrix)[0]
        sims[i] = -1e9
        ranked = [cap_keys[j] for j in np.argsort(sims)[::-1]]
        truth = [k for k in cap2img[cap_id.split("_cap_")[0]] if k != cap_id]
        metrics = compute_top_k(sims, ranked, truth, TOP_K_VALUES)
        aggregate(t2t, metrics)
    report("Text → Text", t2t, len(cap_emb))

# ---------- tiny helper ----------
def show_image(path, title=""):
    img = Image.open(path).convert("RGB")
    plt.imshow(img)
    plt.axis("off")
    plt.title(title, fontsize=10)

# ---------- 1. IMAGE ➜ TEXT ----------
def image_to_text_retrieval(
    image_id, image_embeddings, caption_embeddings, caption_keys,
    top_k, image_dir, captions_dict
):
    print("\n================ IMAGE → TEXT RETRIEVAL ================")

    query_vec   = image_embeddings[image_id].unsqueeze(0).cpu().numpy()
    cap_matrix  = torch.stack([caption_embeddings[k] for k in caption_keys]).cpu().numpy()
    sims        = cosine_similarity(query_vec, cap_matrix)[0]

    ranked_idx  = np.argsort(sims)[::-1][:top_k]
    ranked_caps = [caption_keys[i] for i in ranked_idx]

    # show query image
    show_image(os.path.join(image_dir, f"{image_id}.jpg"), "Query Image")
    plt.show()

    print(f"\nTop‑{top_k} captions:")
    print(f"caption_list for {image_id}.jpg =", captions_dict.get(f"{image_id}.jpg"))
    for r, cap_key in enumerate(ranked_caps, 1):
        base, num = cap_key.rsplit("_cap_", 1)
        caption = captions_dict.get(f"{base}.jpg", ["[caption missing]"]*5)[int(num)-1]
        print(f"{r}. {caption}")

# ---------- 2. TEXT ➜ IMAGE ----------
def text_to_image_retrieval(
    caption_key, image_embeddings, caption_embeddings, image_keys,
    top_k, image_dir, captions_dict
):
    print("\n================ TEXT → IMAGE RETRIEVAL ================")

    # Extract image file and caption index
    base_img, cap_num = caption_key.rsplit("_cap_", 1)
    try:
        caption_text = captions_dict[f"{base_img}.jpg"][int(cap_num) - 1]
    except (IndexError, KeyError, ValueError):
        caption_text = "[Original caption not found]"

    print(f"\n📜 Query Caption: {caption_text}")

    query_vec   = caption_embeddings[caption_key].unsqueeze(0).cpu().numpy()
    img_matrix  = torch.stack([image_embeddings[k] for k in image_keys]).cpu().numpy()
    sims        = cosine_similarity(query_vec, img_matrix)[0]

    ranked_idx  = np.argsort(sims)[::-1][:top_k]
    ranked_imgs = [image_keys[i] for i in ranked_idx]

    print(f"\nTop‑{top_k} images:")
    for r, img_id in enumerate(ranked_imgs, 1):
        plt.figure()
        show_image(os.path.join(image_dir, f"{img_id}.jpg"), f"Rank {r}")
        plt.show()

# ---------- 3. TEXT ➜ TEXT ----------
def text_to_text_retrieval(
    caption_key, caption_embeddings, caption_keys,
    top_k, captions_dict
):
    print("\n================ TEXT → TEXT RETRIEVAL ================")

    idx         = caption_keys.index(caption_key)
    query_vec   = caption_embeddings[caption_key].unsqueeze(0).cpu().numpy()
    cap_matrix  = torch.stack([caption_embeddings[k] for k in caption_keys]).cpu().numpy()
    sims        = cosine_similarity(query_vec, cap_matrix)[0]
    sims[idx]   = -1e9                                # drop self‑match

    ranked_idx  = np.argsort(sims)[::-1][:top_k]
    ranked_keys = [caption_keys[i] for i in ranked_idx]

    base, num   = caption_key.rsplit("_cap_", 1)
    query_cap   = captions_dict.get(f"{base}.jpg", ["[caption missing]"]*5)[int(num)-1]
    print("Query Caption:", query_cap, "\n")

    print(f"Top‑{top_k} similar captions:")
    for r, key in enumerate(ranked_keys, 1):
        b, n   = key.rsplit("_cap_", 1)
        cap    = captions_dict.get(f"{b}.jpg", ["[caption missing]"]*5)[int(n)-1]
        print(f"{r}. {cap}")

# ---------- 4. IMAGE ➜ IMAGE ----------
def image_to_image_retrieval(
    image_id, image_embeddings, image_keys,
    top_k, image_dir
):
    print("\n================ IMAGE → IMAGE RETRIEVAL ================")

    query_vec   = image_embeddings[image_id].unsqueeze(0).cpu().numpy()
    img_matrix  = torch.stack([image_embeddings[k] for k in image_keys]).cpu().numpy()
    sims        = cosine_similarity(query_vec, img_matrix)[0]
    sims[image_keys.index(image_id)] = -1e9           # drop self‑match

    ranked_idx  = np.argsort(sims)[::-1][:top_k]
    ranked_imgs = [image_keys[i] for i in ranked_idx]

    # show query
    show_image(os.path.join(image_dir, f"{image_id}.jpg"), "Query Image")
    plt.show()

    print(f"\nTop‑{top_k} similar images:")
    for r, img_id in enumerate(ranked_imgs, 1):
        plt.figure()
        show_image(os.path.join(image_dir, f"{img_id}.jpg"), f"Rank {r}")
        plt.show()


In [None]:
from google.colab import files
import shutil

# === CONTROL FLAGS ===
evaluate_flag = False  # Set to True to evaluation accuracy metrics

# === DATASET CONFIGURATION ===
datasets = {
    "flickr8k": {
        "image_dir": "/content/sample_data/flickr8k/Images",
        "caption_file": "/content/sample_data/flickr8k/captions.txt",
        "embedding_dir": "/content/sample_data/flickr8k/embeddings",
        "zip_path": "/content/drive/MyDrive/Flickr8k_embeddings_openai_clip.zip"
    },
    "flickr30k": {
        "image_dir": "/content/sample_data/flickr30k/Images",
        "caption_file": "/content/sample_data/flickr30k/captions.txt",
        "embedding_dir": "/content/sample_data/flickr30k/embeddings",
        "zip_path": "/content/drive/MyDrive/Flickr30k_embeddings_openai_clip.zip"
    }
}

# === CONDITIONAL EMBEDDING EXTRACTION ===
def maybe_extract_embeddings(zip_path, extract_to):
    if os.path.exists(zip_path):
        print(f"📦 Found precomputed embeddings: {os.path.basename(zip_path)}. Extracting...")
        os.makedirs(extract_to, exist_ok=True)
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        return True
    return False

for name, cfg in datasets.items():
    print(f"\n🚀 Running pipeline for {name.upper()}")

    # Check for precomputed embeddings
    if not maybe_extract_embeddings(cfg["zip_path"], cfg["embedding_dir"]):
        print(f"🔄 Embeddings not found for {name}. Generating from scratch...")

        # === Generate embeddings ===
        generate_embeddings(cfg["image_dir"], cfg["caption_file"], cfg["embedding_dir"])

        # === Zip generated embeddings and download ===
        zip_output = f"/content/{name}_embeddings.zip"
        shutil.make_archive(base_name=zip_output.replace(".zip", ""), format='zip', root_dir=cfg["embedding_dir"])
        print(f"📦 Zipped embeddings to {zip_output}")

        try:
            files.download(zip_output)
            print("⬇️ Triggered download of generated embeddings.")
        except Exception as e:
            print(f"⚠️ Download failed: {e}")
    else:
        print(f"✅ Using precomputed embeddings for {name}.")

    # === Optional Evaluation ===
    if evaluate_flag:
        print(f"📊 Evaluating {name} embeddings...")
        evaluate(cfg["embedding_dir"])
    else:
        print(f"⏩ Skipping evaluation for {name} (flag is off)")

    # === Load embeddings & captions ===
    image_embeddings, caption_embeddings, caption_to_image = load_embeddings(cfg["embedding_dir"])
    caption_keys = sorted(caption_embeddings.keys())
    image_keys = sorted(image_embeddings.keys())

    if name == "flickr8k":
        captions_dict = load_flickr8k_captions(cfg["caption_file"])
    else:
        captions_dict = load_flickr30k_captions(cfg["caption_file"])

    # === Select sample and run retrieval ===
    sample_image_id = image_keys[0]
    sample_caption_key = [k for k in caption_keys if k.startswith(sample_image_id)][0]

    print(f"\n🎯 Retrieval demo for {name.upper()} — Sample: {sample_image_id}")

    image_to_text_retrieval(sample_image_id, image_embeddings, caption_embeddings,
                              caption_keys, 3, cfg["image_dir"], captions_dict)

    text_to_image_retrieval(sample_caption_key, image_embeddings, caption_embeddings,
                              image_keys, 3, cfg["image_dir"], captions_dict)

    text_to_text_retrieval(sample_caption_key, caption_embeddings, caption_keys,
                           3, captions_dict)

    image_to_image_retrieval(sample_image_id, image_embeddings, image_keys,
                            3, cfg["image_dir"])