# MobileCLIP2 + CoOp (Context Optimization)

This notebook starts from **MobileCLIP2** via `open_clip_torch` and adds a simple **Context Optimization (CoOp)** prompt learner, following the project's **STYLEGUIDE**.

**Notebook flow** (from STYLEGUIDE):
1. `!pip install` (only if needed)
2. Common imports + seeds
3. Load the model / pretrained model (+ preprocessors / related utilities)
4. Create own model module with clear separation of base vs add-ons
5. Hyperparameters section
6. Optimizer + Schedule
7. Train
8. Plot History
9. Evaluate + Show final result
10. Extra analysis and visualisation

> **References**: MobileCLIP2 in OpenCLIP (`MobileCLIP2-S{0,2,3,4}, B, L-14` with `pretrained='dfndr2b'`) and CoOp (Zhou et al.).


In [None]:
# 1) (Optional) Installs — uncomment on first run
# !pip install -U open-clip-torch timm torchvision torch --quiet
# !pip install -U scikit-learn tqdm matplotlib --quiet
print('If you need packages, uncomment the pip lines above and run this cell.')

In [None]:
# 2) Common imports + seeds
import math, random, os, time
from dataclasses import dataclass
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets

import matplotlib.pyplot as plt
from tqdm import tqdm

import open_clip

def set_seed(seed:int=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

In [None]:
# 3) Load MobileCLIP2 from OpenCLIP + transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Pick a lightweight variant for quick experiments
MODEL_NAME = 'MobileCLIP2-S0'   # alternatives: 'MobileCLIP2-S2', 'MobileCLIP2-B', 'MobileCLIP2-S3', 'MobileCLIP2-S4', 'MobileCLIP2-L-14'
PRETRAINED = 'dfndr2b'          # per MobileCLIP2 release in OpenCLIP

# Returns: (model, image_preprocess, text_preprocess)
model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
tokenizer = open_clip.get_tokenizer(MODEL_NAME)

# Important for MobileCLIP(2): keep eval for frozen BN layers; we will train only prompt params
model.eval()
model.to(device)

print('Loaded', MODEL_NAME, 'on', device)

In [None]:
# 4) Data: small ImageFolder or CIFAR-10 for a quick sanity check
#    Using preprocess from OpenCLIP to ensure normalization matches the backbone.

USE_CIFAR10 = True  # set False and edit DATA_ROOT to use your own ImageFolder
DATA_ROOT = './data'  # Only used if USE_CIFAR10=False; expects train/ and val/ subfolders
BATCH_SIZE = 64
NUM_WORKERS = 4

if USE_CIFAR10:
    train_raw = torchvision.datasets.CIFAR10(root=DATA_ROOT, train=True, download=True)
    test_raw  = torchvision.datasets.CIFAR10(root=DATA_ROOT, train=False, download=True)

    # Wrap datasets to apply OpenCLIP preprocess lazily
    class Preprocessed(torch.utils.data.Dataset):
        def __init__(self, base):
            self.base = base
            self.classes = base.classes if hasattr(base, 'classes') else base.targets
        def __len__(self):
            return len(self.base)
        def __getitem__(self, idx):
            img, y = self.base[idx]
            # preprocess: PIL -> tensor normalized to CLIP space
            return preprocess(img), y

    train_ds = Preprocessed(train_raw)
    val_ds   = Preprocessed(test_raw)  # use test as validation for demo
    test_ds  = Preprocessed(test_raw)
    classnames = train_raw.classes
else:
    # ImageFolder usage; structure:
    # DATA_ROOT/
    #   train/<class_name>/*.jpg
    #   val/<class_name>/*.jpg
    train_ds = datasets.ImageFolder(os.path.join(DATA_ROOT, 'train'), transform=preprocess)
    val_ds   = datasets.ImageFolder(os.path.join(DATA_ROOT, 'val'), transform=preprocess)
    test_ds  = val_ds
    classnames = train_ds.classes

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Classes ({len(classnames)}):", classnames[:10], '...')

## 4) Model module: base (frozen MobileCLIP2) + CoOp prompt learner

- We keep the CLIP backbone **frozen** and in **eval** mode (BatchNorm stability on MobileCLIP2). 
- We learn **M** context vectors `ctx` ("soft words") that prepend the class name tokens.
- Shapes are annotated per STYLEGUIDE.

In [None]:
class PromptLearner(nn.Module):
    def __init__(self, clip_model, tokenizer, classnames: List[str], n_ctx: int = 4, ctx_init: str = "a photo of"):
        """
        Context Optimization (CoOp) prompt learner (unified context).

        clip_model: OpenCLIP model with attributes used in encode_text
        tokenizer: open_clip tokenizer for the chosen model
        classnames: list[str] of dataset class names
        n_ctx: number of learnable context tokens (M)
        ctx_init: optional string to initialize context from (e.g., 'a photo of')
        """
        super().__init__()
        self.tokenizer = tokenizer
        self.clip = clip_model
        self.classnames = classnames
        self.n_ctx = n_ctx

        # Build tokenized prompts with room for M learnable vectors after the SOS token
        # Template: [SOS] + M * [CTX] + tokens(classname) + [EOS] + padding ... up to context_length
        # We'll override the embeddings at the M positions, not the token ids themselves.
        template = "{}"  # CoOp uses learnable context instead of a fixed template like 'a photo of a {}'

        prompts = [template.format(name.replace('_', ' ')) for name in classnames]
        tokenized = self.tokenizer(prompts)  # [C, L]
        self.register_buffer('prompts_tokens', tokenized, persistent=False)

        # Init context vectors in the same embedding space as CLIP's token embeddings
        ctx_dim = self.clip.text_projection.shape[1] if hasattr(self.clip, 'text_projection') else self.clip.text.text_projection.shape[1]
        # Better: use token embedding dim
        ctx_dim = self.clip.token_embedding.embedding_dim

        if ctx_init is not None and len(ctx_init) > 0:
            # Initialize from actual text, then take the first M token embeddings (after SOS)
            init_ids = self.tokenizer([ctx_init])  # [1, L]
            with torch.no_grad():
                init_emb = self.clip.token_embedding(init_ids.to(self.clip.token_embedding.weight.device))  # [1, L, D]
            # exclude first token ([SOS]) and take next M positions, pad if shorter
            ctx_init_vecs = torch.zeros(self.n_ctx, ctx_dim, device=init_emb.device)
            take = min(self.n_ctx, init_emb.shape[1]-1)
            if take > 0:
                ctx_init_vecs[:take] = init_emb[0, 1:1+take]
            self.ctx = nn.Parameter(ctx_init_vecs)  # [M, D]
        else:
            self.ctx = nn.Parameter(torch.randn(self.n_ctx, ctx_dim) * 0.02)  # [M, D]

    @torch.no_grad()
    def set_classnames(self, classnames: List[str]):
        self.classnames = classnames
        template = "{}"
        prompts = [template.format(name.replace('_',' ')) for name in classnames]
        tokenized = self.tokenizer(prompts)
        self.prompts_tokens = tokenized.to(self.prompts_tokens.device)

    def forward(self):
        """
        Returns the normalized text features for all classes using current context.
        Output shape: [C, D] where C=#classes, D=text feature dim.
        """
        tokens = self.prompts_tokens.to(self.clip.positional_embedding.device)  # [C, L]
        cast_dtype = self.clip.transformer.get_cast_dtype()

        # 1) Token Embeddings
        x = self.clip.token_embedding(tokens).to(cast_dtype)  # [C, L, D]

        # 2) Replace positions 1..M with learnable context (broadcast over classes)
        #    tokens dims: [C, L] ; x dims: [C, L, D] ; ctx: [M, D]
        x[:, 1:1+self.n_ctx, :] = self.ctx.unsqueeze(0).to(cast_dtype)  # [C, M, D]

        # 3) Add positional embedding and run transformer
        x = x + self.clip.positional_embedding.to(cast_dtype)  # [C, L, D]
        x = x.permute(1, 0, 2)  # [L, C, D]
        x = self.clip.transformer(x, attn_mask=self.clip.attn_mask)  # [L, C, D]
        x = x.permute(1, 0, 2)  # [C, L, D]
        x = self.clip.ln_final(x)  # [C, L, D]

        # 4) Take features at the EOT position (OpenCLIP uses argmax trick)
        eot = tokens.argmax(dim=-1)  # [C]
        x = x[torch.arange(x.shape[0]), eot]  # [C, D]
        x = x @ self.clip.text_projection  # [C, D_out]
        x = x / x.norm(dim=-1, keepdim=True)
        return x  # [C, D_out]

In [None]:
class PromptedMobileCLIP(nn.Module):
    def __init__(self, base_clip, tokenizer, classnames: List[str], n_ctx:int=4):
        super().__init__()
        self.base = base_clip
        for p in self.base.parameters():
            p.requires_grad = False  # freeze
        self.base.eval()  # BN stability

        self.prompt_learner = PromptLearner(self.base, tokenizer, classnames, n_ctx=n_ctx)

    def forward(self, images):
        """
        images: [B, 3, H, W] preprocessed for CLIP
        returns logits over classes: [B, C]
        """
        with torch.no_grad():
            img_feat = self.base.encode_image(images)  # [B, D]
            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)

        txt_feat = self.prompt_learner()              # [C, D]
        logit_scale = self.base.logit_scale.exp()
        logits = logit_scale * img_feat @ txt_feat.t()  # [B, C]
        return logits

    @torch.no_grad()
    def zeroshot_classifier(self, classnames: List[str], template: str = "a photo of a {}"):
        # Helper to compare CoOp vs classic hand-crafted prompt
        prompts = [template.format(c.replace('_',' ')) for c in classnames]
        tokens = tokenizer(prompts).to(self.base.positional_embedding.device)
        feat = self.base.encode_text(tokens)
        return feat / feat.norm(dim=-1, keepdim=True)

In [None]:
# 5) Hyperparameters (Config)
from dataclasses import dataclass

@dataclass
class Config:
    dataset_name: str = 'CIFAR-10'
    batch_size: int = BATCH_SIZE
    num_workers: int = NUM_WORKERS
    max_epochs: int = 3
    lr_prompt: float = 5e-3   # typically higher LR for prompt params
    weight_decay: float = 0.0 # prompt params usually without WD
    prompt_len: int = 4       # CoOp: number of learnable context tokens

cfg = Config()
cfg

In [None]:
# 6) Optimizer (prompt params only)
model_coop = PromptedMobileCLIP(model, tokenizer, classnames, n_ctx=cfg.prompt_len).to(device)
opt = torch.optim.AdamW(model_coop.prompt_learner.parameters(), lr=cfg.lr_prompt, weight_decay=cfg.weight_decay)
loss_fn = nn.CrossEntropyLoss()

In [None]:
# 7) Train loop (EMA of loss in progress bar)
def evaluate(model, loader, device, loss_fn, desc="eval", pbar=True):
    model.eval()
    loss_sum, correct, count = 0.0, 0, 0
    it = tqdm(loader, desc=desc, leave=False) if pbar else loader
    with torch.no_grad():
        for data, target in it:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(data)
            loss = loss_fn(output, target)
            bs = data.size(0)
            loss_sum += loss.item() * bs
            preds = output.argmax(dim=1)
            correct += (preds == target).sum().item()
            count += bs
            if pbar:
                acc = (correct / max(1, count)) * 100.0
                it.set_postfix(loss=f"{loss.item():.4f}", acc=f"{acc:.2f}%")
    return loss_sum / max(1, count), correct / max(1, count)

def train(model, train_loader, val_loader, optimizer, device, epochs:int, log_interval:int=25, ema_alpha:float=0.1):
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    for epoch in range(1, epochs + 1):
        model.train()
        train_sum, train_correct, train_count = 0.0, 0, 0
        ema = None
        iterator = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} | train", leave=False, dynamic_ncols=True, mininterval=0.5)
        for step, (data, target) in enumerate(iterator, start=1):
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            output = model(data)                     # [B, C]
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
            # Stats
            bs = data.size(0)
            train_sum += loss.item() * bs
            preds = output.argmax(dim=1)
            train_correct += (preds == target).sum().item()
            train_count += bs
            ema = loss.item() if ema is None else (1-ema_alpha)*ema + ema_alpha*loss.item()
            if (step % log_interval == 0) or (step == len(iterator)):
                acc_pct = 100.0 * train_correct / max(1, train_count)
                iterator.set_postfix_str(f"loss(ema)={ema:.4f} acc={acc_pct:.2f}%")
        avg_train_loss = train_sum / max(1, train_count)
        avg_train_acc = train_correct / max(1, train_count)
        avg_val_loss, avg_val_acc = evaluate(model, val_loader, device, loss_fn=loss_fn, desc="valid", pbar=True)
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(avg_train_acc)
        history['val_acc'].append(avg_val_acc)
        tqdm.write(
            f"Epoch {epoch:03d}: train_loss={avg_train_loss:.4f} train_acc={avg_train_acc*100:.2f}% "
            f"val_loss={avg_val_loss:.4f} val_acc={avg_val_acc*100:.2f}%"
        )
    return history

history = train(model_coop, train_loader, val_loader, opt, device, epochs=cfg.max_epochs)

In [None]:
# 8) Plot history (loss + accuracy)
def plot_history(history):
    epochs = range(1, len(history['train_loss']) + 1)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4), constrained_layout=True)
    ax = axes[0]
    ax.plot(epochs, history['train_loss'], label='train')
    ax.plot(epochs, history['val_loss'], label='val')
    ax.set_xlabel('Epoch') ; ax.set_ylabel('Loss') ; ax.set_title('Loss') ; ax.legend()
    ax = axes[1]
    if 'train_acc' in history and 'val_acc' in history:
        ax.plot(epochs, [x * 100 for x in history['train_acc']], label='train')
        ax.plot(epochs, [x * 100 for x in history['val_acc']], label='val')
        ax.set_ylabel('Accuracy (%)')
    else:
        ax.text(0.5, 0.5, 'No accuracy in history', ha='center', va='center', transform=ax.transAxes)
    ax.set_xlabel('Epoch') ; ax.set_title('Accuracy') ; ax.legend()
    plt.show()

plot_history(history)

In [None]:
# 9) Evaluate on test set & show a few predictions
test_loss, test_acc = evaluate(model_coop, test_loader, device, loss_fn, desc='test')
print(f"Test: loss={test_loss:.4f} acc={test_acc*100:.2f}%")

from itertools import islice
model_coop.eval()
samples = list(islice(iter(test_loader), 1))[0]
imgs, labels = samples[0].to(device), samples[1]
with torch.no_grad():
    logits = model_coop(imgs)
preds = logits.argmax(dim=1).cpu()

plt.figure(figsize=(10,3))
for i in range(8):
    plt.subplot(2,4,i+1)
    # un-normalize for visualization (approx):
    img = imgs[i].cpu().permute(1,2,0).float()
    img = (img - img.min()) / (img.max() - img.min() + 1e-6)
    plt.imshow(img)
    plt.title(f"gt:{classnames[labels[i]]}\npred:{classnames[preds[i]]}")
    plt.axis('off')
plt.tight_layout(); plt.show()

In [None]:
# 10) Extra analysis: (optional) confusion matrix if sklearn is available
try:
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    y_true, y_pred = [], []
    with torch.no_grad():
        for x, y in tqdm(test_loader, leave=False, desc='cm'):
            p = model_coop(x.to(device)).argmax(dim=1).cpu()
            y_true.extend(y.cpu().tolist())
            y_pred.extend(p.tolist())
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(cm, display_labels=classnames)
    fig, ax = plt.subplots(figsize=(8,8))
    disp.plot(ax=ax, include_values=False, xticks_rotation=90, cmap='Blues')
    plt.title('Confusion Matrix (test)')
    plt.tight_layout(); plt.show()
except Exception as e:
    print('Install scikit-learn to see confusion matrix. Skipping. Error:', e)

### Notes
- The backbone is **frozen**; only the prompt parameters `ctx` (shape `[M, D]`) are trained.
- For MobileCLIP(2), keeping the model in **eval()** avoids BN updates.
- You can change `cfg.prompt_len` to explore different context lengths (e.g., 4, 8, 16).
- To compare with classic prompts, see `zeroshot_classifier()` helper.