# MobileCLIP2 + CoOp (Context Optimization)

This notebook starts from **MobileCLIP2** via `open_clip_torch` and adds a simple **Context Optimization (CoOp)** prompt learner, following the project's **STYLEGUIDE**.

**Notebook flow** (from STYLEGUIDE):
1. `!pip install` (only if needed)
2. Common imports + seeds
3. Load the model / pretrained model (+ preprocessors / related utilities)
4. Create own model module with clear separation of base vs add-ons
5. Hyperparameters section
6. Optimizer + Schedule
7. Train
8. Plot History
9. Evaluate + Show final result
10. Extra analysis and visualisation

> **References**: MobileCLIP2 in OpenCLIP (`MobileCLIP2-S{0,2,3,4}, B, L-14` with `pretrained='dfndr2b'`) and CoOp (Zhou et al.).


In [14]:
# 1) (Optional) Installs — uncomment on first run
# !pip install -U matplotlib tqdm timm torchvision torch --quiet
!pip install -U scikit-learn open-clip-torch
print('If you need packages, uncomment the pip lines above and run this cell.')

If you need packages, uncomment the pip lines above and run this cell.


In [15]:
# 2) Common imports + seeds
import math, random, os, time
from dataclasses import dataclass
from typing import Tuple, List

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

import torchvision
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import open_clip

def set_seed(seed:int=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Runnin on {device}")

try:
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != "spawn":
        mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass  # already set


Runnin on cuda


In [16]:
from dataclasses import dataclass
from typing import Tuple

# Pick a lightweight variant for quick experiments
# MODEL_NAME = 'MobileCLIP2-S0'   # alternatives: 'MobileCLIP2-S2', 'MobileCLIP2-B', 'MobileCLIP2-S3', 'MobileCLIP2-S4', 'MobileCLIP2-L-14'
# PRETRAINED = 'dfndr2b'          # per MobileCLIP2 release in OpenCLIP


@dataclass
class Config:
    model_name: str = "MobileCLIP2-S0"
    pretrained: str = "dfndr2b"
    image_size: int = 224
    batch_size: int = 64
    num_workers: int = 4
    max_epochs: int = 5

    # LR split: base vs prompt
    lr_base: float = 1e-5
    lr_prompt: float = 1e-3
    weight_decay: float = 0.05

    # Unfreeze (if you want a light finetune on top of prompt)
    unfreeze_layers: Tuple[str, ...] = tuple()  # e.g. ("visual.transformer.resblocks.11",)

    # CoOp prompt length
    prompt_len: int = 4

    # Overfit sanity tiny subset
    overfit_n_classes: int = 2
    overfit_k_per_class: int = 8
    overfit_epochs: int = 50

cfg = Config()
cfg


Config(model_name='MobileCLIP2-S0', pretrained='dfndr2b', image_size=224, batch_size=64, num_workers=4, max_epochs=5, lr_base=1e-05, lr_prompt=0.001, weight_decay=0.05, unfreeze_layers=(), prompt_len=4, overfit_n_classes=2, overfit_k_per_class=8, overfit_epochs=50)

In [17]:
# open_clip provides: model, preprocess
model, _, preprocess = open_clip.create_model_and_transforms(
    cfg.model_name, pretrained=cfg.pretrained, device=device
)
tokenizer = open_clip.get_tokenizer(cfg.model_name)

In [18]:
tfm_train = transforms.Compose([
    transforms.Resize((cfg.image_size, cfg.image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711)),
])
tfm_eval = transforms.Compose([
    transforms.Resize((cfg.image_size, cfg.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711)),
])

train_ds = datasets.CIFAR10(root="data", train=True,  transform=tfm_train, download=True)
val_ds   = datasets.CIFAR10(root="data", train=False, transform=tfm_eval,  download=True)
classnames = train_ds.classes  # list[str]
num_classes = len(classnames)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)




## 4) Model module: base (frozen MobileCLIP2) + CoOp prompt learner

- We keep the CLIP backbone **frozen** and in **eval** mode (BatchNorm stability on MobileCLIP2).
- We learn **M** context vectors `ctx` ("soft words") that prepend the class name tokens.
- Shapes are annotated per STYLEGUIDE.

In [19]:
class PromptLearner(nn.Module):
    def __init__(self, text_encoder, tokenizer, classnames, n_ctx=16, prefix="a photo of", suffix=""):
        super().__init__()
        self.classnames = classnames
        self.n_cls = len(classnames)
        self.n_ctx = n_ctx
        self.dtype = text_encoder.text_projection.dtype
        self.token_embedding = text_encoder.token_embedding          # (vocab_size, width)
        self.positional_embedding = text_encoder.positional_embedding
        self.transformer = text_encoder.transformer
        self.ln_final = text_encoder.ln_final
        self.text_projection = text_encoder.text_projection
        self.attn_mask = text_encoder.attn_mask                      # may be None for some cfgs
        self.vocab_size = self.token_embedding.num_embeddings
        self.width = self.token_embedding.embedding_dim
        self.register_parameter("ctx", nn.Parameter(torch.randn(n_ctx, self.width, dtype=self.dtype) * 0.02))
        self.prefix = prefix
        self.suffix = suffix

        self.tokenizer = tokenizer

        # Pre-tokenize classnames once
        # Template is: [SOS] prefix, [CTX x n_ctx], class, suffix, [EOS], [PAD...]
        prompts = [f"{prefix} {name}{(' ' + suffix) if suffix else ''}" for name in classnames]
        self.prompts_tokenized = self.tokenizer(prompts)  # (n_cls, context_length)
        with torch.no_grad():
            # We'll also tokenize the classnames alone to find where they start
            self.cls_tokenized = self.tokenizer(classnames)  # helps locate class tokens if needed

        # Get CLIP context length
        self.context_length = model.context_length

    def forward(self):
        # Build per-class token embeddings with learned context
        prompts = self.prompts_tokenized.to(device)  # (n_cls, context_length)
        x = self.token_embedding(prompts).to(self.dtype)  # (n_cls, context_length, width)

        # Find the location to insert ctx: right after [SOS] and (prefix) tokens.
        # Heuristic: put ctx immediately after the first token (start-of-text, index 0) + the prefix length.
        # Compute prefix length from tokenized prefix (minus special tokens).
        prefix_tokens = self.tokenizer(self.prefix)
        # count non-zero tokens except the starting 0 and trailing 0s
        pref_len = int((torch.tensor(prefix_tokens) != 0).sum().item()) - 2  # rough; adjusts ok in practice

        # Insert learned context at positions [1 ... n_ctx] after prefix
        # NOTE: we’re replacing those slots, not increasing sequence length.
        # Ensure we don't overflow context_length.
        start = 1 + max(pref_len, 0)
        end = min(start + self.n_ctx, self.context_length - 2)  # keep room for [EOS]
        ctx = self.ctx[: end - start]                           # (use as many as fit)
        x[:, start:end, :] = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        # Standard CLIP text forward pass on our *embedded* tokens
        x = x + self.positional_embedding.to(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).to(self.dtype)

        # Take features at the EOS token (the position of highest token id == tokenizer.eot_token)
        # In OpenCLIP, EOT token id is usually tokenizer.eot_token
        eot_token = self.tokenizer.eot_token
        eot_inds = (self.prompts_tokenized == eot_token).int().argmax(dim=1)
        text_embeds = x[torch.arange(x.shape[0]), eot_inds] @ self.text_projection

        # L2 normalize for cosine similarities
        text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
        return text_embeds  # (n_cls, d)


In [25]:
class PromptedMobileCLIP(nn.Module):
    def __init__(self, base_model, tokenizer, classnames, n_ctx=4):
        super().__init__()
        self.base = base_model
        self.text_encoder = base_model.text
        self.image_encoder = base_model.visual

        # --- freeze everything in the base, including logit_scale ---
        for p in self.base.parameters():
            p.requires_grad_(False)
        if hasattr(self.base, "logit_scale"):
            self.base.logit_scale.requires_grad_(False)

        self.prompt_learner = PromptLearner(self.text_encoder, tokenizer, classnames, n_ctx)

    def forward(self, images):
        image_features = self.base.encode_image(images)
        image_features = F.normalize(image_features, dim=-1)
        text_features  = self.prompt_learner()
        logit_scale = self.base.logit_scale.exp()
        return logit_scale * image_features @ text_features.t()


In [26]:
#model_coop = PromptedMobileCLIP(model, tokenizer, classnames, n_ctx=cfg.prompt_len).to(device)

# (Optional) lightly unfreeze some base layers
#for name, p in model_coop.base.named_parameters():
#    if any(un in name for un in cfg.unfreeze_layers):
#        p.requires_grad = True

# Optimizer: prompt (and any unfrozen base params) at different LRs
#prompt_params = [p for p in model_coop.prompt_learner.parameters() if p.requires_grad]
#base_params   = [p for n,p in model_coop.base.named_parameters() if p.requires_grad]

#param_groups = []
#if base_params:
#    param_groups.append({"params": base_params, "lr": cfg.lr_base, "weight_decay": cfg.weight_decay})
#param_groups.append({"params": prompt_params, "lr": cfg.lr_prompt, "weight_decay": cfg.weight_decay})

model_coop = PromptedMobileCLIP(model, tokenizer, classnames, n_ctx=cfg.prompt_len).to(device)

opt = torch.optim.AdamW(
    [{"params": model_coop.prompt_learner.parameters(), "lr": cfg.lr_prompt, "weight_decay": cfg.weight_decay}]
)
print("Trainable params:", sum(p.numel() for p in model_coop.parameters() if p.requires_grad))


Trainable params: 2048


In [27]:
def evaluate(model, loader, device, loss_fn, desc="eval", pbar=True):
    model.eval()
    loss_sum, correct, count = 0.0, 0, 0
    iterator = tqdm(loader, desc=desc, leave=False) if pbar else loader

    with torch.no_grad():
        for data, target in iterator:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(data)
            loss = loss_fn(output, target)
            bs = data.size(0)
            loss_sum += loss.item() * bs
            preds = output.argmax(dim=1)
            correct += (preds == target).sum().item()
            count += bs
            if pbar:
                acc = (correct / max(1, count)) * 100.0
                iterator.set_postfix(loss=f"{loss.item():.4f}", acc=f"{acc:.2f}%")

    return loss_sum / max(1, count), correct / max(1, count)

def train(model, train_loader, val_loader, optimizer, device, epochs: int,
          log_interval: int = 25, ema_alpha: float = 0.1):
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    model.to(device)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(1, epochs + 1):
        model.train()
        train_sum, train_correct, train_count = 0.0, 0, 0
        ema = None
        iterator = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} | train", leave=False, dynamic_ncols=True, mininterval=0.5)
        for step, (data, target) in enumerate(iterator, start=1):
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()

            # --- DEBUG: grad norms ---
            with torch.no_grad():
                gnorm_ctx = 0.0
                for p in model.prompt_learner.parameters():
                    if p.grad is not None:
                        g = p.grad.detach()
                        gnorm_ctx += (g*g).sum().item()
                gnorm_ctx = math.sqrt(gnorm_ctx) if gnorm_ctx > 0 else 0.0
                if step % log_interval == 0:
                    tqdm.write(f"[dbg] ctx grad L2 = {gnorm_ctx:.6e}")

            optimizer.step()

            # stats
            bs = data.size(0)
            train_sum += (loss.item() * bs)
            preds = output.argmax(dim=1)
            train_correct += (preds == target).sum().item()
            train_count += bs
            ema = loss.item() if ema is None else (1 - ema_alpha) * ema + ema_alpha * loss.item()
            if (step % log_interval == 0) or (step == len(iterator)):
                acc_pct = 100.0 * train_correct / max(1, train_count)
                iterator.set_postfix_str(f"loss(ema)={ema:.4f} acc={acc_pct:.2f}%")

        avg_train_loss = train_sum / max(1, train_count)
        avg_train_acc  = train_correct / max(1, train_count)
        avg_val_loss, avg_val_acc = evaluate(model, val_loader, device, loss_fn=loss_fn, desc="valid", pbar=True)

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(avg_train_acc)
        history['val_acc'].append(avg_val_acc)

        tqdm.write(f"Epoch {epoch:03d}: train_loss={avg_train_loss:.4f} train_acc={avg_train_acc*100:.2f}%  "
                   f"val_loss={avg_val_loss:.4f} val_acc={avg_val_acc*100:.2f}%")

    return history


In [28]:
def plot_history(history):
    epochs = range(1, len(history['train_loss']) + 1)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4), constrained_layout=True)
    ax = axes[0]
    ax.plot(epochs, history['train_loss'], label='train')
    ax.plot(epochs, history['val_loss'], label='val')
    ax.set_xlabel('Epoch') ; ax.set_ylabel('Loss') ; ax.set_title('Loss') ; ax.legend()
    ax = axes[1]
    if 'train_acc' in history and 'val_acc' in history:
        ax.plot(epochs, [x * 100 for x in history['train_acc']], label='train')
        ax.plot(epochs, [x * 100 for x in history['val_acc']], label='val')
        ax.set_ylabel('Accuracy (%)')
    else:
        ax.text(0.5, 0.5, 'No accuracy in history', ha='center', va='center', transform=ax.transAxes)
    ax.set_xlabel('Epoch') ; ax.set_title('Accuracy') ; ax.legend()
    plt.show()

In [29]:
def make_overfit_subset(dataset, n_classes=2, k_per_class=8):
    targets = np.array(dataset.targets)  # CIFAR10
    selected_indices = []
    classes = list(range(len(set(targets))))[:n_classes]
    for c in classes:
        idx = np.where(targets == c)[0].tolist()
        random.shuffle(idx)
        selected_indices.extend(idx[:k_per_class])
    random.shuffle(selected_indices)
    return Subset(dataset, selected_indices)

tiny_train = make_overfit_subset(train_ds, cfg.overfit_n_classes, cfg.overfit_k_per_class)
tiny_val   = make_overfit_subset(val_ds,   cfg.overfit_n_classes, cfg.overfit_k_per_class)

tiny_train_loader = DataLoader(tiny_train, batch_size=min(16, cfg.batch_size), shuffle=True,
                               num_workers=1, pin_memory=False, persistent_workers=True)
tiny_val_loader   = DataLoader(tiny_val, batch_size=min(32, cfg.batch_size), shuffle=False,
                               num_workers=1, pin_memory=False, persistent_workers=True)

# quick check: param counts
total = sum(p.numel() for p in model_coop.parameters() if p.requires_grad)
print("Trainable params:", total)

# Train a bit more for overfit
history_tiny = train(model_coop, tiny_train_loader, tiny_val_loader, opt, device, epochs=cfg.overfit_epochs)

plot_history(history_tiny)


Trainable params: 2048


Epoch 1/50 | train:   0%|          | 0/1 [00:04<?, ?it/s]

  pref_len = int((torch.tensor(prefix_tokens) != 0).sum().item()) - 2  # rough; adjusts ok in practice


AttributeError: 'SimpleTokenizer' object has no attribute 'eot_token'

In [None]:
images, y = next(iter(tiny_train_loader))  # or a full loader; any batch is fine
images = images.to(device)
model_coop.zero_grad(set_to_none=True)
logits = model_coop(images)
(logits.sum()).backward()
print("ctx grad L2:", model_coop.prompt_learner.ctx.grad.norm().item())


In [None]:
# Rebuild optimizer if you want to reset after overfit
opt_full = torch.optim.AdamW(
    [{"params": [p for n,p in model_coop.base.named_parameters() if p.requires_grad], "lr": cfg.lr_base,   "weight_decay": cfg.weight_decay},
     {"params": [p for p in model_coop.prompt_learner.parameters()],                 "lr": cfg.lr_prompt, "weight_decay": cfg.weight_decay}]
)
history = train(model_coop, train_loader, val_loader, opt, device, epochs=cfg.max_epochs)
plot_history(history)

In [None]:
# 10) Extra analysis: (optional) confusion matrix if sklearn is available
try:
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    y_true, y_pred = [], []
    with torch.no_grad():
        for x, y in tqdm(test_loader, leave=False, desc='cm'):
            p = model_coop(x.to(device)).argmax(dim=1).cpu()
            y_true.extend(y.cpu().tolist())
            y_pred.extend(p.tolist())
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(cm, display_labels=classnames)
    fig, ax = plt.subplots(figsize=(8,8))
    disp.plot(ax=ax, include_values=False, xticks_rotation=90, cmap='Blues')
    plt.title('Confusion Matrix (test)')
    plt.tight_layout(); plt.show()
except Exception as e:
    print('Install scikit-learn to see confusion matrix. Skipping. Error:', e)

### Notes
- The backbone is **frozen**; only the prompt parameters `ctx` (shape `[M, D]`) are trained.
- For MobileCLIP(2), keeping the model in **eval()** avoids BN updates.
- You can change `cfg.prompt_len` to explore different context lengths (e.g., 4, 8, 16).
- To compare with classic prompts, see `zeroshot_classifier()` helper.