In [1]:
!pip install numpy
!pip install matplotlib
!pip install torch
!pip install torchvision
!pip install transformers
!pip install tqdm



In [2]:
import os
import torch
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
import random
import numpy as np
import torch.optim as optim
from transformers import CLIPTextModel, CLIPTokenizer
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.datasets import CIFAR100
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


# Image Encoder

In [3]:
class ImageTransformerEncoder(nn.Module):
    def __init__(self, image_size=(32, 64), patch_size=16, embed_dim=512, mlp_dim=1024, num_layers=1, num_heads=1):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size)
        self.patch_dim = 3 * patch_size * patch_size  # RGB channels
        self.num_layers = num_layers
        
        # Patch embedding
        self.patch_embedding = nn.Linear(self.patch_dim, embed_dim)
        
        # Positional encoding
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))
        
        # Class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
        blocks = []
        for _ in range(num_layers):
            blocks.append(nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=mlp_dim,
            dropout=0.1,
            activation='relu',
            batch_first=True
        ))
        self.layers = nn.ModuleList(blocks)

        self.final_ln = nn.LayerNorm(embed_dim)

        
    def patchify(self, x):
        """Convert image to patches"""
        B, C, H, W = x.shape
        assert H % self.patch_size == 0
        assert W % self.patch_size == 0
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().view(B, C, -1, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 1, 3, 4).contiguous()
        x = x.view(B, self.num_patches, -1)
        return x
    
    def forward(self, x):
        B = x.shape[0]
        
        # Convert to patches
        patches = self.patchify(x)  # (B, num_patches, patch_dim)
        
        # Patch embedding
        patch_embeddings = self.patch_embedding(patches)  # (B, num_patches, embed_dim)
        # Add positional encoding
        patch_embeddings = patch_embeddings + self.pos_embedding

        # Add class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, patch_embeddings], dim=1)
                
        for block in self.layers:
            x = block(x)

        x = self.final_ln(x)
        # Return class token as feature vector
        cls = x[:, 0]  # (B, embed_dim)

        return cls

# Contrastive Loss

In [4]:
# Contrastive loss function
def contrastive_loss(image_features, text_features, temperature=0.01):
    # Normalize features
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)
    
    # Compute cosine similarity matrix
    logits = torch.matmul(image_features, text_features.t()) / temperature
    
    # Labels for contrastive learning (diagonal should be positive pairs)
    batch_size = image_features.shape[0]
    labels = torch.arange(batch_size, device=image_features.device)
    
    # Cross-entropy loss for both directions
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    
    return (loss_i2t + loss_t2i) / 2

# Dataset

In [5]:
class CIFAR100PairedWithCaption(Dataset):
    """
    Returns: (stacked_image, caption, (label_left, label_right))
    - stacked_image: torch.Tensor, shape (C, H, 2*W)
    - caption: str, e.g. "the photo on the left is apple and the photo on the right is bus"
    - labels: tuple of ints
    """
    def __init__(self, root="./data", train=True, transform=None, download=True):
        self.cifar_dataset = CIFAR100(
            root=root,
            train=train,
            transform=None,
            download=download
        )
        self.transform = transform
        self.class_names = self.cifar_dataset.classes

    def __len__(self):
        return len(self.cifar_dataset)

    def __getitem__(self, idx):
        # Pick first image
        img1, label1 = self.cifar_dataset[idx]
        # Pick a random second image (could be the same as first)
        idx2 = torch.randint(high=len(self.cifar_dataset), size=(1,)).item()
        img2, label2 = self.cifar_dataset[idx2]

        # Apply transforms if provided
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        else:
            # Convert PIL to tensor if no transform
            img1 = ToTensor()(img1)
            img2 = ToTensor()(img2)

        # Stack horizontally: shape (C, H, 2*W)
        stacked_img = torch.cat([img1, img2], dim=2)

        # Caption
        class1 = self.class_names[label1]
        class2 = self.class_names[label2]
        caption = f"the photo on the left is {class1} and the photo on the right is {class2}"

        return stacked_img, caption, (label1, label2)

# Set seed and device

In [6]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # For completely deterministic results, may impact performance
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def get_device() -> torch.device:
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")

    return torch.device("cpu")

# Dataloader

In [7]:
# Create the custom datasets
def create_caption_dataloaders(dataset_class, batch_size=128, root="./data", seed=42):
    """Create dataloaders that return image-caption pairs"""
    
    # Data transforms
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create custom datasets
    train_dataset = dataset_class(
        root=root,
        train=True,
        transform=transform,
        download=True
    )

    # seed dataloaders    
    g = torch.Generator()
    g.manual_seed(seed)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0, 
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=g
    )
    
    test_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=g
    )
    
    return train_loader, test_loader, train_dataset.class_names

In [8]:
set_seed(seed=42)
my_device = get_device()

# get train and val dataloaders
bt_size = 8
train_loader, val_loader, class_names = create_caption_dataloaders(dataset_class=CIFAR100PairedWithCaption,
                                                                    batch_size=bt_size)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")
print(f"Classes: {len(class_names)}")

Files already downloaded and verified
Train loader: 6250 batches
Val loader: 6250 batches
Classes: 100


# Training and Validation functions

In [9]:
def encode_text(text_to_encode, text_encoder, tokenizer, device):
    inputs = tokenizer(text_to_encode, 
                       padding=True, 
                       truncation=True, 
                       return_tensors="pt")
    inputs = inputs.to(device)    
    outputs = text_encoder(**inputs)
    text_features = outputs.pooler_output
    
    return text_features


def train_step(images, texts, optimizer, image_encoder,
               text_encoder, tokenizer, device):
    optimizer.zero_grad()
    
    # Get image features
    images = images.to(device)
    image_features = image_encoder(images)
    
    # Get text features
    with torch.no_grad():
        text_features = encode_text(texts,  text_encoder, tokenizer, device)
    
    # Compute loss
    loss = contrastive_loss(image_features, text_features)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    return loss.item()

def train_epoch(image_encoder, text_encoder, tokenizer, 
                train_loader, epoch, optimizer, device):
    image_encoder.train()
    text_encoder.eval()
    total_loss = 0
    num_batches = 0
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')
    
    for batch_idx, (images, captions, label) in enumerate(progress_bar):
        
        # Forward pass and compute loss
        loss = train_step(images, captions, optimizer, image_encoder,
                          text_encoder, tokenizer, device)
        
        total_loss += loss
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({'Loss': f'{loss:.4f}', 'Avg Loss': f'{total_loss/num_batches:.4f}'})
    
    return total_loss / num_batches


def validate_epoch(image_encoder, text_encoder, tokenizer, 
                              val_loader, device):
    image_encoder.eval()
    text_encoder.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch_idx, (images, captions, labels) in enumerate(val_loader):            
            # Get features
            images = images.to(device)
            image_features = image_encoder(images)
            text_features = encode_text(captions, text_encoder, tokenizer, device)
            
            # Compute loss
            loss = contrastive_loss(image_features, text_features)

            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches

# Evaluation

In [None]:
# DO NOT modify the evaluation function
def evaluate_topk(image_encoder, text_encoder, tokenizer, 
                val_loader, class_names, device):
    """
    Zero-shot evaluation: For each image, predict the closest caption from the prompts
    Returns top-1, top-10, top-100 recall.
    """
    image_encoder.eval()
    text_encoder.eval()
    # Prepare all combinations
    prompts = []
    for class_name_left in class_names: # 100x100
        for class_name_right in class_names:
            prompts.append(f"the photo on the left is {class_name_left} and the photo on the right is {class_name_right}")
    assert len(prompts) == len(set(prompts)), "Prompts must be unique!"
    # map prompt to index
    prompt_to_id = {}
    for i, prompt in enumerate(prompts):
        prompt_to_id[prompt] = i

    with torch.no_grad():
        # batch prompts to reduce peak memory
        batch_prompts = 256
        txt_features = []
        for index in range(0, len(prompts), batch_prompts):
            current_prompts = prompts[index: index + batch_prompts]
            text_features = encode_text(current_prompts, text_encoder, tokenizer, device)  
            # normalize 
            text_features = F.normalize(text_features, dim=-1)
            # aggregate 
            txt_features.append(text_features)
        # stack
        text_features = torch.concatenate(txt_features, dim=0)
        assert text_features.size(0) == len(prompts)

        top1, top10, top100, total = 0, 0, 0, 0

        for images, captions, labels in tqdm(val_loader, desc="Topk eval"):
            images = images.to(device)
            # Encode images
            image_features = image_encoder(images)
            image_features = F.normalize(image_features, dim=-1)

            # Compute similarity (batch_size, num_classes)
            logits = image_features @ text_features.T

            # Top-1 and Top-5 predictions
            top1_pred = logits.argmax(dim=-1)
            top10_pred = logits.topk(10, dim=-1).indices
            top100_pred = logits.topk(100, dim=-1).indices

            # get_labels
            idx_relative_to_prompt = [] 
            for cap in captions:
                idx_relative_to_prompt.append(prompt_to_id[cap])
            idx_relative_to_prompt = torch.tensor(idx_relative_to_prompt, device=device)

            top1 += (top1_pred == idx_relative_to_prompt).sum().item()
            top10 += sum([idx_relative_to_prompt[i] in top10_pred[i] for i in range(idx_relative_to_prompt.size(0))])
            top100 += sum([idx_relative_to_prompt[i] in top100_pred[i] for i in range(idx_relative_to_prompt.size(0))])

            total += idx_relative_to_prompt.size(0)

    top1_acc = 100 * top1 / total
    top10_acc = 100 * top10 / total
    top100_acc = 100 * top100 / total
    print(f"Zero-shot Top-1 Acc: {top1_acc:.2f}%, Top-10 Acc: {top10_acc:.2f}% Top-100 Acc: {top100_acc:.2f}%")
    return top1_acc, top10_acc, top100_acc

# Load Image and Text Encoders

In [11]:
# Load CLIP text encoder and tokenizer
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder.to(my_device)

# models
# Create an instance of the encoder
image_encoder = ImageTransformerEncoder()
image_encoder.to(my_device)

ImageTransformerEncoder(
  (patch_embedding): Linear(in_features=768, out_features=512, bias=True)
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=1024, out_features=512, bias=True)
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (final_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)

# Training loop

In [12]:
parameters = list(image_encoder.parameters())
op = optim.Adam(parameters, lr=1e-4)

# Training parameters
num_epochs = 4
best_val_loss = float('inf')

print("Starting training...")
print(f"Training on {len(train_loader)} batches per epoch")
print(f"Validating on {len(val_loader)} batches per epoch")

# Training loop
for epoch in range(num_epochs):
    # Train
    train_loss = train_epoch(image_encoder, text_encoder, tokenizer, 
                             train_loader, epoch + 1, op, my_device)
    
    # Validate
    val_loss = validate_epoch(image_encoder, text_encoder, tokenizer, 
                              val_loader, my_device)
    
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # # Evaluate retrieval
    top1_acc, top10_acc, top100_acc = evaluate_topk(image_encoder, text_encoder, tokenizer, 
                                                    val_loader, class_names, my_device)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
    
    print("-" * 50)


Starting training...
Training on 6250 batches per epoch
Validating on 6250 batches per epoch


Epoch 1: 100%|██████████| 6250/6250 [02:12<00:00, 47.08it/s, Loss=1.5194, Avg Loss=1.8511]


Epoch 1/4
Train Loss: 1.8511, Val Loss: 1.6413


Topk eval: 100%|██████████| 6250/6250 [01:03<00:00, 98.67it/s] 


Zero-shot Top-1 Acc: 0.13%, Top-10 Acc: 1.16% Top-100 Acc: 7.81%
--------------------------------------------------


Epoch 2: 100%|██████████| 6250/6250 [02:18<00:00, 44.99it/s, Loss=1.4111, Avg Loss=1.6771]


Epoch 2/4
Train Loss: 1.6771, Val Loss: 1.5752


Topk eval: 100%|██████████| 6250/6250 [01:03<00:00, 97.78it/s] 


Zero-shot Top-1 Acc: 0.16%, Top-10 Acc: 1.31% Top-100 Acc: 8.83%
--------------------------------------------------


Epoch 3: 100%|██████████| 6250/6250 [02:17<00:00, 45.45it/s, Loss=1.4530, Avg Loss=1.6215]


Epoch 3/4
Train Loss: 1.6215, Val Loss: 1.4755


Topk eval: 100%|██████████| 6250/6250 [01:04<00:00, 97.58it/s] 


Zero-shot Top-1 Acc: 0.19%, Top-10 Acc: 1.49% Top-100 Acc: 10.46%
--------------------------------------------------


Epoch 4: 100%|██████████| 6250/6250 [02:18<00:00, 45.20it/s, Loss=2.1695, Avg Loss=1.5835]


Epoch 4/4
Train Loss: 1.5835, Val Loss: 1.4419


Topk eval: 100%|██████████| 6250/6250 [01:04<00:00, 97.54it/s] 

Zero-shot Top-1 Acc: 0.21%, Top-10 Acc: 1.68% Top-100 Acc: 10.98%
--------------------------------------------------





# Save Model

In [13]:
# Save image encoder (architecture + weights)
os.makedirs("output", exist_ok=True)  # Creates the directory if it doesn't exist
torch.save(image_encoder, "output/image_encoder.pth")

# Save text encoder and tokenizer (weights + config)
text_encoder.save_pretrained("output/clip_text_encoder")
tokenizer.save_pretrained("output/clip_text_encoder")

('output/clip_text_encoder/tokenizer_config.json',
 'output/clip_text_encoder/special_tokens_map.json',
 'output/clip_text_encoder/vocab.json',
 'output/clip_text_encoder/merges.txt',
 'output/clip_text_encoder/added_tokens.json')

# Load Model

In [14]:
# Load image encoder (architecture + weights)
image_encoder = torch.load("output/image_encoder.pth", weights_only=False)
image_encoder.to(my_device)

# Load text encoder and tokenizer (weights + config)
from transformers import CLIPTextModel, CLIPTokenizer
text_encoder = CLIPTextModel.from_pretrained("output/clip_text_encoder")
tokenizer = CLIPTokenizer.from_pretrained("output/clip_text_encoder")
text_encoder.to(my_device)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), ep

# Evaluate

In [15]:
# # Evaluate retrieval
top1_acc, top10_acc, top100_acc = evaluate_topk(image_encoder, text_encoder, tokenizer, 
                                                val_loader, class_names, my_device)

Topk eval: 100%|██████████| 6250/6250 [01:04<00:00, 97.64it/s] 

Zero-shot Top-1 Acc: 0.24%, Top-10 Acc: 1.69% Top-100 Acc: 10.97%



