# CLIP with Flowers!?!?!??!?

In [1]:
import os
# Prefer expandable segments to reduce fragmentation (restart kernel after changing)
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

import sys
import torch
import torchvision

# Ensure CLIP is installed in the current kernel; install if missing.
# Using subprocess with sys.executable to target the same Python interpreter.
try:
    import clip
except Exception:
    import subprocess, importlib
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "git+https://github.com/openai/CLIP.git"], stdout=subprocess.DEVNULL)
    importlib.invalidate_caches()
    import clip

from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler
from collections import OrderedDict


In [2]:
# Check environment and CLIP installation
import sys, importlib, torch

try:
    import clip
    print("CLIP: già installato")
except Exception:
    print("CLIP non trovato: eseguo installazione nel kernel corrente...")
    import importlib
    # Preferisci %pip per installare nel kernel Jupyter corrente; fallback a subprocess
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except Exception:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "git+https://github.com/openai/CLIP.git"])
    importlib.invalidate_caches()
    import clip
    print("CLIP installato correttamente nel kernel corrente")

print("python:", sys.executable)
print("torch:", torch.__version__, "cuda_available:", torch.cuda.is_available())
# su mac con Apple Silicon, controlla MPS
try:
    print("mps_available:", torch.backends.mps.is_available())
except Exception:
    pass

print('Riavvia il kernel se necessario, poi esegui questa cella e procedi con l\'allenamento CoCoOp.')


CLIP: già installato
python: /usr/bin/python3
torch: 2.9.0+cu126 cuda_available: True
mps_available: False
Riavvia il kernel se necessario, poi esegui questa cella e procedi con l'allenamento CoCoOp.


## Dataset Functions

We define utility functions for:
- **`get_data()`**: Load Flowers102 from torchvision
- **`base_novel_categories()`**: Split 102 classes into base (0-50) and novel (51-101)
- **`split_data()`**: Filter images for base/novel in each split

This simulates the real scenario: we have 51 seen classes during training (base) and 51 new ones (novel).


In [3]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets."""
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test


def base_novel_categories(dataset):
    """Return base and novel class id lists using the actual labels present
    in the dataset. Prefer public attributes (`targets` then `labels`) and
    only fall back to the dataset private attribute `_labels` if neither is
    available.
    """
    labels = getattr(dataset, "targets", None)
    if labels is None:
        labels = getattr(dataset, "labels", None)

    if labels is None and hasattr(dataset, "_labels"):
        # FALLBACK: using private dataset internals. Flowers102 exposes
        # `_labels` but this is a private attribute; prefer public attributes
        # above so future datasets remain compatible.
        labels = dataset._labels

    if labels is None:
        raise ValueError("Could not find labels on dataset (checked 'targets','labels','_labels').")

    unique_labels = sorted(set(labels))
    num_classes = len(unique_labels)
    mid = num_classes // 2
    base_classes = unique_labels[:mid]
    novel_classes = unique_labels[mid:]
    return base_classes, novel_classes


def split_data(dataset, base_classes):
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Class Names and Dataset Loading

We load the names of 102 flower classes from Flowers102.

This is **critical** for CLIP:
- Creates prompts like "a photo of a **rose**, a type of flower"
- Each prompt is encoded by CLIP's text encoder
- Image features are compared against these text templates


In [4]:
_, _, tmp_test = get_data()
base_classes, novel_classes = base_novel_categories(tmp_test)

CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# Uncomment to see class names
# print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
# print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

In [5]:
# Load CLIP model and preprocessing
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)
print(f"Device: {device}")
print(f"Model: ViT-B/16")

Device: cuda
Model: ViT-B/16


## Load Flowers102 and Split Base/Novel

We load the 3 splits (train, val, test) and divide into base/novel.

**Statistics:**
- Train Base: 10 images × 51 classes = 510 images
- Val Base: 10 images × 51 classes = 510 images
- Test Base: ~10 images × 51 classes (from test split)
- Test Novel: Remaining (~10 per class)

**Note:** Train and val have ~10 images per class (few-shot setting).


In [6]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# Few-shot: sample `shots_per_class` images per base class from the train split
shots_per_class = 16
import random
random.seed(42)

# Collect indices per class in the original train_set
indices_per_class = {c: [] for c in base_classes}
for idx, label in enumerate(train_set._labels):
    if label in indices_per_class:
        indices_per_class[label].append(idx)

selected = []
for c in base_classes:
    inds = indices_per_class.get(c, [])
    random.shuffle(inds)
    # take up to shots_per_class (if fewer available, take all)
    selected.extend(inds[:shots_per_class])

# Create the few-shot training subset
train_base = torch.utils.data.Subset(train_set, selected)

# validation and test splits remain full (or filtered by base classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

print(f"Train Base (few-shot): {len(train_base)} samples ({shots_per_class} shots per class)")
print(f"Val Base: {len(val_base)} samples")
print(f"Test Base: {len(test_base)} samples")
print(f"Test Novel: {len(test_novel)} samples")

Train Base (few-shot): 510 samples (16 shots per class)
Val Base: 510 samples
Test Base: 2473 samples
Test Novel: 3676 samples


## Harmonic Mean (HM)

Standard metric for few-shot adaptation papers.

Formula: HM = 2 / (1/base_acc + 1/novel_acc)

**Why HM instead of arithmetic mean?**
- HM heavily penalizes outliers
- If base=90% and novel=50%: arithmetic mean=70%, HM=64.3%
- Forces the model to balance both accuracies

**Obiettivo:** massimizzare l'HM tra `base_acc_cocoop` e `novel_acc_cocoop`.


In [None]:
def harmonic_mean(base_accuracy, novel_accuracy):
    # Guard against zero to avoid division-by-zero errors
    if base_accuracy <= 0 or novel_accuracy <= 0:
        return 0.0
    numerator = 2.0
    denominator = 1.0 / base_accuracy + 1.0 / novel_accuracy
    return numerator / denominator


## Text Encoder

In [8]:
class TextEncoder(nn.Module):
    """Encodes soft prompts through CLIP's text transformer."""
    
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        """
        Args:
            prompts: (batch_size, n_tokens, 512) soft prompt embeddings
            tokenized_prompts: (n_cls, n_tokens) token indices for EOT detection
        
        Returns:
            text_features: (batch_size, 512) per-class text features
        """
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        
        # Extract EOT token
        x = x[torch.arange(x.shape), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        
        return x

## MetaNetwork: Conditional Token Generator

**Problem:** Fixed prompts don't adapt to each image.

**Solution:** A small neural network that transforms image features into a conditional token.

**Parameters:** ~256K (negligible vs. fine-tuning)

**Effect:** Each image gets a different prompt → instance-level adaptation


In [10]:
"""
MetaNetwork è una piccola rete neurale (MLP con 2 layer)
che trasforma le image_features (512-dim) in un token
condizionale (512-dim) usato in CoCoOp.

Questo token varia per ogni immagine, permettendo prompt
personalizzati per ogni input.
"""
\
class MetaNetwork(nn.Module):
    def __init__(self, ctx_dim=512, hidden_dim=256):
        """
        Args:
            ctx_dim: dimensione degli embeddings (512 per ViT-B/16)
            hidden_dim: dimensione dello strato nascosto
        """
        super().__init__()
        self.linear1 = nn.Linear(ctx_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(hidden_dim, ctx_dim)

    def forward(self, image_features):
        """
        Args:
            image_features: tensor (B, ctx_dim) dalle immagini encodate

        Returns:
            conditional_token: tensor (B, ctx_dim)
        """
        # Assicura il tipo corretto (importante per mixed precision)
        image_features = image_features.to(self.linear1.weight.dtype)

        out = self.linear1(image_features)
        out = self.relu(out)
        out = self.linear2(out)
        return out


## CoCoOpPromptLearner: Dynamic Prompts


**Components:**
1. **V1...VM:** 16 context vectors (learned via SGD)
   - Shape: (16, 512) tensors
   - Initialized randomly from N(0, 0.02²)
   - Optimized during training

2. **π(x):** Conditional token (generated per image)
   - Shape: (B, 512) from MetaNetwork output
   - Different for each image

3. **[CLASS]:** Class name embedding
   - Shape: (seq_len, 512) from CLIP's token embedding
   - Same for all images of the same class

**Forward Pass:**
- Input: image_features (B, 512)
- Output: prompts (B, num_classes, seq_len_total, 512)


In [None]:
class CoCoOpPromptLearner(nn.Module):
    """
    Learnable prompt module with meta-net for instance-conditioned adaptation.
    
    Based on CoCoOp: https://arxiv.org/abs/2203.05557
    """
    
    def __init__(self, clip_model, classnames, n_ctx=4, ctx_init=None, device='cuda'):
        super().__init__()
        
        n_cls = len(classnames)
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape  # 512 per ViT-B/16
        vis_dim = clip_model.visual.output_dim  # 512 per ViT-B/16
        
        # ✅ CONTEXT VECTORS (learnable)
        if ctx_init:
            # Initialize from provided text
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1:1+n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # Random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=torch.float32)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        
        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words: {n_ctx}")
        
        self.ctx = nn.Parameter(ctx_vectors)
        
        # ✅ META NETWORK (corrected architecture da GitHub)
        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
        ]))
        
        # ✅ CLASS EMBEDDINGS
        classnames = [name.replace("_", " ") for name in classnames]
        from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
        _tokenizer = _Tokenizer()
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
        
        # Register buffers for prefix and suffix tokens
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :])  # CLS, EOS
        
        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.dtype = dtype
    
    def construct_prompts(self, ctx, prefix, suffix, label=None):
        """Construct prompt embeddings from context and class-specific tokens."""
        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]
        
        prompts = torch.cat([
            prefix,  # (dim0, 1, dim)
            ctx,     # (dim0, n_ctx, dim)
            suffix,  # (dim0, *, dim)
        ], dim=1)
        
        return prompts
    
    def forward(self, im_features):
        """
        Args:
            im_features: (batch_size, 512) image features from CLIP's ViT
        
        Returns:
            prompts: (batch_size, n_cls, n_tokens, 512) soft prompts for all classes
        """
        prefix = self.token_prefix
        suffix = self.token_suffix
        
        # ✅ INSTANCE-CONDITIONED CONTEXT
        ctx = self.ctx.unsqueeze(0)  # (1, n_ctx, ctx_dim)
        bias = self.meta_net(im_features)  # (batch_size, ctx_dim)
        bias = bias.unsqueeze(1)  # (batch_size, 1, ctx_dim)
        
        # ✅ SHIFT CONTEXT per image
        ctx_shifted = ctx + bias  # (batch_size, n_ctx, ctx_dim)
        
        # Build prompts for each image and each class
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            # Replicate across classes
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)  # (n_cls, n_tokens, ctx_dim)
            prompts.append(pts_i)
        
        prompts = torch.stack(prompts)  # (batch_size, n_cls, n_tokens, ctx_dim)
        
        return prompts

## CoCoOpTrainer: Training and Evaluation

Class that manages:

**1. Initialization:**
- Create PromptLearner
- Freeze CLIP (`requires_grad=False`)
- Configure SGD optimizer for prompt learner only

**2. train_epoch():**
- Forward: Image encoder + PromptLearner + Text encoder
- **Critical step:** Encode soft prompts through text transformer
  - Add positional embeddings
  - Pass through CLIP's transformer
  - Extract first token
  - Apply final layer norm + projection
- Compute loss: Cross-entropy on base classes
- Backward: Backprop only in PromptLearner
- Return: Average loss of the epoch

**3. eval():**
- Same forward procedure as training
- Without backward pass
- Compute accuracy on any dataset (base or novel)

**Important note:** We don't use `model.encode_text()` on soft prompts
because that method expects integer tokens, not embeddings.
We manually forward through the text transformer.

In [25]:
class CoCoOpPromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=16, ctx_dim=512, device='cuda'):
        super().__init__()
        self.n_ctx = n_ctx
        self.ctx_dim = ctx_dim
        self.device = device

        # ✅ LEARNABLE CONTEXT
        self.ctx = nn.Parameter(
            torch.randn(n_ctx, ctx_dim).float() * 0.02
        )

        # ✅ META NETWORK
        self.meta_net = nn.Sequential(
            nn.Linear(ctx_dim, ctx_dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(ctx_dim // 2, ctx_dim)
        )

        # ✅ CLASS EMBEDDINGS - SEMPLICE E DIRETTO
        # Costruisci gli embeddings SENZA tokenizer esterno
        class_embs = []
        for classname in classnames:
            # Crea il testo
            text = f"a photo of {classname}"

            # ✅ TOKENIZZA MANUALMENTE usando il tokenizer di CLIP interno
            # Importa il tokenizer
            from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
            _tokenizer = _Tokenizer()

            # Tokenizza il testo
            tokens = _tokenizer.encode(text)
            # Padding a 77 token
            tokens = tokens + [0] * (77 - len(tokens))
            tokens = torch.tensor([tokens[:77]]).to(device)

            # Estrai l'embedding
            with torch.no_grad():
                emb = clip_model.token_embedding(tokens).type(torch.float32)  # (1, 77, 512)
            class_embs.append(emb)

        # Stack: (num_classes, 77, 512)
        class_token_embeddings = torch.cat(class_embs, dim=0)

        self.register_buffer("class_token_embeddings", class_token_embeddings)

        self.clip_context_length = 77

    def forward(self, image_features):
        """
        image_features: (B, 512)
        Output: (B, num_classes, 77, 512)
        """
        batch_size = image_features.shape[0]
        num_classes = self.class_token_embeddings.shape[0]

        # Meta network - genera condizionamenti
        cond = self.meta_net(image_features)  # (B, 512)

        # Context base
        ctx = self.ctx.unsqueeze(0)  # (1, n_ctx, 512)
        ctx = ctx.repeat(batch_size, 1, 1)  # (B, n_ctx, 512)

        # Modula context con meta_net
        ctx = ctx + cond.unsqueeze(1) * 0.1  # (B, n_ctx, 512)

        # Class embeddings
        class_embed = self.class_token_embeddings.unsqueeze(0)  # (1, num_classes, 77, 512)
        class_embed = class_embed.repeat(batch_size, 1, 1, 1)  # (B, num_classes, 77, 512)

        # Build prompts per ogni classe
        prompts_list = []
        for i in range(num_classes):
            # Prendi i token significativi dalla class embedding
            prompt_i = torch.cat([
                ctx,  # (B, n_ctx, 512)
                class_embed[:, i, 1:self.clip_context_length-self.n_ctx, :]  # (B, 77-n_ctx-1, 512)
            ], dim=1)  # (B, 77, 512)

            # Garantisci che sia esattamente 77 token
            if prompt_i.shape[1] < self.clip_context_length:
                pad_len = self.clip_context_length - prompt_i.shape[1]
                prompt_i = torch.cat([
                    prompt_i,
                    torch.zeros(batch_size, pad_len, self.ctx_dim, device=image_features.device, dtype=torch.float32)
                ], dim=1)
            elif prompt_i.shape[1] > self.clip_context_length:
                prompt_i = prompt_i[:, :self.clip_context_length, :]

            prompts_list.append(prompt_i)

        prompts = torch.stack(prompts_list, dim=1)  # (B, num_classes, 77, 512)

        return prompts.to(dtype=torch.float32)


class CoCoOpTrainer:
    """
    Trainer for CoCoOp with corrected hyperparameters and scheduler.
    """
    
    def __init__(self, clip_model, classnames, base_classes, novel_classes, 
                 device='cuda', lr=0.002, n_ctx=4, num_epochs=10):
        """
        ✅ CORRECTED hyperparameters
        
        Args:
            lr: 0.002 (NOT 0.02)
            num_epochs: for scheduler initialization
        """
        self.clip_model = clip_model.float()
        self.classnames = classnames
        self.base_classes = base_classes
        self.novel_classes = novel_classes
        self.device = device
        self.num_epochs = num_epochs
        
        # Mapping from class ID to index
        self.contig_cat2idx = {cat: idx for idx, cat in enumerate(self.base_classes)}
        
        # Freeze CLIP
        for p in self.clip_model.parameters():
            p.requires_grad = False
        
        # Create model
        self.model = CustomCLIP(self.clip_model, classnames, n_ctx=n_ctx, 
                               device=device).to(device)
        
        # Only optimize prompt learner
        self.optimizer = torch.optim.SGD(
            self.model.prompt_learner.parameters(),
            lr=lr,
            momentum=0.9,
            weight_decay=5e-4
        )
        
        # ✅ SCHEDULER INITIALIZED HERE
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs)
        
        print(f"\n{'='*70}")
        print(f"CoCoOpTrainer initialized:")
        print(f"  LR: {lr}")
        print(f"  Momentum: 0.9")
        print(f"  Weight decay: 5e-4")
        print(f"  Scheduler: CosineAnnealing (T_max={num_epochs})")
        print(f"  Trainable params: {sum(p.numel() for p in self.model.prompt_learner.parameters() if p.requires_grad):,}")
        print(f"{'='*70}\n")
    
    def train_epoch(self, train_dataset, batch_size=4):
        """Train for one epoch."""
        self.model.train()
        self.clip_model.eval()
        
        dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=False
        )
        
        total_loss = 0
        n_batches = 0
        
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc="Training")):
            images = images.to(self.device).float()
            labels = labels.to(self.device)
            
            # ✅ Map class labels to contiguous indices
            labels_mapped = torch.tensor(
                [self.contig_cat2idx[l.item()] for l in labels],
                dtype=torch.long,
                device=self.device
            )
            
            # Forward pass
            loss = self.model(images, labels_mapped)
            
            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
            
            # Debug: print gradients on first batch
            if batch_idx == 0:
                print(f"\n{'='*70}")
                print("GRADIENT CHECK (first batch):")
                for name, param in self.model.prompt_learner.named_parameters():
                    if param.grad is not None:
                        grad_norm = param.grad.norm().item()
                        print(f"  {name:40s} | grad_norm: {grad_norm:.8f}")
                print(f"{'='*70}\n")
        
        # Step scheduler
        self.scheduler.step()
        
        return total_loss / max(1, n_batches)
    
    @torch.no_grad()
    def eval(self, dataset, categories, batch_size=64):
        """
        ✅ CORRECTED SIGNATURE: removed 'classnames' parameter
        """
        self.model.eval()
        self.clip_model.eval()
        
        contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
        
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )
        
        correct = 0
        total = 0
        
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(self.device).float()
            labels = labels.to(self.device)
            
            # Get logits
            logits = self.model(images)  # (batch_size, n_cls)
            pred = logits.argmax(dim=1)
            
            # Map labels
            labels_mapped = torch.tensor(
                [contig_cat2idx[l.item()] for l in labels],
                dtype=torch.long,
                device=self.device
            )
            
            correct += (pred == labels_mapped).sum().item()
            total += labels.size(0)
        
        accuracy = correct / total if total > 0 else 0.0
        return accuracy

In [29]:
class CustomCLIP(nn.Module):
    """CLIP model with learnable prompts."""
    
    def __init__(self, clip_model, classnames, n_ctx=4, ctx_init=None, device='cuda'):
        super().__init__()
        
        self.prompt_learner = CoCoOpPromptLearner(clip_model, classnames, 
                                   n_ctx=n_ctx,
                                   device=device)

        
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
    
    def forward(self, image, label=None):
        """
        Args:
            image: (batch_size, 3, 224, 224)
            label: (batch_size,) class labels during training
        
        Returns:
            logits: (batch_size, n_cls) or loss if label provided
        """
        logit_scale = self.logit_scale.exp()
        
        # ✅ Encode images
        image_features = self.image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        # ✅ Generate instance-conditioned prompts
        prompts = self.prompt_learner(image_features)  # (batch, n_cls, n_tokens, dim)
        
        # ✅ Encode prompts and compute logits
        batch_size = prompts.shape
        logits = []
        
        for i in range(batch_size):
            # Prompts for this image: (n_cls, n_tokens, dim)
            pts_i = prompts[i]
            # Image feature for this image: (dim,)
            imf_i = image_features[i]
            
            # Encode prompts through text transformer
            text_features = self.text_encoder(pts_i, self.tokenized_prompts)
            # text_features: (n_cls, dim)
            
            # ✅ Normalize and compute similarity
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            l_i = logit_scale * (imf_i @ text_features.t())  # (n_cls,)
            logits.append(l_i)
        
        logits = torch.stack(logits)  # (batch_size, n_cls)
        
        # Return loss during training, logits during evaluation
        if label is not None:
            return F.cross_entropy(logits, label)
        
        return logits


## Training CoCoOp

We will train the PromptLearner for **5 epochs** on **base classes only**.

**Hyperparameters:**
- Learning rate: 0.002 (SGD)
- Momentum: 0.9
- Weight decay: 5e-4
- Batch size: 1
- Epochs: 5

**What happens:**
- Context vectors V1...VM adapt to the Flowers102 dataset
- MetaNetwork learns to generate useful conditional tokens
- CLIP remains frozen (unchanged)

**Expected output:**
- Initial loss: ~3.0
- Final loss: ~1.3-1.5
- Training time: ~5-10 minutes on GPU

In [30]:
def harmonic_mean(base_acc, novel_acc):
    """Compute harmonic mean of accuracies."""
    if base_acc <= 0 or novel_acc <= 0:
        return 0.0
    return 2.0 / (1.0 / base_acc + 1.0 / novel_acc)


# Setup
base_classnames = [CLASS_NAMES[i] for i in base_classes]
novel_classnames = [CLASS_NAMES[i] for i in novel_classes]

# Initialize trainer
trainer = CoCoOpTrainer(
    clip_model=model,
    classnames=base_classnames,
    base_classes=base_classes,
    novel_classes=novel_classes,
    device=device,
    lr=0.002,  # ✅ CORRECTED (was 0.02)
    n_ctx=4,
    num_epochs=10
)

# Train
print("\n" + "="*70)
print("TRAINING CoCoOp")
print("="*70)

num_epochs = 10
for epoch in range(num_epochs):
    avg_loss = trainer.train_epoch(train_base, batch_size=4)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

print("\n" + "="*70)
print("TRAINING COMPLETED")
print("="*70)

TypeError: empty(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got torch.Size"

## Final Evaluation (CoCoOp only)

We'll evaluate the model with:
1. Test Base
2. Test Novel

Computing Harmonic Mean between them to evaluate the trade-off.


In [None]:
print("\n" + "="*70)
print("EVALUATION")
print("="*70)

# ✅ CORRECTED eval() call (no 'classnames' parameter)
base_acc = trainer.eval(test_base, base_classes, batch_size=64)
novel_acc = trainer.eval(test_novel, novel_classes, batch_size=64)
hm = harmonic_mean(base_acc, novel_acc)

print("\n" + "="*70)
print("RESULTS")
print("="*70)
print(f"  Base Accuracy:  {base_acc*100:6.2f}%")
print(f"  Novel Accuracy: {novel_acc*100:6.2f}%")
print(f"  Harmonic Mean:  {hm*100:6.2f}%")
print("="*70)

Base classnames (51): ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells']...
Novel classnames (51): ['wild pansy', 'primula', 'sunflower']...

EVALUATION


Eval: 100%|██████████| 39/39 [05:06<00:00,  7.87s/it]
Eval: 100%|██████████| 58/58 [07:39<00:00,  7.92s/it]


CoCoOp RESULTS

 Base Accuracy:    0.81%
 Novel Accuracy:   1.77%
 Harmonic Mean:    1.11%




