In [None]:
!pip install --quiet ftfy regex tqdm
!pip install --quiet git+https://github.com/openai/CLIP.git
!pip install --quiet pycocotools


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for clip (setup.py) ... [?25l[?25hdone


# Download and Prepare the MS COCO Dataset

In [None]:
import os
import requests
from zipfile import ZipFile
from tqdm import tqdm

# Define paths
data_dir = '/content/coco2014'
os.makedirs(data_dir, exist_ok=True)

# URLs for datasets and annotations
datasets = {
    "train2014": "http://images.cocodataset.org/zips/train2014.zip",
    "val2014": "http://images.cocodataset.org/zips/val2014.zip",
    "annotations_trainval2014": "http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
}

# Download helper function with progress bar
def download_file(url, dest_path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(dest_path, 'wb') as f, tqdm(
        desc=f"Downloading {os.path.basename(dest_path)}",
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            f.write(data)
            bar.update(len(data))

# Download and extract datasets
for name, url in datasets.items():
    zip_path = os.path.join(data_dir, f"{name}.zip")
    print(f"Processing {name}...")

    # Download the dataset
    download_file(url, zip_path)

    # Unzip the dataset
    with ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)

    # Remove the zip file to save space
    os.remove(zip_path)
    print(f"{name} downloaded and extracted.")

print("All datasets and annotations successfully downloaded and extracted!")


Processing train2014...


Downloading train2014.zip: 100%|██████████| 12.6G/12.6G [15:02<00:00, 15.0MB/s]


train2014 downloaded and extracted.
Processing val2014...


Downloading val2014.zip: 100%|██████████| 6.19G/6.19G [07:00<00:00, 15.8MB/s]


val2014 downloaded and extracted.
Processing annotations_trainval2014...


Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:17<00:00, 14.8MB/s]


annotations_trainval2014 downloaded and extracted.
All datasets and annotations successfully downloaded and extracted!


# Load the Teacher Model: CLIP RN50 Model

In [None]:
import torch
import clip
import numpy as np


device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Load the CLIP model
model, preprocess = clip.load("RN50", device)
model.eval()

input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)


100%|████████████████████████████████████████| 244M/244M [00:02<00:00, 125MiB/s]


Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408


# Define the Student Model (ResNet-34)

In [None]:
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F


# Student Image Encoder (ResNet-34)
class StudentImageEncoder(nn.Module):
    def __init__(self, output_dim):
        super(StudentImageEncoder, self).__init__()
        self.encoder = models.resnet34(pretrained=False)
        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, output_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = x / x.norm(dim=-1, keepdim=True)  # Normalize
        return x


class StudentTextEncoder(nn.Module):
    def __init__(self, vocab_size, context_length, output_dim):
        super(StudentTextEncoder, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, output_dim)
        self.positional_embedding = nn.Parameter(torch.zeros(context_length, output_dim))
        nn.init.normal_(self.positional_embedding, std=0.01)
        encoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=8)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.ln_final = nn.LayerNorm(output_dim)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        x = self.token_embedding(x) + self.positional_embedding  # (batch_size, seq_len, output_dim)
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, output_dim)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # (batch_size, seq_len, output_dim)
        x = self.ln_final(x)
        x = x.mean(dim=1)  # Mean pooling over the sequence length
        x = x / x.norm(dim=-1, keepdim=True)  # Normalize to unit length
        return x  # (batch_size, output_dim)



# Prepare the MSCOCO Data Loaders

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset

# Preprocessing transforms
transform = transforms.Compose([
    transforms.Resize((input_resolution, input_resolution)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

# Custom dataset to select one caption per image
class CocoDataset(Dataset):
    def __init__(self, root, annFile, transform=None):
        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, captions = self.dataset[idx]
        # Select the first caption
        caption = captions[0]
        # Tokenize the caption
        text = clip.tokenize(caption, context_length=context_length)[0]
        return image, text

# Paths to images and annotations
train_img_dir = os.path.join(data_dir, 'train2014')
train_ann_file = os.path.join(data_dir, 'annotations', 'captions_train2014.json')

# Create the training dataset and dataloader
train_dataset = CocoDataset(root=train_img_dir, annFile=train_ann_file, transform=transform)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,  # batch size
    shuffle=True,
    num_workers=2,
    pin_memory=True
)


loading annotations into memory...
Done (t=0.62s)
creating index...
index created!


# Define the Contrastive Loss Function Using Logit

In [None]:
import torch.nn.functional as F


def contrastive_loss(image_features, text_features, temperature=0.07):
    # Normalize features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    # Compute logits
    logits_per_image = image_features @ text_features.t() / temperature
    logits_per_text = logits_per_image.t()

    # Labels
    batch_size = image_features.size(0)
    labels = torch.arange(batch_size, device=image_features.device)

    # Cross entropy loss
    loss_image = F.cross_entropy(logits_per_image, labels)
    loss_text = F.cross_entropy(logits_per_text, labels)
    loss = (loss_image + loss_text) / 2

    return loss



# Set Up the Training Loop

In [None]:
# Instantiate student models
student_image_encoder = StudentImageEncoder(output_dim=1024).to(device)
student_text_encoder = StudentTextEncoder(vocab_size, context_length, output_dim=1024).to(device)

# Define optimizer
optimizer = torch.optim.Adam(
    list(student_image_encoder.parameters()) + list(student_text_encoder.parameters()),
    lr=1e-4
)




# Train the Student Model

In [None]:
# Training Loop
num_epochs = 5  # the number of epochs

for epoch in range(num_epochs):
    student_image_encoder.train()
    student_text_encoder.train()
    total_loss = 0.0

    for batch_idx, (images, texts) in enumerate(train_dataloader):
        images = images.to(device)
        texts = texts.to(device)

        # Teacher outputs
        with torch.no_grad():
            teacher_image_features = model.encode_image(images)
            teacher_text_features = model.encode_text(texts)

        # Student outputs
        student_image_features = student_image_encoder(images).to(teacher_image_features.dtype)
        student_text_features = student_text_encoder(texts).to(teacher_text_features.dtype)

        # Compute Contrastive Loss between student features and teacher features
        loss_image = contrastive_loss(student_image_features, teacher_text_features)
        loss_text = contrastive_loss(student_text_features, teacher_image_features)
        loss = (loss_image + loss_text) / 2

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_dataloader)}], Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")



Epoch [1/5], Step [0/1294], Loss: 4.1836
Epoch [1/5], Step [100/1294], Loss: 2.7109
Epoch [1/5], Step [200/1294], Loss: 2.2910
Epoch [1/5], Step [300/1294], Loss: 2.1309
Epoch [1/5], Step [400/1294], Loss: 2.0586
Epoch [1/5], Step [500/1294], Loss: 2.0000
Epoch [1/5], Step [600/1294], Loss: 1.7734
Epoch [1/5], Step [700/1294], Loss: 1.9160
Epoch [1/5], Step [800/1294], Loss: 1.7598
Epoch [1/5], Step [900/1294], Loss: 1.8320
Epoch [1/5], Step [1000/1294], Loss: 1.6387
Epoch [1/5], Step [1100/1294], Loss: 1.7070
Epoch [1/5], Step [1200/1294], Loss: 1.7334
Epoch [1/5], Average Loss: 2.0444
Epoch [2/5], Step [0/1294], Loss: 1.4980
Epoch [2/5], Step [100/1294], Loss: 1.6182
Epoch [2/5], Step [200/1294], Loss: 1.5879
Epoch [2/5], Step [300/1294], Loss: 1.6855
Epoch [2/5], Step [400/1294], Loss: 1.6748
Epoch [2/5], Step [500/1294], Loss: 1.5107
Epoch [2/5], Step [600/1294], Loss: 1.5361
Epoch [2/5], Step [700/1294], Loss: 1.6406
Epoch [2/5], Step [800/1294], Loss: 1.5654
Epoch [2/5], Step [90

# Evaluate the Trained Student Model

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import clip
import os
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
input_resolution = 224
context_length = 77

# Evaluation transforms (same as training)
eval_transform = transforms.Compose([
    transforms.Resize((input_resolution, input_resolution)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

class CocoEvalDataset(Dataset):
    def __init__(self, root, annFile, transform=None):
        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, captions = self.dataset[idx]
        # Pick just the first caption to avoid irregular batch shapes
        caption = captions[0]
        return image, caption


# Paths for validation
val_img_dir = os.path.join(data_dir, 'val2014')
val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')

val_dataset = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

student_image_encoder.eval()
student_text_encoder.eval()

all_image_features = []
all_text_features = []
all_captions = []  # Store captions for each image in order

with torch.no_grad():
    for images, captions in val_dataloader:
        images = images.to(device)
        # Tokenize captions here
        texts = clip.tokenize(captions, context_length=context_length).to(device)

        image_feats = student_image_encoder(images)
        text_feats = student_text_encoder(texts)


        # Normalize
        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)
        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)

        all_image_features.append(image_feats.cpu())
        all_text_features.append(text_feats.cpu())
        all_captions.extend(captions)

all_image_features = torch.cat(all_image_features, dim=0)  # (N, 512)
all_text_features = torch.cat(all_text_features, dim=0)    # (N, 512)

# Compute similarity matrix
# image-to-text similarity: each image vs all texts
sim_matrix = all_image_features @ all_text_features.t()  # (N, N)

# Function to compute recall@K
def compute_recall(sim_matrix, k=1):
    # sim_matrix[i, j]: similarity of image i and text j
    # For each image i, we find where the correct text ranks
    # Here we matched each image with its own text at the same index
    # If we have multiple captions per image and want a more robust metric,
    # we assume the first caption corresponds directly.
    ranks = []
    n = sim_matrix.size(0)
    for i in range(n):
        # Sort texts by similarity to image i
        sorted_indices = torch.argsort(sim_matrix[i], descending=True)
        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()
        ranks.append(rank)
    ranks = torch.tensor(ranks)
    recall = (ranks < k).float().mean().item()
    return recall

r1 = compute_recall(sim_matrix, k=1)
r5 = compute_recall(sim_matrix, k=5)
r10 = compute_recall(sim_matrix, k=10)

print("Image-to-Text Retrieval:")
print(f"Recall@1: {r1*100:.2f}%")
print(f"Recall@5: {r5*100:.2f}%")
print(f"Recall@10: {r10*100:.2f}%")

# For text-to-image retrieval, we do the same but transpose the matrix
# and consider each text in row i and find its image in column i.
sim_matrix_t2i = sim_matrix.t()  # (N, N)

r1_t2i = compute_recall(sim_matrix_t2i, k=1)
r5_t2i = compute_recall(sim_matrix_t2i, k=5)
r10_t2i = compute_recall(sim_matrix_t2i, k=10)

print("Text-to-Image Retrieval:")
print(f"Recall@1: {r1_t2i*100:.2f}%")
print(f"Recall@5: {r5_t2i*100:.2f}%")
print(f"Recall@10: {r10_t2i*100:.2f}%")


loading annotations into memory...
Done (t=0.30s)
creating index...
index created!
Image-to-Text Retrieval:
Recall@1: 0.39%
Recall@5: 1.77%
Recall@10: 3.16%
Text-to-Image Retrieval:
Recall@1: 0.67%
Recall@5: 2.55%
Recall@10: 4.28%
