In [None]:
!pip install -U torch torchvision timm open_clip_torch

In [2]:
import torch, open_clip
from torchvision import datasets
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

# === Choose a model from recent CLIP-family work ===
# Classic strong baseline:
MODEL_NAME   = "ViT-B-32"
PRETRAINED   = "laion2b_s34b_b79k"   # from OpenCLIP


model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
model = model.to(device).eval()


In [None]:
# ===== Few-shot Linear Probe on CLIP image features =====
from torchvision import datasets
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch.nn.functional as F
import random
import torch

# -- helper: get CLIP embed dim (shared by image/text) --
def get_embed_dim(clip_model):
    if hasattr(clip_model, "text_projection") and clip_model.text_projection is not None:
        return clip_model.text_projection.shape[1]
    # Fallback via a tiny forward pass if needed
    with torch.no_grad():
        dummy = torch.randn(1, 3, 224, 224, device=device)
        f = clip_model.encode_image(dummy)
    return f.shape[-1]

# -- build few-shot subset from CIFAR-10 train split --
def make_fewshot_subset(dataset, shots_per_class=1, seed=0):
    random.seed(seed)
    by_class = {c: [] for c in range(len(dataset.classes))}
    # TorchVision CIFAR10 stores targets as list/attr
    targets = dataset.targets if hasattr(dataset, "targets") else [y for _, y in dataset]
    for idx, y in enumerate(targets):
        if len(by_class[y]) < shots_per_class:
            by_class[y].append(idx)
        if all(len(v) == shots_per_class for v in by_class.values()):
            break
    # If a class didn’t fill (shouldn’t happen here), pad randomly
    for c in by_class:
        if len(by_class[c]) < shots_per_class:
            remaining = [i for i, y in enumerate(targets) if y == c and i not in by_class[c]]
            random.shuffle(remaining)
            by_class[c].extend(remaining[: shots_per_class - len(by_class[c])])
    indices = [i for c in sorted(by_class) for i in by_class[c]]
    return Subset(dataset, indices)

# -- Linear Probe module (frozen image encoder + linear head) --
class LinearProbe(nn.Module):
    def __init__(self, clip_model, num_classes, normalize_features=True):
        super().__init__()
        self.clip = clip_model

        self.normalize = normalize_features

        embed_dim = get_embed_dim(clip_model)

        self.classifier = nn.Linear(embed_dim, num_classes, bias=True)

        # Init to small values (optional but nice)
        nn.init.normal_(self.classifier.weight, std=0.02)
        nn.init.constant_(self.classifier.bias, 0.)

        # Freeze all clip params
        for p in self.clip.parameters():
            p.requires_grad = False

    def forward(self, images):
        with torch.no_grad():
            feats = self.clip.encode_image(images)
        if self.normalize:
            feats = feats / feats.norm(dim=-1, keepdim=True)
        logits = self.classifier(feats)
        return logits

# -- accuracy helper --
@torch.no_grad()
def eval_accuracy(model_or_fn, loader, device):
    """model_or_fn(images)->logits; returns (correct, total)."""
    correct = total = 0
    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        logits = model_or_fn(images)
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return correct, total

# ===== Data: CIFAR-10 train (few-shot) + test =====
train_cifar = datasets.CIFAR10(root="./data", train=True, download=True, transform=preprocess)
test_cifar  = datasets.CIFAR10(root="./data", train=False, download=True, transform=preprocess)

shots_per_class = 4   # <<< change this to your N
fewshot_subset = make_fewshot_subset(train_cifar, shots_per_class=shots_per_class, seed=0)

train_loader = DataLoader(fewshot_subset, batch_size=128, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_cifar,    batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

num_classes = len(train_cifar.classes)

# ===== Zero-shot (reuse your prompts) =====
prompts = [f"a photo of a {c}" for c in train_cifar.classes]
with torch.no_grad():
    text_tokens   = tokenizer(prompts).to(device)
    text_features = model.encode_text(text_tokens)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

@torch.no_grad()
def zeroshot_logits(images):
    image_features = model.encode_image(images)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    return 100.0 * image_features @ text_features.T

# Evaluate zero-shot on test set
zs_correct, zs_total = eval_accuracy(zeroshot_logits, test_loader, device)
zs_acc = 100.0 * zs_correct / zs_total
print(f"[Zero-shot] CIFAR-10 accuracy: {zs_acc:.2f}% with {MODEL_NAME} ({PRETRAINED})")

# ===== Linear Probe training =====
probe = LinearProbe(model, num_classes=num_classes, normalize_features=True).to(device)

epochs = 10
optimizer = torch.optim.AdamW(probe.classifier.parameters(), lr=1e-3, weight_decay=1e-4)
# Cosine schedule over a small number of steps is fine; or keep fixed LR.
# steps_per_epoch = max(1, len(train_loader))
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*steps_per_epoch)

probe.train()
for epoch in range(1, epochs + 1):
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = probe(images)
        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        # scheduler.step()  # if using the cosine schedule

        running_loss += loss.item()

    # quick val on test each epoch
    probe.eval()
    lp_correct, lp_total = eval_accuracy(lambda x: probe(x), test_loader, device)
    lp_acc = 100.0 * lp_correct / lp_total
    probe.train()
    print(f"[LinearProbe] epoch {epoch:02d} | train_loss {running_loss/len(train_loader):.4f} | test_acc {lp_acc:.2f}%")

# Final comparison
probe.eval()
lp_correct, lp_total = eval_accuracy(lambda x: probe(x), test_loader, device)
lp_acc = 100.0 * lp_correct / lp_total
print(f"\n=== CIFAR-10 results ({shots_per_class}-shot per class) ===")
print(f"Zero-shot accuracy : {zs_acc:.2f}%")
print(f"Linear-probe acc   : {lp_acc:.2f}%  (encoder frozen, linear head)")




KeyboardInterrupt: 