# CLIP with Flowers!?!?!??!?

In [1]:
import os
# Prefer expandable segments to reduce fragmentation (restart kernel after changing)
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

import sys
import torch
import torchvision

# Ensure CLIP is installed in the current kernel; install if missing.
# Using subprocess with sys.executable to target the same Python interpreter.
try:
    import clip
except Exception:
    import subprocess, importlib
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "git+https://github.com/openai/CLIP.git"], stdout=subprocess.DEVNULL)
    importlib.invalidate_caches()
    import clip

from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-hi6qq6w8
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-hi6qq6w8
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.8/44.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=311d8475af11f48b8839f67da3c0fc933775f22bde09b287b29ca4fc9a46ad7b
  Stored in 

In [2]:
# Check environment and CLIP installation
import sys, importlib, torch

try:
    import clip
    print("CLIP: gi√† installato")
except Exception:
    print("CLIP non trovato: eseguo installazione nel kernel corrente...")
    import importlib
    # Preferisci %pip per installare nel kernel Jupyter corrente; fallback a subprocess
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except Exception:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "git+https://github.com/openai/CLIP.git"])
    importlib.invalidate_caches()
    import clip
    print("CLIP installato correttamente nel kernel corrente")

print("python:", sys.executable)
print("torch:", torch.__version__, "cuda_available:", torch.cuda.is_available())
# su mac con Apple Silicon, controlla MPS
try:
    print("mps_available:", torch.backends.mps.is_available())
except Exception:
    pass

print('Riavvia il kernel se necessario, poi esegui questa cella e procedi con l\'allenamento CoCoOp.')


CLIP: gi√† installato
python: /usr/bin/python3
torch: 2.9.0+cpu cuda_available: False
mps_available: False
Riavvia il kernel se necessario, poi esegui questa cella e procedi con l'allenamento CoCoOp.


## Dataset Functions

We define utility functions for:
- **`get_data()`**: Load Flowers102 from torchvision
- **`base_novel_categories()`**: Split 102 classes into base (0-50) and novel (51-101)
- **`split_data()`**: Filter images for base/novel in each split

This simulates the real scenario: we have 51 seen classes during training (base) and 51 new ones (novel).


In [3]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets."""
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

def base_novel_categories(dataset):
    all_classes = set(dataset._labels)
    num_classes = len(all_classes)
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

def split_data(dataset, base_classes):
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Class Names and Dataset Loading

We load the names of 102 flower classes from Flowers102.

This is **critical** for CLIP:
- Creates prompts like "a photo of a **rose**, a type of flower"
- Each prompt is encoded by CLIP's text encoder
- Image features are compared against these text templates


In [4]:
_, _, tmp_test = get_data()
base_classes, novel_classes = base_novel_categories(tmp_test)

CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# Uncomment to see class names
# print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
# print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 345M/345M [00:09<00:00, 36.8MB/s] 
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 502/502 [00:00<00:00, 1.11MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15.0k/15.0k [00:00<00:00, 22.8MB/s]


In [5]:
# Load CLIP model and preprocessing
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)
print(f"Device: {device}")
print(f"Model: ViT-B/16")

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 335M/335M [00:04<00:00, 81.8MiB/s]


Device: cpu
Model: ViT-B/16


## Load Flowers102 and Split Base/Novel

We load the 3 splits (train, val, test) and divide into base/novel.

**Statistics:**
- Train Base: 10 images √ó 51 classes = 510 images
- Val Base: 10 images √ó 51 classes = 510 images
- Test Base: ~10 images √ó 51 classes (from test split)
- Test Novel: Remaining (~10 per class)

**Note:** Train and val have ~10 images per class (few-shot setting).


In [6]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

print(f"Train Base: {len(train_base)} samples")
print(f"Val Base: {len(val_base)} samples")
print(f"Test Base: {len(test_base)} samples")
print(f"Test Novel: {len(test_novel)} samples")

Train Base: 510 samples
Val Base: 510 samples
Test Base: 2473 samples
Test Novel: 3676 samples


## (Sezione rimossa) Zero-Shot CLIP Evaluation

Questa sezione √® stata rimossa: il notebook esegue ora esclusivamente il workflow CoCoOp (prompt learning + MetaNetwork).

Se vuoi riattivare la valutazione zero-shot pi√π tardi, esegui una cella separata che costruisca i prompt fissi e chiami `model.encode_text()`.


In [7]:
# Zero-shot evaluation removed ‚Äî notebook runs CoCoOp only now.
# Se serve, riattivare la valutazione zero-shot separatamente.
print("Zero-shot evaluation removed. Proceeding with CoCoOp-only workflow.")


Zero-shot evaluation removed. Proceeding with CoCoOp-only workflow.


## Harmonic Mean (HM)

Standard metric for few-shot adaptation papers.

Formula: HM = 2 / (1/base_acc + 1/novel_acc)

**Why HM instead of arithmetic mean?**
- HM heavily penalizes outliers
- If base=90% and novel=50%: arithmetic mean=70%, HM=64.3%
- Forces the model to balance both accuracies

**Obiettivo:** massimizzare l'HM tra `base_acc_cocoop` e `novel_acc_cocoop`.


In [8]:
def harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    return numerator / denominator


## MetaNetwork: Conditional Token Generator

**Problem:** Fixed prompts don't adapt to each image.

**Solution:** A small neural network that transforms image features into a conditional token.

**Parameters:** ~256K (negligible vs. fine-tuning)

**Effect:** Each image gets a different prompt ‚Üí instance-level adaptation


In [9]:
"""
MetaNetwork √® una piccola rete neurale (MLP con 2 layer)
che trasforma le image_features (512-dim) in un token
condizionale (512-dim) usato in CoCoOp.

Questo token varia per ogni immagine, permettendo prompt
personalizzati per ogni input.
"""

class MetaNetwork(nn.Module):
    def __init__(self, ctx_dim=512, hidden_dim=256):
        """
        Args:
            ctx_dim: dimensione degli embeddings (512 per ViT-B/16)
            hidden_dim: dimensione dello strato nascosto
        """
        super().__init__()
        self.linear1 = nn.Linear(ctx_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(hidden_dim, ctx_dim)

    def forward(self, image_features):
        """
        Args:
            image_features: tensor (B, ctx_dim) dalle immagini encodate

        Returns:
            conditional_token: tensor (B, ctx_dim)
        """
        # Assicura il tipo corretto (importante per mixed precision)
        image_features = image_features.to(self.linear1.weight.dtype)

        out = self.linear1(image_features)
        out = self.relu(out)
        out = self.linear2(out)
        return out


## CoCoOpPromptLearner: Dynamic Prompts


**Components:**
1. **V1...VM:** 16 context vectors (learned via SGD)
   - Shape: (16, 512) tensors
   - Initialized randomly from N(0, 0.02¬≤)
   - Optimized during training

2. **œÄ(x):** Conditional token (generated per image)
   - Shape: (B, 512) from MetaNetwork output
   - Different for each image

3. **[CLASS]:** Class name embedding
   - Shape: (seq_len, 512) from CLIP's token embedding
   - Same for all images of the same class

**Forward Pass:**
- Input: image_features (B, 512)
- Output: prompts (B, num_classes, seq_len_total, 512)


In [10]:
class CoCoOpPromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=4):
        super().__init__()
        self.n_ctx = n_ctx
        self.classnames = classnames
        dtype = torch.float32
        ctx_dim = int(clip_model.ln_final.weight.shape[0])
        self.clip_context_length = clip_model.context_length  # üîß AGGIUNGI QUESTA RIGA
        print(f"[CoCoOp] ctx_dim={ctx_dim}, max_len={self.clip_context_length}")

        device = next(clip_model.parameters()).device
        # keep a reference to the clip model and dtype for later updates
        self.clip_model = clip_model
        self.dtype = dtype

        ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device=device)
        nn.init.normal_(ctx_vectors, std=0.02)
        self.ctx = nn.Parameter(ctx_vectors)

        self.meta_net = MetaNetwork(ctx_dim).to(device=device, dtype=dtype)

        classnames_tokens = clip.tokenize(classnames).to(device)
        with torch.no_grad():
            self.register_buffer("class_token_embeddings",
                              clip_model.token_embedding(classnames_tokens).to(dtype=dtype))
        self.register_buffer("class_token_ids", classnames_tokens)

    def set_classnames(self, classnames):
        """Replace the internal class token embeddings with a new set.

        This is used at evaluation time to temporarily switch prompts
        to a different class set (e.g., novel classes).
        """
        device = next(self.parameters()).device
        tokens = clip.tokenize(classnames).to(device)
        with torch.no_grad():
            emb = self.clip_model.token_embedding(tokens).to(dtype=self.dtype)
        # register/replace buffers so state_dict remains consistent
        self.register_buffer("class_token_embeddings", emb)
        self.register_buffer("class_token_ids", tokens)

    #was it causing problems?   #trying to use ai to solve
    def forward(self, image_features):
      """
      Genera prompts per ogni immagine e classe.

      Args:
          image_features: tensor (B, ctx_dim) dalle immagini encodate

      Returns:
          prompts: tensor (B, num_classes, seq_len_total, ctx_dim) in FLOAT32
      """
      batch_size = image_features.shape[0]
      num_classes, seq_len, ctx_dim = self.class_token_embeddings.shape

      # FIX: converting EVERYTHING to float32 right away
      image_features = image_features.to(dtype=torch.float32)

      # Step 1: genrating conditional token
      cond_token = self.meta_net(image_features)  # (B, ctx_dim)
      cond_token = cond_token.unsqueeze(1).to(dtype=torch.float32)  # (B, 1, ctx_dim)
 
      # Step 2: Context vectors
      ctx = self.ctx.unsqueeze(0).unsqueeze(0).to(dtype=torch.float32)  # (1, 1, n_ctx, ctx_dim)
      ctx = ctx.repeat(batch_size, num_classes, 1, 1)  # (B, num_classes, n_ctx, ctx_dim)

      # Step 3: Conditional token expansion
      cond_expand = cond_token.unsqueeze(1).to(dtype=torch.float32)  # (B, 1, 1, ctx_dim)
      cond_expand = cond_expand.repeat(1, num_classes, 1, 1)  # (B, num_classes, 1, ctx_dim)

      # Step 4: Class embeddings expansion
      class_embed = self.class_token_embeddings.unsqueeze(0).to(dtype=torch.float32)  # (1, num_classes, seq_len, ctx_dim)
      class_embed = class_embed.repeat(batch_size, 1, 1, 1)  # (B, num_classes, seq_len, ctx_dim)

      # Step 5: Concatenate
      prompts = torch.cat([ctx, cond_expand, class_embed], dim=2)  # (B, num_classes, n_ctx + 1 + seq_len, ctx_dim)

      # Trim to CLIP max length: 77 tokens
      prompts = prompts[:, :, :self.clip_context_length, :]

      # Ensure return type is float32 -> was causing problems
      return prompts.to(dtype=torch.float32)


## CoCoOpTrainer: Training and Evaluation

Class that manages:

**1. Initialization:**
- Create PromptLearner
- Freeze CLIP (`requires_grad=False`)
- Configure SGD optimizer for prompt learner only

**2. train_epoch():**
- Forward: Image encoder + PromptLearner + Text encoder
- **Critical step:** Encode soft prompts through text transformer
  - Add positional embeddings
  - Pass through CLIP's transformer
  - Extract first token
  - Apply final layer norm + projection
- Compute loss: Cross-entropy on base classes
- Backward: Backprop only in PromptLearner
- Return: Average loss of the epoch

**3. eval():**
- Same forward procedure as training
- Without backward pass
- Compute accuracy on any dataset (base or novel)

**Important note:** We don't use `model.encode_text()` on soft prompts
because that method expects integer tokens, not embeddings.
We manually forward through the text transformer.

In [11]:
class CoCoOpTrainer:
    def __init__(self, clip_model, base_classnames, base_classes,
                 novel_classes, device, lr=0.002):
        """Trainer ultra-light per memoria limitata"""
        self.clip_model = clip_model
        self.base_classnames = base_classnames
        self.base_classes = base_classes
        self.novel_classes = novel_classes
        self.device = device

        # Contig mapping
        self.contig_cat2idx = {cat: idx for idx, cat in enumerate(self.base_classes)}

        # Freeze CLIP
        for p in clip_model.parameters():
            p.requires_grad = False

        # Prompt learner
        self.prompt_learner = CoCoOpPromptLearner(
            clip_model,
            base_classnames
        ).to(device=device, dtype=clip_model.dtype)

        # Optimizer
        self.optimizer = torch.optim.SGD(
            self.prompt_learner.parameters(),
            lr=lr,
            momentum=0.9,
            weight_decay=5e-4
        )

    def train_epoch(self, train_dataset, batch_size=1):
        """Training minimale - NO inplace operations"""
        self.prompt_learner.train()
        self.clip_model.eval()

        dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=False
        )

        total_loss = 0
        n_batches = 0

        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc="CoCoOp")):
            torch.cuda.empty_cache()

            images = images.to(self.device)
            labels = labels.to(self.device)

            # Image features
            with torch.no_grad():
                img_feat = self.clip_model.encode_image(images)
            img_feat = img_feat.float()
            img_feat = img_feat / (img_feat.norm(dim=-1, keepdim=True) + 1e-8)  # NO inplace

            # Prompts (already float32 from prompt_learner)
            prompts = self.prompt_learner(img_feat)
            B, N, L, D = prompts.shape
            prompts_flat = prompts.view(B * N, L, D)

            # Text encoding - FORCE float32
            pos_emb = self.clip_model.positional_embedding[:L].float()
            x = prompts_flat.float()
            x = x + pos_emb
            x = x.permute(1, 0, 2)
            x = x.float()

            # Transformer
            x = self.clip_model.transformer(x)
            x = x.permute(1, 0, 2)
            x = x[:, 0, :].contiguous()

            # Final layers
            x = self.clip_model.ln_final(x.float())
            text_feat = x.float() @ self.clip_model.text_projection  # (B*N, 512)
            text_feat = text_feat.view(B, N, -1)
            text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-8)  # NO inplace

            # Loss
            logit_scale = self.clip_model.logit_scale.exp()
            logits = logit_scale * (img_feat.unsqueeze(1) * text_feat).sum(-1)

            # Map labels safely to contiguous indices
            mapped = []
            missing = []
            for l in labels:
                key = l.item()
                idx = self.contig_cat2idx.get(key)
                if idx is None:
                    missing.append(key)
                else:
                    mapped.append(idx)
            if missing:
                raise ValueError(f"Found labels not in trainer's categories: {missing}.\nAvailable categories: {list(self.contig_cat2idx.keys())}")

            labels_mapped = torch.tensor(mapped, dtype=torch.long, device=self.device)

            loss = F.cross_entropy(logits, labels_mapped)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            n_batches += 1

            del prompts, prompts_flat, x, logits, text_feat, img_feat
            torch.cuda.empty_cache()

        return total_loss / max(1, n_batches)

    @torch.no_grad()
    def eval(self, dataset, categories, batch_size=1, classnames=None):
        """Evaluation - NO inplace operations

        Args:
            dataset: dataset to evaluate
            categories: list of category indices (contiguous ids from original dataset)
            batch_size: dataloader batch size
            classnames: optional list of class name strings corresponding to `categories`.
                        If provided, the prompt learner will be temporarily switched
                        to use these classnames when building prompts.
        """
        self.prompt_learner.eval()
        self.clip_model.eval()

        contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

        # If requested, swap prompt learner class embeddings to match `classnames`.
        old_emb = None
        old_ids = None
        if classnames is not None:
            old_emb = self.prompt_learner.class_token_embeddings.clone()
            old_ids = self.prompt_learner.class_token_ids.clone()
            self.prompt_learner.set_classnames(classnames)

        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )

        correct = 0
        total = 0

        for images, labels in tqdm(dataloader, desc="Eval"):
            torch.cuda.empty_cache()

            images = images.to(self.device)
            labels = labels.to(self.device)

            img_feat = self.clip_model.encode_image(images).float()
            img_feat = img_feat / (img_feat.norm(dim=-1, keepdim=True) + 1e-8)  # NO inplace

            prompts = self.prompt_learner(img_feat)
            B, N, L, D = prompts.shape
            prompts_flat = prompts.view(B * N, L, D)

            # Text encoding
            pos_emb = self.clip_model.positional_embedding[:L].float()
            x = prompts_flat.float()
            x = x + pos_emb
            x = x.permute(1, 0, 2)
            x = x.float()

            x = self.clip_model.transformer(x)
            x = x.permute(1, 0, 2)
            x = x[:, 0, :].contiguous()

            x = self.clip_model.ln_final(x.float())
            text_feat = x.float() @ self.clip_model.text_projection
            text_feat = text_feat.view(B, N, -1)
            text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-8)  #NO inplace -> was making runtime crash since it worked directly on memory

            logit_scale = self.clip_model.logit_scale.exp()
            logits = logit_scale * (img_feat.unsqueeze(1) * text_feat).sum(-1)

            pred = logits.argmax(dim=1)

            # Map labels safely to contiguous indices for evaluation
            mapped = []
            missing = []
            for l in labels:
                key = l.item()
                idx = contig_cat2idx.get(key)
                if idx is None:
                    missing.append(key)
                else:
                    mapped.append(idx)
            if missing:
                raise ValueError(f"Found labels not in evaluation categories: {missing}.\nProvided categories: {categories}")

            labels_mapped = torch.tensor(mapped, dtype=torch.long, device=self.device)

            correct += (pred == labels_mapped).sum().item()
            total += labels.size(0)

            del prompts, prompts_flat, x, logits, text_feat, img_feat
            torch.cuda.empty_cache()

        # restore old class buffers if we swapped them
        if classnames is not None and old_emb is not None and old_ids is not None:
            self.prompt_learner.register_buffer("class_token_embeddings", old_emb)
            self.prompt_learner.register_buffer("class_token_ids", old_ids)

        return correct / total


## Training CoCoOp

We will train the PromptLearner for **5 epochs** on **base classes only**.

**Hyperparameters:**
- Learning rate: 0.002 (SGD)
- Momentum: 0.9
- Weight decay: 5e-4
- Batch size: 1
- Epochs: 5

**What happens:**
- Context vectors V1...VM adapt to the Flowers102 dataset
- MetaNetwork learns to generate useful conditional tokens
- CLIP remains frozen (unchanged)

**Expected output:**
- Initial loss: ~3.0
- Final loss: ~1.3-1.5
- Training time: ~5-10 minutes on GPU

## How I fixed

| Problem           | Fix                                                  |
| ----------------- | ---------------------------------------------------- |
| Inplace /=        | text_feat = text_feat / (text_feat.norm(...) + 1e-8) |
| Dtype mismatch    | .float() everywhere prima del transformer               |
| Memory leak       | del + torch.cuda.empty_cache() each batch            |
| Gradient tracking | NO .detach() in forward pass                        |

In [12]:
base_classnames = [CLASS_NAMES[i] for i in base_classes]
print(f"Base classnames ({len(base_classnames)}): {base_classnames[:5]}...\n")

trainer = CoCoOpTrainer(
    clip_model=model,
    base_classnames=base_classnames,
    base_classes=base_classes,
    novel_classes=novel_classes,
    device=device,
    lr=0.002
)

print("\n" + "="*60)
print("Training CoCoOp") #memory optimised (could it be better?)
print("="*60)

num_epochs = 5
for epoch in range(num_epochs):
    avg_loss = trainer.train_epoch(train_base, batch_size=1)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

print("\n Training completed!")

Base classnames (51): ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold']...

[CoCoOp] ctx_dim=512, max_len=77

Training CoCoOp


CoCoOp:   0%|          | 1/510 [00:18<2:35:56, 18.38s/it]


KeyboardInterrupt: 

## Final Evaluation (CoCoOp only)

We'll evaluate the model with:
1. Test Base
2. Test Novel

Computing Harmonic Mean between them to evaluate the trade-off.


In [None]:
print("\n" + "="*60)
print("EVALUATION AND COMPARISON (CoCoOp Only)")
print("="*60)

# Prepare class name lists for evaluation
base_classnames = [CLASS_NAMES[i] for i in base_classes]
novel_classnames = [CLASS_NAMES[i] for i in novel_classes]

# Evaluating CoCoOp on base and novel
base_acc_cocoop = trainer.eval(test_base, base_classes, batch_size=64, classnames=base_classnames)
novel_acc_cocoop = trainer.eval(test_novel, novel_classes, batch_size=64, classnames=novel_classnames)
hm_cocoop = harmonic_mean(base_acc_cocoop, novel_acc_cocoop)

# Printing results for CoCoOp
print("\n" + "="*60)
print("CoCoOp RESULTS")
print("="*60)

print(f"   Base Accuracy:  {base_acc_cocoop*100:6.2f}%")
print(f"   Novel Accuracy: {novel_acc_cocoop*100:6.2f}%")
print(f"   Harmonic Mean:  {hm_cocoop*100:6.2f}%")

print("\n" + "="*60)
