In [None]:
!pip install --quiet ftfy regex tqdm
!pip install --quiet git+https://github.com/openai/CLIP.git
!pip install --quiet pycocotools


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for clip (setup.py) ... [?25l[?25hdone


# Download and Prepare the MS COCO Dataset

In [None]:
import os
import requests
from zipfile import ZipFile
from tqdm import tqdm

# Define paths
data_dir = '/content/coco2014'
os.makedirs(data_dir, exist_ok=True)

# URLs for datasets and annotations
datasets = {
#    "train2014": "http://images.cocodataset.org/zips/train2014.zip",
    "val2014": "http://images.cocodataset.org/zips/val2014.zip",
    "annotations_trainval2014": "http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
}

# Download helper function with progress bar
def download_file(url, dest_path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(dest_path, 'wb') as f, tqdm(
        desc=f"Downloading {os.path.basename(dest_path)}",
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            f.write(data)
            bar.update(len(data))

# Download and extract datasets
for name, url in datasets.items():
    zip_path = os.path.join(data_dir, f"{name}.zip")
    print(f"Processing {name}...")

    # Download the dataset
    download_file(url, zip_path)

    # Unzip the dataset
    with ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)

    # Remove the zip file to save space
    os.remove(zip_path)
    print(f"{name} downloaded and extracted.")

print("All datasets and annotations successfully downloaded and extracted!")


Processing val2014...


Downloading val2014.zip: 100%|██████████| 6.19G/6.19G [14:09<00:00, 7.83MB/s]


val2014 downloaded and extracted.
Processing annotations_trainval2014...


Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:17<00:00, 14.4MB/s]


annotations_trainval2014 downloaded and extracted.
All datasets and annotations successfully downloaded and extracted!


# Evaluate the Teacher Model ResNet50

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import clip
import os
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the CLIP model (Teacher)
model, preprocess = clip.load("RN50", device)
model.eval()

input_resolution = model.visual.input_resolution
context_length = model.context_length

# Evaluation transforms (same as training)
eval_transform = preprocess

class CocoEvalDataset(Dataset):
    def __init__(self, root, annFile, transform=None):
        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, captions = self.dataset[idx]
        caption = captions[0]
        return image, caption


# Paths for validation
val_img_dir = os.path.join(data_dir, 'val2014')
val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')

val_dataset = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

all_image_features = []
all_text_features = []
all_captions = []  # Store captions for each image in order

with torch.no_grad():
    for images, captions in val_dataloader:
        images = images.to(device)
        # Tokenize captions here
        texts = clip.tokenize(captions, context_length=context_length).to(device)

        # Encode images and texts using the teacher model
        image_feats = model.encode_image(images)
        text_feats = model.encode_text(texts)

        # Normalize
        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)
        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)

        all_image_features.append(image_feats.cpu())
        all_text_features.append(text_feats.cpu())
        all_captions.extend(captions)

all_image_features = torch.cat(all_image_features, dim=0)  # (N, 512)
all_text_features = torch.cat(all_text_features, dim=0)    # (N, 512)

# Compute similarity matrix
# image-to-text similarity: each image vs all texts
sim_matrix = all_image_features @ all_text_features.t()  # (N, N)

# Function to compute recall@K
def compute_recall(sim_matrix, k=1):
    ranks = []
    n = sim_matrix.size(0)
    for i in range(n):
        # Sort texts by similarity to image i
        sorted_indices = torch.argsort(sim_matrix[i], descending=True)
        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()
        ranks.append(rank)
    ranks = torch.tensor(ranks)
    recall = (ranks < k).float().mean().item()
    return recall

r1 = compute_recall(sim_matrix, k=1)
r5 = compute_recall(sim_matrix, k=5)
r10 = compute_recall(sim_matrix, k=10)

print("Image-to-Text Retrieval:")
print(f"Recall@1: {r1*100:.2f}%")
print(f"Recall@5: {r5*100:.2f}%")
print(f"Recall@10: {r10*100:.2f}%")

# For text-to-image retrieval, we do the same but transpose the matrix
sim_matrix_t2i = sim_matrix.t()  # (N, N)

r1_t2i = compute_recall(sim_matrix_t2i, k=1)
r5_t2i = compute_recall(sim_matrix_t2i, k=5)
r10_t2i = compute_recall(sim_matrix_t2i, k=10)

print("Text-to-Image Retrieval:")
print(f"Recall@1: {r1_t2i*100:.2f}%")
print(f"Recall@5: {r5_t2i*100:.2f}%")
print(f"Recall@10: {r10_t2i*100:.2f}%")


100%|███████████████████████████████████████| 244M/244M [00:05<00:00, 46.7MiB/s]


loading annotations into memory...
Done (t=0.26s)
creating index...
index created!
Image-to-Text Retrieval:
Recall@1: 15.27%
Recall@5: 30.73%
Recall@10: 39.05%
Text-to-Image Retrieval:
Recall@1: 11.68%
Recall@5: 25.52%
Recall@10: 33.50%
