In [ ]:
# 1) Install deps (ok to skip if already installed)
!pip install -U torch torchvision timm open_clip_torch tqdm

In [ ]:
# 2) Imports & seeds
import os, math, random, time
from dataclasses import dataclass
from typing import Tuple, List, Optional

import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, Dataset

import torchvision
from torchvision import datasets, transforms

import open_clip
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt

# Reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

In [ ]:
# 3) Config (edit as needed for quick experiments)
@dataclass
class Config:
    dataset_name: str = "CIFAR-10"
    model_name: str = "ViT-B-32"
    pretrained: str = "laion2b_s34b_b79k"
    batch_size: int = 256
    num_workers: int = 2
    max_epochs: int = 5
    prompt_len: int = 16
    init_scale: float = 0.02
    lr_prompt: float = 5e-3
    weight_decay: float = 0.0
    log_interval: int = 25
    temperature: float = 100.0
    overfit_subset: int = 64
    seed: int = 42

cfg = Config()
cfg

In [ ]:
# 4) Load base CLIP model + preprocess + tokenizer
set_seed(cfg.seed)
model, _, preprocess = open_clip.create_model_and_transforms(cfg.model_name, pretrained=cfg.pretrained)
tokenizer = open_clip.get_tokenizer(cfg.model_name)
model = model.to(device).eval()
print("Loaded:", cfg.model_name, "pretrained:", cfg.pretrained)

In [ ]:
# 5) CoOp module: learnable context (prompt) + wrapper that produces logits
class CoOpPrompt(nn.Module):
    def __init__(self, clip_model, tokenizer, classnames: List[str], n_ctx=16, init_scale=0.02, device="cpu"):
        """
        Learnable continuous prompt for CLIP text tower (CoOp-style).

        clip_model: open_clip model (frozen)
        tokenizer: tokenizer for the model
        classnames: list of class strings
        n_ctx: number of soft tokens to prepend
        init_scale: std of normal init for soft tokens

        Shapes:
            - ctx: [n_ctx, W]
            - class_token_ids: [C, L]
            - token embeddings: [C, L, W]
        """
        super().__init__()
        self.model = clip_model
        self.tokenizer = tokenizer
        self.classnames = classnames
        self.device = device

        self.context_length = getattr(self.model, "context_length", 77)
        self.width = self.model.token_embedding.weight.shape[1]

        self.ctx = nn.Parameter(init_scale * torch.randn(n_ctx, self.width))

        with torch.no_grad():
            self.class_token_ids = tokenizer(classnames).to(device)  # [C, L]
            self.eot_indices = self.class_token_ids.argmax(dim=-1)   # [C]

        for p in self.model.parameters():
            p.requires_grad = False

    def forward_text_features(self) -> torch.Tensor:
        """
        Returns:
            text_emb: [C, D] normalized embeddings for each class using current ctx
        """
        C = len(self.classnames)
        token_ids = self.class_token_ids  # [C, L]

        with torch.no_grad():
            tok_emb = self.model.token_embedding(token_ids)  # [C, L, W]

        sos = tok_emb[:, :1, :]         # [C,1,W]
        class_part = tok_emb[:, 1:, :]  # [C,L-1,W]

        ctx = self.ctx.unsqueeze(0).expand(C, -1, -1)  # [C, n_ctx, W]

        x = torch.cat([sos, ctx, class_part], dim=1)  # [C, 1+n_ctx+(L-1), W]

        L_target = getattr(self.model, "context_length", 77)
        if x.size(1) > L_target:
            x = x[:, :L_target, :]
        elif x.size(1) < L_target:
            pad_len = L_target - x.size(1)
            pad = torch.zeros(C, pad_len, x.size(2), device=self.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=1)
        L = x.size(1)

        text_dtype = self.model.token_embedding.weight.dtype
        pos = self.model.positional_embedding[:L].to(text_dtype)  # [L,W]
        attn = getattr(self.model, "attn_mask", None)
        if attn is not None:
            attn = attn[:L, :L].to(text_dtype)

        batch_first = getattr(self.model.transformer, "batch_first", False)
        if batch_first:
            x = x.to(text_dtype) + pos.unsqueeze(0)   # [C,L,W]
            x = self.model.transformer(x, attn_mask=attn)  # [C,L,W]
            x = self.model.ln_final(x).to(text_dtype)      # [C,L,W]
        else:
            x = (x.to(text_dtype) + pos).permute(1, 0, 2)  # [L,C,W]
            x = self.model.transformer(x, attn_mask=attn)  # [L,C,W]
            x = x.permute(1, 0, 2)                         # [C,L,W]
            x = self.model.ln_final(x).to(text_dtype)      # [C,L,W]

        eot = (self.eot_indices + self.ctx.shape[0]).clamp(max=L - 1)
        text_emb = x[torch.arange(C, device=self.device), eot] @ self.model.text_projection  # [C,D]
        text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)
        return text_emb


class PromptedCLIP(nn.Module):
    def __init__(self, clip_model, prompt_learner: CoOpPrompt, temperature: float = 100.0):
        """
        Wraps a frozen CLIP image tower + learnable prompt to produce class logits.
        """
        super().__init__()
        self.clip = clip_model
        self.prompt_learner = prompt_learner
        self.temperature = temperature
        for p in self.clip.parameters():
            p.requires_grad = False

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        images: [B, 3, H, W] -> logits: [B, C]
        """
        with torch.no_grad():
            img_feat = self.clip.encode_image(images)  # [B,D]
            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)

        text_feat = self.prompt_learner.forward_text_features()  # [C,D]
        logits = self.temperature * img_feat @ text_feat.T       # [B,C]
        return logits

In [ ]:
# 6) Data: CIFAR-10 train/val/test + overfit subset helper
set_seed(cfg.seed)

tfm_train = preprocess
tfm_eval = preprocess

train_set = datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm_train)
test_set  = datasets.CIFAR10(root="./data", train=False, download=True, transform=tfm_eval)

val_set = Subset(test_set, list(range(0, 5000)))
eval_set = Subset(test_set, list(range(5000, 10000)))

classnames = train_set.classes
print("Classes:", classnames)

def make_loaders(batch_size: int):
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=cfg.num_workers, pin_memory=True)
    val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
    test_loader  = DataLoader(eval_set,  batch_size=batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
    return train_loader, val_loader, test_loader

def make_overfit_loader(n: int = 64):
    idx = list(range(n))
    tiny = Subset(train_set, idx)
    return DataLoader(tiny, batch_size=min(64, n), shuffle=True, num_workers=0, pin_memory=True)

train_loader, val_loader, test_loader = make_loaders(cfg.batch_size)
overfit_loader = make_overfit_loader(cfg.overfit_subset)

In [ ]:
# 7) Instantiate CoOp + optimizer (prompt params only)
prompt_learner = CoOpPrompt(model, tokenizer, classnames, n_ctx=cfg.prompt_len, init_scale=cfg.init_scale, device=device).to(device)
model_coop = PromptedCLIP(model, prompt_learner, temperature=cfg.temperature).to(device)

opt = torch.optim.AdamW([prompt_learner.ctx], lr=cfg.lr_prompt, weight_decay=cfg.weight_decay)
loss_fn = nn.CrossEntropyLoss()

# Display trainable params
sum(p.numel() for p in prompt_learner.parameters() if p.requires_grad), prompt_learner.ctx.shape

In [ ]:
# 8) Train loop (prompt-only) with grad norm debug
def train_prompt_only(model_coop: PromptedCLIP,
                      opt: torch.optim.Optimizer,
                      train_loader: DataLoader,
                      val_loader: DataLoader,
                      device: str,
                      epochs: int,
                      log_interval: int = 25,
                      ema_alpha: float = 0.1,
                      debug_grad: bool = True):
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    prompt = model_coop.prompt_learner
    model_coop.to(device)

    for epoch in range(1, epochs + 1):
        model_coop.train()
        train_sum, train_correct, train_count = 0.0, 0, 0
        ema = None

        iterator = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} | train", leave=False, dynamic_ncols=True, mininterval=0.5)
        for step, (images, labels) in enumerate(iterator, start=1):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            logits = model_coop(images)
            loss = loss_fn(logits, labels)
            loss.backward()

            if debug_grad and (step % log_interval == 0 or step == len(iterator)):
                with torch.no_grad():
                    grad_norm = 0.0 if prompt.ctx.grad is None else float(prompt.ctx.grad.norm().item())
                iterator.set_postfix_str(f"loss={loss.item():.4f}  grad|ctx|={grad_norm:.2e}")

            opt.step()

            bs = images.size(0)
            train_sum += loss.item() * bs
            preds = logits.argmax(dim=1)
            train_correct += (preds == labels).sum().item()
            train_count += bs

            ema = loss.item() if ema is None else (1 - ema_alpha) * ema + ema_alpha * loss.item()

        avg_train_loss = train_sum / max(1, train_count)
        avg_train_acc = train_correct / max(1, train_count)

        v_loss, v_acc = evaluate_prompt_only(model_coop, val_loader, device=device, loss_fn=loss_fn, pbar=True)

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(v_loss)
        history['train_acc'].append(avg_train_acc)
        history['val_acc'].append(v_acc)

        print(f"Epoch {epoch:03d}: train_loss={avg_train_loss:.4f} train_acc={avg_train_acc*100:.2f}%  val_loss={v_loss:.4f} val_acc={v_acc*100:.2f}%")

    return history

In [ ]:
# 9) Evaluate + Plot
@torch.no_grad()
def evaluate_prompt_only(model_coop: PromptedCLIP, loader: DataLoader, device: str, loss_fn, desc="valid", pbar=True):
    model_coop.eval()
    loss_sum, correct, count = 0.0, 0, 0
    iterator = tqdm(loader, desc=desc, leave=False) if pbar else loader

    for images, labels in iterator:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model_coop(images)
        loss = loss_fn(logits, labels)

        bs = images.size(0)
        loss_sum += loss.item() * bs
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        count += bs

        if pbar:
            acc = (correct / max(1, count)) * 100.0
            iterator.set_postfix(loss=loss.item(), acc=f"{acc:.2f}%")

    return loss_sum / max(1, count), correct / max(1, count)


def plot_history(history):
    epochs = range(1, len(history['train_loss']) + 1)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4), constrained_layout=True)

    ax = axes[0]
    ax.plot(epochs, history['train_loss'], label='train')
    ax.plot(epochs, history['val_loss'], label='val')
    ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.set_title('Loss'); ax.legend()

    ax = axes[1]
    if 'train_acc' in history and 'val_acc' in history:
        ax.plot(epochs, [x * 100 for x in history['train_acc']], label='train')
        ax.plot(epochs, [x * 100 for x in history['val_acc']], label='val')
        ax.set_ylabel('Accuracy (%)')
    else:
        ax.text(0.5, 0.5, 'No accuracy in history', ha='center', va='center', transform=ax.transAxes)
    ax.set_xlabel('Epoch'); ax.set_title('Accuracy'); ax.legend()
    plt.show()

In [ ]:
# 10) Zero-shot baseline (optional quick check)
@torch.no_grad()
def zero_shot_acc(clip_model, tokenizer, classnames, loader, device="cpu", temperature=100.0):
    prompts = [f"a photo of a {c}" for c in classnames]
    text_tokens = tokenizer(prompts).to(device)
    text_features = clip_model.encode_text(text_tokens)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    correct = total = 0
    for images, labels in loader:
        images = images.to(device)
        image_features = clip_model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logits = temperature * image_features @ text_features.T
        preds = logits.argmax(dim=-1).cpu()
        correct += (preds == labels.cpu()).sum().item()
        total += labels.size(0)

    return 100.0 * correct / total

test_loader_full = DataLoader(eval_set, batch_size=128, shuffle=False, num_workers=2)
print("Zero-shot acc (eval split):", f"{zero_shot_acc(model, tokenizer, classnames, test_loader_full, device=device, temperature=cfg.temperature):.2f}%")

In [ ]:
# 11) Sanity check: overfit tiny subset
prompt_learner_oc = CoOpPrompt(model, tokenizer, classnames, n_ctx=cfg.prompt_len, init_scale=cfg.init_scale, device=device).to(device)
model_overfit = PromptedCLIP(model, prompt_learner_oc, temperature=cfg.temperature).to(device)
opt_overfit = torch.optim.AdamW([prompt_learner_oc.ctx], lr=cfg.lr_prompt, weight_decay=cfg.weight_decay)

tiny_loader = overfit_loader
val_tiny_loader = DataLoader(Subset(test_set, list(range(cfg.overfit_subset))), batch_size=cfg.overfit_subset, shuffle=False)

hist_tiny = train_prompt_only(model_overfit, opt_overfit, tiny_loader, val_tiny_loader, device=device, epochs=5, log_interval=5, debug_grad=True)
plot_history(hist_tiny)

In [ ]:
# 12) Train on full train set, validate on val, evaluate on eval split
prompt_learner = CoOpPrompt(model, tokenizer, classnames, n_ctx=cfg.prompt_len, init_scale=cfg.init_scale, device=device).to(device)
model_coop = PromptedCLIP(model, prompt_learner, temperature=cfg.temperature).to(device)
opt = torch.optim.AdamW([prompt_learner.ctx], lr=cfg.lr_prompt, weight_decay=cfg.weight_decay)

history = train_prompt_only(model_coop, opt, train_loader, val_loader, device=device, epochs=cfg.max_epochs, log_interval=cfg.log_interval, debug_grad=True)
plot_history(history)

test_loss, test_acc = evaluate_prompt_only(model_coop, test_loader, device=device, loss_fn=loss_fn, desc="test", pbar=True)
print(f"Test: loss={test_loss:.4f}, acc={test_acc*100:.2f}%")

In [ ]:
# 13) Extra analysis (example: ctx norms)
with torch.no_grad():
    ctx = model_coop.prompt_learner.ctx.detach().cpu()
print("Ctx param shape:", tuple(ctx.shape), "  mean|ctx|:", float(ctx.norm(dim=1).mean()))