In [None]:
# pip install -U torch torchvision timm open_clip_torch

import torch, open_clip
from torchvision import datasets
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

# === Choose a model from recent CLIP-family work ===
# Classic strong baseline:
MODEL_NAME   = "ViT-B-32"
PRETRAINED   = "laion2b_s34b_b79k"   # from OpenCLIP

# Tip: try very recent ones too (if available in your env):
# MODEL_NAME, PRETRAINED = "ViT-SO400M-14-SigLIP", "webli"        # SigLIP family
# MODEL_NAME, PRETRAINED = "EVA02-L-14", "laion2b_s9b_b144k"      # EVA-CLIP family
# (List available combos:)
# import pprint; pprint.pp(open_clip.list_pretrained())

# --- Load model + preprocess ---
model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
model = model.to(device).eval()

# --- Zero-shot on CIFAR-10 (tiny & quick) ---
cifar = datasets.CIFAR10(root="./data", train=False, download=True, transform=preprocess)
loader = DataLoader(cifar, batch_size=128, shuffle=False, num_workers=2)

classnames = cifar.classes

def get_text_features(classnames):
    prompts = [f"a photo of a {c}" for c in classnames]
    with torch.no_grad():
        text_tokens   = tokenizer(prompts).to(device)
        text_features = model.encode_text(text_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    return text_features

text_features = get_text_features(classnames)

def test_model(model, text_features, test_loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            # CLIP-style scaled cosine sims
            logits = 100.0 * image_features @ text_features.T
            preds = logits.argmax(dim=-1).cpu()
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Zero-shot accuracy: {100*accuracy:.2f}%")
    return accuracy

test_model(model, text_features, loader)


In [None]:

# --- Few-shot split (16-shot/class) -----------------------------------------
from torchvision import transforms
from torch.utils.data import Subset
import random, math, torch.nn as nn, torch.optim as optim

# CIFAR-10 train set for few-shot
train_tf = preprocess  # you can make it stronger later; using same as test keeps things simple
cifar_train = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)

def build_fewshot_indices(dataset, k_per_class=16):
    idxs_by_c = {c: [] for c in range(10)}
    for i, (_, y) in enumerate(dataset):
        if len(idxs_by_c[y]) < k_per_class:
            idxs_by_c[y].append(i)
        if all(len(v) >= k_per_class for v in idxs_by_c.values()):
            break
    fewshot = [i for c in range(10) for i in idxs_by_c[c]]
    return fewshot

few_idx = build_fewshot_indices(cifar_train, k_per_class=16)
fewshot_loader = DataLoader(Subset(cifar_train, few_idx), batch_size=128, shuffle=True, num_workers=2)

# --- CoOp: learnable prompt context -----------------------------------------
@torch.no_grad()
def tokenize_classnames(classnames, ctx_prefix="", ctx_suffix=""):
    # We'll still need tokenized class name pieces for positions after context
    texts = [f"{ctx_prefix}{name}{ctx_suffix}" for name in classnames]
    return tokenizer(texts)

class SimpleCoOp(nn.Module):
    """
    Learn n_ctx context tokens as free embeddings; keep CLIP frozen.
    """
    def __init__(self, clip_model, classnames, n_ctx=8):
        super().__init__()
        self.clip = clip_model
        for p in self.clip.parameters():
            p.requires_grad = False

        self.classnames = classnames
        self.dtype = next(self.clip.parameters()).dtype
        self.ctx_len = n_ctx

        # CLIP text parts we need
        self.token_embedding = self.clip.token_embedding           # (vocab, dim)
        self.positional_embedding = self.clip.positional_embedding  # (n_ctx, dim)
        self.transformer = self.clip.transformer
        self.ln_final = self.clip.ln_final
        self.text_projection = self.clip.text_projection
        self.register_buffer("attn_mask", self.clip.attn_mask, persistent=False)

        # init context vectors (learnable)
        ctx_dim = self.token_embedding.weight.shape[1]
        self.ctx = nn.Parameter(torch.randn(self.ctx_len, ctx_dim) * 0.02)

        # tokenized classnames to get their tokens/embs (frozen)
        self.classname_tokens = tokenize_classnames(classnames).to(self.ctx.device if self.ctx.is_cuda else "cpu")

        # end-of-text token index
        try:
            from open_clip.tokenizer import _tokenizer
            self.eot_token = _tokenizer.eot_token
        except Exception:
            # fallback: assume 49407 (OpenAI CLIP); fine for most OpenCLIP tokenizers too
            self.eot_token = 49407

    def forward(self, device):
        # Build a batch of prompts with learned ctx + classname tokens
        B = len(self.classnames)
        ctx = self.ctx.to(device, dtype=self.dtype)                                # (ctx_len, dim)
        ctx = ctx.unsqueeze(0).repeat(B, 1, 1)                                     # (B, ctx_len, dim)

        tokens = self.classname_tokens.to(device)                                  # (B, seqlen)
        with torch.no_grad():
            class_embs = self.token_embedding(tokens).to(dtype=self.dtype)         # (B, seqlen, dim)

        # Replace the first 'ctx_len' token slots *after* SOS with learned context.
        # Layout: [SOS] [CTX x n] [rest of tokens ... EOT] [PAD ...]
        # We fetch SOS position = 0
        sos = class_embs[:, :1, :]
        rest = class_embs[:, 1 + self.ctx_len:, :]                                 # drop slots we’ll replace

        x = torch.cat([sos, ctx, rest], dim=1)                                     # (B, T, dim)
        x = x + self.positional_embedding[:x.size(1)].to(x)                        # add pos
        x = x.permute(1, 0, 2)                                                     # NLD -> LND
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)                                                     # LND -> NLD
        x = self.ln_final(x)

        # take features at EOT position for each sequence
        # find EOT index in our retokenized sequence (still valid)
        eot_positions = (tokens == self.eot_token).argmax(dim=1)                   # (B,)
        feat = x[torch.arange(B, device=device), eot_positions] @ self.text_projection

        # normalize
        feat = feat / feat.norm(dim=-1, keepdim=True)
        return feat

# Instantiate CoOp head (learns only the context vectors)
coop = SimpleCoOp(model, classnames, n_ctx=8).to(device)
optim_coop = optim.AdamW([coop.ctx], lr=5e-3, weight_decay=0.0)
epochs = 5

# Precompute the model's logit scale (frozen)
logit_scale = model.logit_scale.exp().to(device)

@torch.no_grad()
def eval_with_coop(model, coop_head, loader):
    coop_head.eval()
    correct = total = 0
    # class text features from coop head:
    class_text_features = coop_head(device)  # (num_classes, d)
    for images, labels in loader:
        images = images.to(device)
        img_f = model.encode_image(images)
        img_f = img_f / img_f.norm(dim=-1, keepdim=True)
        logits = (logit_scale * img_f @ class_text_features.T)
        preds = logits.argmax(dim=-1).cpu()
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    print(f"CoOp few-shot eval: {100*correct/total:.2f}%")

# Train CoOp with cross-entropy using fixed image features
ce = nn.CrossEntropyLoss()
for ep in range(1, epochs+1):
    coop.train()
    total_loss = 0.0
    for imgs, labels in fewshot_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.no_grad():
            img_f = model.encode_image(imgs)
            img_f = img_f / img_f.norm(dim=-1, keepdim=True)
        txt_f = coop(device)                          # (C, d), recompute every step (cheap)
        logits = (logit_scale * img_f @ txt_f.T)      # (B, C)
        loss = ce(logits, labels)

        optim_coop.zero_grad(set_to_none=True)
        loss.backward()
        optim_coop.step()
        total_loss += loss.item() * labels.size(0)

    print(f"[CoOp] epoch {ep} | loss {total_loss/len(few_idx):.4f}")
    eval_with_coop(model, coop, loader)

# After training, you can get improved text_features for future evals:
with torch.no_grad():
    text_features_coop = coop(device)  # use this instead of the original text_features
