In [None]:
import os
import torch
import clip
import open_clip
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
#PLEASE RUN THE download_datasets.ipynb PRESENT IN THE BASE DIRECTORY OF THE REPO FIRST TO DOWNLOAD THE DATASETS

In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
model, preprocess = clip.load("ViT-B/16", device)
model.eval()

In [None]:
IMAGES_PATH = f"{datasets}/flickr8k/Images"        # Folder containing 8000 images
CAPTIONS_PATH = f"{datasets}/flickr8k/captions.txt"  # Caption file
MODEL_PATH = f"{models}"  # Path to save the trained model

In [None]:
import csv

captions = []

with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
    reader = csv.reader(f)
    next(reader)  # skip header: image,caption

    for row in reader:
        if len(row) < 2:
            continue
        img_name, caption = row
        captions.append((img_name.strip(), caption.strip()))

print("Total captions:", len(captions))
print("Sample:", captions[:5])

In [None]:
image_features = {}

print("Extracting image embeddings...")
for img_name in tqdm(os.listdir(IMAGES_PATH)):
    img_path = os.path.join(IMAGES_PATH, img_name)

    try:
        image = Image.open(img_path).convert("RGB")
    except:
        continue

    image_input = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        emb = model.encode_image(image_input)
        emb = emb / emb.norm(dim=-1, keepdim=True)

    image_features[img_name] = emb.cpu()

torch.save(image_features, f"{models}/image_features_flickr8k.pt")
print("Saved image features!")

In [None]:
caption_features = {}

print("Extracting caption embeddings with ensembling...")

# Step 1: group captions by image
from collections import defaultdict
captions_by_image = defaultdict(list)
for img_name, caption in captions:
    captions_by_image[img_name].append(caption)

for img_name, caps in tqdm(captions_by_image.items()):
    emb_list = []
    for cap in caps:
        text_input = clip.tokenize([cap]).to(device)
        with torch.no_grad():
            emb = model.encode_text(text_input)
            emb = emb / emb.norm(dim=-1, keepdim=True)
            emb_list.append(emb.cpu())
    # -----------------------------
    # Caption ensembling: average embeddings of all 5 captions
    # -----------------------------
    caption_features[img_name] = torch.stack(emb_list).mean(dim=0)
    caption_features[img_name] /= caption_features[img_name].norm()  # normalize

torch.save(caption_features, f"{models}/caption_features_flickr8k.pt")
print("Saved caption features with ensembling!")