In [None]:
!pip install --quiet ftfy regex tqdm
!pip install --quiet git+https://github.com/openai/CLIP.git
!pip install --quiet pycocotools


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for clip (setup.py) ... [?25l[?25hdone


# Download and Prepare the MS COCO Dataset

In [None]:
import os
import requests
from zipfile import ZipFile
from tqdm import tqdm

# Define paths
data_dir = '/content/coco2014'
os.makedirs(data_dir, exist_ok=True)

# URLs for datasets and annotations
datasets = {
    "train2014": "http://images.cocodataset.org/zips/train2014.zip",
    "val2014": "http://images.cocodataset.org/zips/val2014.zip",
    "annotations_trainval2014": "http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
}

# Download helper function with progress bar
def download_file(url, dest_path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(dest_path, 'wb') as f, tqdm(
        desc=f"Downloading {os.path.basename(dest_path)}",
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            f.write(data)
            bar.update(len(data))

# Download and extract datasets
for name, url in datasets.items():
    zip_path = os.path.join(data_dir, f"{name}.zip")
    print(f"Processing {name}...")

    # Download the dataset
    download_file(url, zip_path)

    # Unzip the dataset
    with ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)

    # Remove the zip file to save space
    os.remove(zip_path)
    print(f"{name} downloaded and extracted.")

print("All datasets and annotations successfully downloaded and extracted!")


Processing train2014...


Downloading train2014.zip: 100%|██████████| 12.6G/12.6G [14:25<00:00, 15.6MB/s]


train2014 downloaded and extracted.
Processing val2014...


Downloading val2014.zip: 100%|██████████| 6.19G/6.19G [07:49<00:00, 14.2MB/s]


val2014 downloaded and extracted.
Processing annotations_trainval2014...


Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:19<00:00, 12.8MB/s]


annotations_trainval2014 downloaded and extracted.
All datasets and annotations successfully downloaded and extracted!


# Load the Teacher Model: CLIP RN50 Model

In [None]:
import torch
import clip
import numpy as np


device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Load the CLIP model
model, preprocess = clip.load("RN50", device)
model.eval()

input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)


100%|███████████████████████████████████████| 244M/244M [00:08<00:00, 31.5MiB/s]


Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408


# Define the Student Model (ResNet-34)

In [None]:
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F


# Student Image Encoder (ResNet-34)
class StudentImageEncoder(nn.Module):
    def __init__(self, output_dim):
        super(StudentImageEncoder, self).__init__()
        self.encoder = models.resnet34(pretrained=False)
        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, output_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = x / x.norm(dim=-1, keepdim=True)  # Normalize
        return x


class StudentTextEncoder(nn.Module):
    def __init__(self, vocab_size, context_length, output_dim):
        super(StudentTextEncoder, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, output_dim)
        self.positional_embedding = nn.Parameter(torch.zeros(context_length, output_dim))
        nn.init.normal_(self.positional_embedding, std=0.01)
        encoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=8)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.ln_final = nn.LayerNorm(output_dim)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        x = self.token_embedding(x) + self.positional_embedding  # (batch_size, seq_len, output_dim)
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, output_dim)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # (batch_size, seq_len, output_dim)
        x = self.ln_final(x)
        x = x.mean(dim=1)  # Mean pooling over the sequence length
        x = x / x.norm(dim=-1, keepdim=True)  # Normalize to unit length
        return x  # (batch_size, output_dim)



# Prepare the MSCOCO Data Loaders

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset

# Preprocessing transforms
transform = transforms.Compose([
    transforms.Resize((input_resolution, input_resolution)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

# Custom dataset to select one caption per image
class CocoDataset(Dataset):
    def __init__(self, root, annFile, transform=None):
        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, captions = self.dataset[idx]
        # Select the first caption
        caption = captions[0]
        # Tokenize the caption
        text = clip.tokenize(caption, context_length=context_length)[0]
        return image, text

# Paths to images and annotations
train_img_dir = os.path.join(data_dir, 'train2014')
train_ann_file = os.path.join(data_dir, 'annotations', 'captions_train2014.json')

# Create the training dataset and dataloader
train_dataset = CocoDataset(root=train_img_dir, annFile=train_ann_file, transform=transform)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,  # batch size
    shuffle=True,
    num_workers=2,
    pin_memory=True
)


loading annotations into memory...
Done (t=0.66s)
creating index...
index created!


# Define the Loss Function Using MSE

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn


def loss_with_l2(
    student_image_features,
    student_text_features,
    teacher_image_features,
    teacher_text_features
):


    # Normalize features
    student_image_features = student_image_features / student_image_features.norm(dim=-1, keepdim=True)
    student_text_features = student_text_features / student_text_features.norm(dim=-1, keepdim=True)
    teacher_image_features = teacher_image_features / teacher_image_features.norm(dim=-1, keepdim=True)
    teacher_text_features = teacher_text_features / teacher_text_features.norm(dim=-1, keepdim=True)


    # L2 distance loss
    l2_img = F.mse_loss(student_image_features, teacher_image_features)
    l2_txt = F.mse_loss(student_text_features, teacher_text_features)
    l2_loss = (l2_img + l2_txt) / 2


    total_loss = l2_loss

    return total_loss


# Set Up the Training Loop

In [None]:
# Instantiate student models
student_image_encoder = StudentImageEncoder(output_dim=1024).to(device)
student_text_encoder = StudentTextEncoder(vocab_size, context_length, output_dim=1024).to(device)

# Define optimizer
optimizer = torch.optim.Adam(
    list(student_image_encoder.parameters()) + list(student_text_encoder.parameters()),
    lr=1e-4
)




# Train the Student Model

In [None]:
# Training Loop
num_epochs = 5  # the number of epochs

for epoch in range(num_epochs):
    student_image_encoder.train()
    student_text_encoder.train()
    total_loss = 0.0

    for batch_idx, (images, texts) in enumerate(train_dataloader):
        images = images.to(device)
        texts = texts.to(device)

        # Teacher outputs
        with torch.no_grad():
            teacher_image_features = model.encode_image(images)
            teacher_text_features = model.encode_text(texts)

        # Student outputs
        student_image_features = student_image_encoder(images).to(teacher_image_features.dtype)
        student_text_features = student_text_encoder(texts).to(teacher_text_features.dtype)

        # Compute Loss with detailed returns
        loss = loss_with_l2(
            student_image_features,
            student_text_features,
            teacher_image_features,
            teacher_text_features
        )

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_dataloader)}], "
                f"Loss: {loss.item():.10f}"
            )

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")


Epoch [1/5], Step [0/1294], Loss: 0.0006260872
Epoch [1/5], Step [100/1294], Loss: 0.0006008148
Epoch [1/5], Step [200/1294], Loss: 0.0005578995
Epoch [1/5], Step [300/1294], Loss: 0.0004940033
Epoch [1/5], Step [400/1294], Loss: 0.0004930496
Epoch [1/5], Step [500/1294], Loss: 0.0004553795
Epoch [1/5], Step [600/1294], Loss: 0.0004348755
Epoch [1/5], Step [700/1294], Loss: 0.0004620552
Epoch [1/5], Step [800/1294], Loss: 0.0004358292
Epoch [1/5], Step [900/1294], Loss: 0.0004110336
Epoch [1/5], Step [1000/1294], Loss: 0.0004281998
Epoch [1/5], Step [1100/1294], Loss: 0.0004146099
Epoch [1/5], Step [1200/1294], Loss: 0.0004065037
Epoch [1/5], Average Loss: 0.0005
Epoch [2/5], Step [0/1294], Loss: 0.0004043579
Epoch [2/5], Step [100/1294], Loss: 0.0004079342
Epoch [2/5], Step [200/1294], Loss: 0.0004165173
Epoch [2/5], Step [300/1294], Loss: 0.0003681183
Epoch [2/5], Step [400/1294], Loss: 0.0003802776
Epoch [2/5], Step [500/1294], Loss: 0.0003755093
Epoch [2/5], Step [600/1294], Loss: 

# Evaluate the Trained Student Model

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import clip
import os
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
input_resolution = 224
context_length = 77

# Evaluation transforms (same as training)
eval_transform = transforms.Compose([
    transforms.Resize((input_resolution, input_resolution)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

class CocoEvalDataset(Dataset):
    def __init__(self, root, annFile, transform=None):
        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, captions = self.dataset[idx]
        # Return the full list of captions for each image
        return image, captions

def coco_collate_fn(batch):
    # batch is a list of (image, captions_list) tuples
    images = []
    captions = []
    for img, caps in batch:
        images.append(img)      # img is a Tensor
        captions.append(caps)   # caps is a list of strings
    images = torch.stack(images, dim=0)  # stack all images into a single tensor
    return images, captions


# Paths for validation
val_img_dir = os.path.join(data_dir, 'val2014')
val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')



val_dataset = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2, collate_fn=coco_collate_fn)


student_image_encoder.eval()
student_text_encoder.eval()

all_image_features = []
all_text_features = []
image_to_text_indices = []  # For each image, store which text indices correspond to its captions
all_captions_flat = []  # We'll store all captions globally

with torch.no_grad():
    image_count = 0
    text_count = 0
    for images, batch_captions in val_dataloader:
        # images: (B, C, H, W)
        # batch_captions: list of length B, each item is a list of captions for that image

        images = images.to(device)

        # Encode images
        image_feats = student_image_encoder(images)
        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)
        all_image_features.append(image_feats.cpu())

        # Flatten captions for this batch
        flat_captions = []
        image_to_text_map_for_batch = []
        for caps in batch_captions:
            start_idx = len(flat_captions)
            flat_captions.extend(caps)  # add all captions from this image
            end_idx = len(flat_captions)
            # This image's captions correspond to indices [start_idx+text_count, end_idx+text_count)
            image_to_text_map_for_batch.append((start_idx + text_count, end_idx + text_count))

        # Tokenize all captions in the batch at once
        texts = clip.tokenize(flat_captions, context_length=context_length).to(device)
        text_feats = student_text_encoder(texts)
        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)

        # Store text features globally
        all_text_features.append(text_feats.cpu())
        all_captions_flat.extend(flat_captions)

        # Update the global mapping
        for (start_idx, end_idx) in image_to_text_map_for_batch:
            image_to_text_indices.append(list(range(start_idx, end_idx)))

        image_count += images.size(0)
        text_count += len(flat_captions)

all_image_features = torch.cat(all_image_features, dim=0)  # (N_images, embed_dim)
all_text_features = torch.cat(all_text_features, dim=0)    # (N_captions_total, embed_dim)

# Compute similarity matrix: shape (N_images, N_captions_total)
sim_matrix = all_image_features @ all_text_features.t()

def compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1):
    n = sim_matrix.size(0)
    successes = 0
    for i in range(n):
        scores = sim_matrix[i]
        sorted_indices = torch.argsort(scores, descending=True)

        correct_indices = set(image_to_text_indices[i])
        ranks_of_correct = []
        for cidx in correct_indices:
            pos = (sorted_indices == cidx).nonzero(as_tuple=True)
            if len(pos) > 0:
                ranks_of_correct.append(pos[0].item())

        if len(ranks_of_correct) > 0:
            min_rank = min(ranks_of_correct)
            if min_rank < k:
                successes += 1
    recall = successes / n
    return recall

# Image-to-Text Retrieval
r1 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1)
r5 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=5)
r10 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=10)

print("Image-to-Text Retrieval:")
print(f"Recall@1: {r1*100:.2f}%")
print(f"Recall@5: {r5*100:.2f}%")
print(f"Recall@10: {r10*100:.2f}%")

# Text-to-Image Retrieval
# Create reverse mapping from text index to image index
text_to_image = [None]*all_text_features.size(0)
for i, tinds in enumerate(image_to_text_indices):
    for t in tinds:
        text_to_image[t] = i

sim_matrix_t2i = sim_matrix.t()  # (N_captions_total, N_images)

def compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1):
    m = sim_matrix_t2i.size(0)
    successes = 0
    for j in range(m):
        scores = sim_matrix_t2i[j]
        sorted_indices = torch.argsort(scores, descending=True)
        correct_image = text_to_image[j]
        rank = (sorted_indices == correct_image).nonzero(as_tuple=True)[0].item()
        if rank < k:
            successes += 1
    recall = successes / m
    return recall

r1_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1)
r5_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=5)
r10_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=10)

print("Text-to-Image Retrieval:")
print(f"Recall@1: {r1_t2i*100:.2f}%")
print(f"Recall@5: {r5_t2i*100:.2f}%")
print(f"Recall@10: {r10_t2i*100:.2f}%")


loading annotations into memory...
Done (t=0.51s)
creating index...
index created!
Image-to-Text Retrieval:
Recall@1: 0.22%
Recall@5: 0.97%
Recall@10: 1.83%
Text-to-Image Retrieval:
Recall@1: 0.33%
Recall@5: 1.39%
Recall@10: 2.50%
