# CLIP with Flowers!?!?!??!?

In [1]:
import os
# Prefer expandable segments to reduce fragmentation (restart kernel after changing)
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

import sys
import torch
import torchvision

# Ensure CLIP is installed in the current kernel; install if missing.
# Using subprocess with sys.executable to target the same Python interpreter.
try:
    import clip
except Exception:
    import subprocess, importlib
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "git+https://github.com/openai/CLIP.git"], stdout=subprocess.DEVNULL)
    importlib.invalidate_caches()
    import clip

from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-amr0i8uo
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-amr0i8uo
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.8/44.8 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=381848863840c09469a58534cf4401c02e2cb96411f7de224e12b990a1a1f72c
  Stored in 

In [2]:
# Check environment and CLIP installation
import sys, importlib, torch

try:
    import clip
    print("CLIP: gi√† installato")
except Exception:
    print("CLIP non trovato: eseguo installazione nel kernel corrente...")
    import importlib
    # Preferisci %pip per installare nel kernel Jupyter corrente; fallback a subprocess
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except Exception:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "git+https://github.com/openai/CLIP.git"])
    importlib.invalidate_caches()
    import clip
    print("CLIP installato correttamente nel kernel corrente")

print("python:", sys.executable)
print("torch:", torch.__version__, "cuda_available:", torch.cuda.is_available())
# su mac con Apple Silicon, controlla MPS
try:
    print("mps_available:", torch.backends.mps.is_available())
except Exception:
    pass

print('Riavvia il kernel se necessario, poi esegui questa cella e procedi con l\'allenamento CoCoOp.')


CLIP: gi√† installato
python: /usr/bin/python3
torch: 2.9.0+cu126 cuda_available: True
mps_available: False
Riavvia il kernel se necessario, poi esegui questa cella e procedi con l'allenamento CoCoOp.


## Dataset Functions

We define utility functions for:
- **`get_data()`**: Load Flowers102 from torchvision
- **`base_novel_categories()`**: Split 102 classes into base (0-50) and novel (51-101)
- **`split_data()`**: Filter images for base/novel in each split

This simulates the real scenario: we have 51 seen classes during training (base) and 51 new ones (novel).


In [3]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets."""
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test


def base_novel_categories(dataset):
    """Return base and novel class id lists using the actual labels present
    in the dataset. Prefer public attributes (`targets` then `labels`) and
    only fall back to the dataset private attribute `_labels` if neither is
    available.
    """
    labels = getattr(dataset, "targets", None)
    if labels is None:
        labels = getattr(dataset, "labels", None)

    if labels is None and hasattr(dataset, "_labels"):
        # FALLBACK: using private dataset internals. Flowers102 exposes
        # `_labels` but this is a private attribute; prefer public attributes
        # above so future datasets remain compatible.
        labels = dataset._labels

    if labels is None:
        raise ValueError("Could not find labels on dataset (checked 'targets','labels','_labels').")

    unique_labels = sorted(set(labels))
    num_classes = len(unique_labels)
    mid = num_classes // 2
    base_classes = unique_labels[:mid]
    novel_classes = unique_labels[mid:]
    return base_classes, novel_classes


def split_data(dataset, base_classes):
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Class Names and Dataset Loading

We load the names of 102 flower classes from Flowers102.

This is **critical** for CLIP:
- Creates prompts like "a photo of a **rose**, a type of flower"
- Each prompt is encoded by CLIP's text encoder
- Image features are compared against these text templates


In [4]:
_, _, tmp_test = get_data()
base_classes, novel_classes = base_novel_categories(tmp_test)

CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# Uncomment to see class names
# print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
# print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 345M/345M [00:17<00:00, 20.3MB/s] 
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 502/502 [00:00<00:00, 1.92MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15.0k/15.0k [00:00<00:00, 42.6MB/s]


In [5]:
# Load CLIP model and preprocessing
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)
print(f"Device: {device}")
print(f"Model: ViT-B/16")

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 335M/335M [00:01<00:00, 325MiB/s]


Device: cuda
Model: ViT-B/16


## Load Flowers102 and Split Base/Novel

We load the 3 splits (train, val, test) and divide into base/novel.

**Statistics:**
- Train Base: 10 images √ó 51 classes = 510 images
- Val Base: 10 images √ó 51 classes = 510 images
- Test Base: ~10 images √ó 51 classes (from test split)
- Test Novel: Remaining (~10 per class)

**Note:** Train and val have ~10 images per class (few-shot setting).


In [6]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

print(f"Train Base: {len(train_base)} samples")
print(f"Val Base: {len(val_base)} samples")
print(f"Test Base: {len(test_base)} samples")
print(f"Test Novel: {len(test_novel)} samples")

Train Base: 510 samples
Val Base: 510 samples
Test Base: 2473 samples
Test Novel: 3676 samples


## Harmonic Mean (HM)

Standard metric for few-shot adaptation papers.

Formula: HM = 2 / (1/base_acc + 1/novel_acc)

**Why HM instead of arithmetic mean?**
- HM heavily penalizes outliers
- If base=90% and novel=50%: arithmetic mean=70%, HM=64.3%
- Forces the model to balance both accuracies

**Obiettivo:** massimizzare l'HM tra `base_acc_cocoop` e `novel_acc_cocoop`.


In [7]:
def harmonic_mean(base_accuracy, novel_accuracy):
    # Guard against zero to avoid division-by-zero errors
    if base_accuracy <= 0 or novel_accuracy <= 0:
        return 0.0
    numerator = 2.0
    denominator = 1.0 / base_accuracy + 1.0 / novel_accuracy
    return numerator / denominator


## MetaNetwork: Conditional Token Generator

**Problem:** Fixed prompts don't adapt to each image.

**Solution:** A small neural network that transforms image features into a conditional token.

**Parameters:** ~256K (negligible vs. fine-tuning)

**Effect:** Each image gets a different prompt ‚Üí instance-level adaptation


In [8]:
"""
MetaNetwork √® una piccola rete neurale (MLP con 2 layer)
che trasforma le image_features (512-dim) in un token
condizionale (512-dim) usato in CoCoOp.

Questo token varia per ogni immagine, permettendo prompt
personalizzati per ogni input.
"""

class MetaNetwork(nn.Module):
    def __init__(self, ctx_dim=512, hidden_dim=256):
        """
        Args:
            ctx_dim: dimensione degli embeddings (512 per ViT-B/16)
            hidden_dim: dimensione dello strato nascosto
        """
        super().__init__()
        self.linear1 = nn.Linear(ctx_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(hidden_dim, ctx_dim)

    def forward(self, image_features):
        """
        Args:
            image_features: tensor (B, ctx_dim) dalle immagini encodate

        Returns:
            conditional_token: tensor (B, ctx_dim)
        """
        # Assicura il tipo corretto (importante per mixed precision)
        image_features = image_features.to(self.linear1.weight.dtype)

        out = self.linear1(image_features)
        out = self.relu(out)
        out = self.linear2(out)
        return out


## CoCoOpPromptLearner: Dynamic Prompts


**Components:**
1. **V1...VM:** 16 context vectors (learned via SGD)
   - Shape: (16, 512) tensors
   - Initialized randomly from N(0, 0.02¬≤)
   - Optimized during training

2. **œÄ(x):** Conditional token (generated per image)
   - Shape: (B, 512) from MetaNetwork output
   - Different for each image

3. **[CLASS]:** Class name embedding
   - Shape: (seq_len, 512) from CLIP's token embedding
   - Same for all images of the same class

**Forward Pass:**
- Input: image_features (B, 512)
- Output: prompts (B, num_classes, seq_len_total, 512)


In [15]:
class CoCoOpPromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=16, ctx_dim=512):
        super().__init__()
        self.n_ctx = n_ctx
        self.ctx_dim = ctx_dim
        
        # ‚úÖ LEARNABLE CONTEXT
        self.ctx = nn.Parameter(
            torch.randn(n_ctx, ctx_dim).float() * 0.02
        )
        
        # ‚úÖ META NETWORK - CONDIZIONA I CONTEXT
        self.meta_net = nn.Sequential(
            nn.Linear(ctx_dim, ctx_dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(ctx_dim // 2, ctx_dim)
        )
        
        # ‚úÖ CLASS EMBEDDINGS (congelati)
        with torch.no_grad():
            class_tokens = clip_model.tokenize(classnames).to(clip_model.device)
            class_emb = clip_model.token_embedding(class_tokens).type(clip_model.dtype)
        
        self.register_buffer("class_token_embeddings", class_emb)
        
        self.clip_context_length = 77
    
    def forward(self, image_features):
        """
        ‚úÖ CORRECTED FORWARD
        
        image_features: (B, 512)
        Output: (B, num_classes, 77, 512)
        """
        batch_size = image_features.shape[0]
        num_classes = self.class_token_embeddings.shape[0]
        
        # ‚úÖ Step 1: Meta network GENERA condizionamenti dai features
        cond = self.meta_net(image_features)  # (B, 512)
        
        # ‚úÖ Step 2: Context PARAMETRI
        ctx = self.ctx.unsqueeze(0)  # (1, n_ctx, 512)
        ctx = ctx.repeat(batch_size, 1, 1)  # (B, n_ctx, 512)
        
        # ‚úÖ Step 3: MODULA il context usando meta_net output
        # Questo √® il VERO differenziale con il context fisso!
        ctx = ctx + cond.unsqueeze(1) * 0.1  # (B, n_ctx, 512)
        
        # ‚úÖ Step 4: Class embeddings
        class_embed = self.class_token_embeddings.unsqueeze(0)  # (1, num_classes, L, 512)
        class_embed = class_embed.repeat(batch_size, 1, 1, 1)  # (B, num_classes, L, 512)
        
        # ‚úÖ Step 5: Build prompts per ogni classe
        prompts_list = []
        for i in range(num_classes):
            # ctx: (B, n_ctx, 512)
            # class_embed[:, i, :, :]: (B, L, 512)
            prompt_i = torch.cat([
                ctx,  # (B, n_ctx, 512)
                class_embed[:, i, :, :]  # (B, L, 512)
            ], dim=1)  # (B, n_ctx + L, 512)
            
            # Trim a 77 tokens max (CLIP length)
            prompt_i = prompt_i[:, :self.clip_context_length, :]  # (B, 77, 512)
            prompts_list.append(prompt_i)
        
        # Stack: (B, num_classes, 77, 512)
        prompts = torch.stack(prompts_list, dim=1)
        
        return prompts.to(dtype=torch.float32)


## CoCoOpTrainer: Training and Evaluation

Class that manages:

**1. Initialization:**
- Create PromptLearner
- Freeze CLIP (`requires_grad=False`)
- Configure SGD optimizer for prompt learner only

**2. train_epoch():**
- Forward: Image encoder + PromptLearner + Text encoder
- **Critical step:** Encode soft prompts through text transformer
  - Add positional embeddings
  - Pass through CLIP's transformer
  - Extract first token
  - Apply final layer norm + projection
- Compute loss: Cross-entropy on base classes
- Backward: Backprop only in PromptLearner
- Return: Average loss of the epoch

**3. eval():**
- Same forward procedure as training
- Without backward pass
- Compute accuracy on any dataset (base or novel)

**Important note:** We don't use `model.encode_text()` on soft prompts
because that method expects integer tokens, not embeddings.
We manually forward through the text transformer.

In [None]:
class CoCoOpTrainer:
    def __init__(self, clip_model, base_classnames, base_classes,
                 novel_classes, device, lr=0.002):
        """Trainer ultra-light per memoria limitata"""
        # ‚ö†Ô∏è FORCE CLIP to float32 BEFORE doing anything else!
        self.clip_model = clip_model.float()

        self.base_classnames = base_classnames
        self.base_classes = base_classes
        self.novel_classes = novel_classes
        self.device = device

        # Contig mapping
        self.contig_cat2idx = {cat: idx for idx, cat in enumerate(self.base_classes)}

        # Freeze CLIP
        for p in self.clip_model.parameters():
            p.requires_grad = False

        # Prompt learner - use float32
        self.prompt_learner = CoCoOpPromptLearner(
            self.clip_model,
            base_classnames
        ).to(device=device, dtype=torch.float32)

        # Optimizer
        self.optimizer = torch.optim.SGD(
            self.prompt_learner.parameters(),
            lr=lr,
            momentum=0.9,
            weight_decay=5e-4
        )

    def train_epoch(self, train_dataset, batch_size=1):
        """TRAINING LOOP CORRETTO"""
        self.prompt_learner.train()
        self.clip_model.eval()
        
        dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=False
        )
        
        total_loss = 0
        n_batches = 0
        
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc="CoCoOp Training")):
            torch.cuda.empty_cache()
            
            images = images.to(self.device).float()
            labels = labels.to(self.device)
            
            # ‚úÖ Step 1: Encode images (congelato)
            with torch.no_grad():
                img_feat = self.clip_model.encode_image(images)  # (B, 512)
            img_feat = img_feat.float()
            
            # ‚úÖ Step 2: Generate adaptive prompts using meta_net
            prompts = self.prompt_learner(img_feat)  # (B, num_classes, 77, 512)
            B, N, L, D = prompts.shape
            prompts_flat = prompts.view(B * N, L, D)  # (B*N, 77, 512)
            
            # ‚úÖ Step 3: Pass through CLIP text transformer
            pos_emb = self.clip_model.positional_embedding[:L].float()
            x = prompts_flat.float()
            x = x + pos_emb
            x = x.permute(1, 0, 2).float()  # (77, B*N, 512)
            
            x = self.clip_model.transformer(x)
            x = x.permute(1, 0, 2).float()  # (B*N, 77, 512)
            x = x[:, 0, :].contiguous()  # (B*N, 512) - [CLS] token
            
            x = self.clip_model.ln_final(x.float())  # (B*N, 512)
            text_feat = x.view(B, N, -1)  # (B, N, 512)
            
            # ‚úÖ Step 4: Compute logits (similarity)
            text_feat_norm = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-8)
            img_feat_norm = img_feat / (img_feat.norm(dim=-1, keepdim=True) + 1e-8)
            
            logit_scale = self.clip_model.logit_scale.exp()
            logits = logit_scale * torch.matmul(
                img_feat_norm.unsqueeze(1),  # (B, 1, 512)
                text_feat_norm.permute(0, 2, 1)  # (B, 512, N)
            ).squeeze(1)  # (B, N)
            
            # ‚úÖ Step 5: Compute loss
            labels_mapped = torch.tensor(
                [self.contig_cat2idx[l.item()] for l in labels],
                dtype=torch.long,
                device=self.device
            )
            
            loss = F.cross_entropy(logits, labels_mapped)
            
            # ‚úÖ Step 6: Backward
            self.optimizer.zero_grad()
            loss.backward()
            
            # üîç DEBUG
            if batch_idx == 0:
                print("\n" + "="*70)
                print("GRADIENT CHECK:")
                for name, param in self.prompt_learner.named_parameters():
                    if param.grad is not None:
                        grad_norm = param.grad.norm().item()
                        print(f"{name:30s} | grad_norm: {grad_norm:.8f}")
                print("="*70 + "\n")
            
            self.optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
            
            del prompts, prompts_flat, x, logits, text_feat, img_feat
            torch.cuda.empty_cache()
        
        return total_loss / max(1, n_batches)


    @torch.no_grad()
    def eval(self, dataset, categories, batch_size=1, classnames=None):
        """Evaluation - IDENTICO AL TRAINING (ma senza backward)"""
        self.prompt_learner.eval()
        self.clip_model.eval()
        
        contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
        
        old_emb = None
        if classnames is not None:
            old_emb = self.prompt_learner.class_token_embeddings.clone()
            self.prompt_learner.class_token_embeddings = nn.Parameter(
                self.clip_model.token_embedding(
                    self.clip_model.tokenize(classnames).to(self.device)
                ).type(self.clip_model.dtype)
            )
        
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )
        
        correct = 0
        total = 0
        
        for images, labels in tqdm(dataloader, desc="Eval"):
            torch.cuda.empty_cache()
            
            images = images.to(self.device).float()
            labels = labels.to(self.device)
            
            # ‚úÖ IDENTICO AL TRAINING
            img_feat = self.clip_model.encode_image(images).float()  # (B, 512)
            
            # ‚úÖ Meta_net USATO anche in eval
            prompts = self.prompt_learner(img_feat)  # (B, num_classes, 77, 512)
            B, N, L, D = prompts.shape
            prompts_flat = prompts.view(B * N, L, D)
            
            pos_emb = self.clip_model.positional_embedding[:L].float()
            x = prompts_flat.float()
            x = x + pos_emb
            x = x.permute(1, 0, 2).float()
            
            x = self.clip_model.transformer(x)
            x = x.permute(1, 0, 2).float()
            x = x[:, 0, :].contiguous()
            
            x = self.clip_model.ln_final(x.float())
            text_feat = x.view(B, N, -1)  # (B, N, 512)
            
            # ‚úÖ IDENTICO AL TRAINING
            text_feat_norm = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-8)
            img_feat_norm = img_feat / (img_feat.norm(dim=-1, keepdim=True) + 1e-8)
            
            logit_scale = self.clip_model.logit_scale.exp()
            logits = logit_scale * torch.matmul(
                img_feat_norm.unsqueeze(1),  # (B, 1, 512)
                text_feat_norm.permute(0, 2, 1)  # (B, 512, N)
            ).squeeze(1)  # (B, N)
            
            pred = logits.argmax(dim=1)
            
            labels_mapped = torch.tensor(
                [contig_cat2idx[l.item()] for l in labels],
                dtype=torch.long,
                device=self.device
            )
            
            correct += (pred == labels_mapped).sum().item()
            total += labels.size(0)
            
            del prompts, prompts_flat, x, logits, text_feat, img_feat
            torch.cuda.empty_cache()
        
        if classnames is not None and old_emb is not None:
            self.prompt_learner.class_token_embeddings = old_emb
        
        return correct / total if total > 0 else 0.0


In [None]:
"""debug

# üîç TEST DIAGNOSTICO - Esegui una volta sola PRIMA del training
base_classnames = [CLASS_NAMES[i] for i in base_classes]

trainer = CoCoOpTrainer(
    clip_model=model,
    base_classnames=base_classnames,
    base_classes=base_classes,
    novel_classes=novel_classes,
    device=device,
    lr=0.002,  # more stable LR
)

# Test su UN SINGOLO BATCH
test_loader = torch.utils.data.DataLoader(train_base, batch_size=4, shuffle=False)
images, labels = next(iter(test_loader))

images = images.to(device).float()
labels = labels.to(device)

print("="*70)
print("TEST 1: I parametri sono TRAINABILI?")
print("="*70)

for name, param in trainer.prompt_learner.named_parameters():
    print(f"{name:40s} | requires_grad: {param.requires_grad} | shape: {param.shape}")

print("\n" + "="*70)
print("TEST 2: Forward pass funziona?")
print("="*70)

with torch.no_grad():
    img_feat = model.encode_image(images)
img_feat = img_feat.float()
print(f"img_feat shape: {img_feat.shape}, dtype: {img_feat.dtype}")

prompts = trainer.prompt_learner(img_feat)
print(f"prompts shape: {prompts.shape}, dtype: {prompts.dtype}")

B, N, L, D = prompts.shape
prompts_flat = prompts.view(B * N, L, D)
print(f"prompts_flat shape: {prompts_flat.shape}")

pos_emb = model.positional_embedding[:L].float()
x = prompts_flat.float()
x = x + pos_emb
x = x.permute(1, 0, 2).float()

print(f"x before transformer: {x.shape}, dtype: {x.dtype}")

x = model.transformer(x)
x = x.permute(1, 0, 2).float()
x = x[:, 0, :].contiguous()

print(f"x after transformer: {x.shape}, dtype: {x.dtype}")

x = model.ln_final(x.float())
text_feat = x.view(B, N, -1)

print(f"text_feat: {text_feat.shape}, dtype: {text_feat.dtype}")

print("\n" + "="*70)
print("TEST 3: Logits computation")
print("="*70)

text_feat_norm = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-8)
img_feat_norm = img_feat / (img_feat.norm(dim=-1, keepdim=True) + 1e-8)

print(f"text_feat_norm min: {text_feat_norm.min().item():.4f}, max: {text_feat_norm.max().item():.4f}")
print(f"img_feat_norm min: {img_feat_norm.min().item():.4f}, max: {img_feat_norm.max().item():.4f}")

logit_scale = model.logit_scale.exp()
print(f"logit_scale: {logit_scale.item():.4f}")

# Vecchio modo (SBAGLIATO)
logits_old = logit_scale * (img_feat.unsqueeze(1) * text_feat).sum(-1)
print(f"logits_old shape: {logits_old.shape}, min: {logits_old.min().item():.4f}, max: {logits_old.max().item():.4f}")

# Nuovo modo (CORRETTO)
logits_new = logit_scale * torch.matmul(
    img_feat_norm.unsqueeze(1),
    text_feat_norm.permute(0, 2, 1)
).squeeze(1)
print(f"logits_new shape: {logits_new.shape}, min: {logits_new.min().item():.4f}, max: {logits_new.max().item():.4f}")

print("\n" + "="*70)
print("TEST 4: Loss computation")
print("="*70)

labels_mapped = torch.tensor(
    [trainer.contig_cat2idx[l.item()] for l in labels],
    dtype=torch.long,
    device=device
)

loss_old = torch.nn.functional.cross_entropy(logits_old, labels_mapped)
loss_new = torch.nn.functional.cross_entropy(logits_new, labels_mapped)

print(f"loss_old: {loss_old.item():.4f}")
print(f"loss_new: {loss_new.item():.4f}")
print(f"logits_old requires_grad: {logits_old.requires_grad}")
print(f"logits_new requires_grad: {logits_new.requires_grad}")

print("\n" + "="*70)
print("TEST 5: Gradienti nel training")
print("="*70)

# Resetta gli optimizer
trainer.optimizer.zero_grad()

# Backward
loss_new.backward()

# Controlla i gradienti
print("Gradienti dopo backward():")
for name, param in trainer.prompt_learner.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm().item()
        print(f"{name:40s} | grad_norm: {grad_norm:.8f}")
    else:
        print(f"{name:40s} | grad: None (‚ùå PROBLEMA!)")

print("\n" + "="*70)
print("TEST 6: Aggiorna i parametri")
print("="*70)

# Salva i vecchi valori
old_ctx = trainer.prompt_learner.ctx.clone()

# Step
trainer.optimizer.step()

# Controlla se sono cambiati
new_ctx = trainer.prompt_learner.ctx.clone()
changed = not torch.allclose(old_ctx, new_ctx)
max_change = (new_ctx - old_ctx).abs().max().item()

print(f"ctx √® cambiato: {changed}")
print(f"Max change in ctx: {max_change:.8f}")

if changed:
    print("‚úÖ IL TRAINING FUNZIONA!")
else:
    print("‚ùå I PARAMETRI NON SI STANNO AGGIORNANDO!")
    
"""

## Training CoCoOp

We will train the PromptLearner for **5 epochs** on **base classes only**.

**Hyperparameters:**
- Learning rate: 0.002 (SGD)
- Momentum: 0.9
- Weight decay: 5e-4
- Batch size: 1
- Epochs: 5

**What happens:**
- Context vectors V1...VM adapt to the Flowers102 dataset
- MetaNetwork learns to generate useful conditional tokens
- CLIP remains frozen (unchanged)

**Expected output:**
- Initial loss: ~3.0
- Final loss: ~1.3-1.5
- Training time: ~5-10 minutes on GPU

## How I fixed

| Problem           | Fix                                                  |
| ----------------- | ---------------------------------------------------- |
| Inplace /=        | text_feat = text_feat / (text_feat.norm(...) + 1e-8) |
| Dtype mismatch    | .float() everywhere prima del transformer               |
| Memory leak       | del + torch.cuda.empty_cache() each batch            |
| Gradient tracking | NO .detach() in forward pass                        |

there were ther problems with gpu runtime that have been fixed by

| What            | Before                 | After                                    |
| --------------- | ---------------------- | ----------------------------------------- |
| __init__        | N/A                    | self.clip_model = self.clip_model.float() |
| images.to()     | images.to(device)      | images.to(device).float()                 |
| Prompt learner  | dtype=clip_model.dtype | dtype=torch.float32                       |
| x = x.permute() | No .float()            | x = x.permute(...).float()                |

In [None]:
base_classnames = [CLASS_NAMES[i] for i in base_classes]
print(f"Base classnames ({len(base_classnames)}): {base_classnames[:5]}...\n")

trainer = CoCoOpTrainer(
    clip_model=model,
    base_classnames=base_classnames,
    base_classes=base_classes,
    novel_classes=novel_classes,
    device=device,
    lr=0.002,  # more stable LR
)

# Scheduler (cosine) - step outside the epoch loop
trainer.scheduler = CosineAnnealingLR(trainer.optimizer, T_max=20)

print("\n" + "="*60)
print("Training CoCoOp") #better hyperparams
print("="*60 + "\n")

# Reproducibility
import random
torch.manual_seed(42)
random.seed(42)

num_epochs = 5 #should they be 20? -> fewer epochs for time
for epoch in range(num_epochs):
    avg_loss = trainer.train_epoch(train_base, batch_size=4)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")
    # step scheduler per epoch
    try:
        trainer.scheduler.step()
    except Exception:
        pass

print("\n Training completed!")


Base classnames (51): ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold']...

[CoCoOp] ctx_dim=512, n_ctx=16, max_len=77

Training CoCoOp



CoCoOp: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 128/128 [02:28<00:00,  1.16s/it]


Epoch 1/5 - Loss: 3.9318


CoCoOp: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 128/128 [02:27<00:00,  1.15s/it]


Epoch 2/5 - Loss: 3.9318


CoCoOp: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 128/128 [02:27<00:00,  1.15s/it]


Epoch 3/5 - Loss: 3.9318


CoCoOp: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 128/128 [02:27<00:00,  1.15s/it]


Epoch 4/5 - Loss: 3.9318


CoCoOp:  24%|‚ñà‚ñà‚ñç       | 31/128 [00:35<01:51,  1.15s/it]

## Final Evaluation (CoCoOp only)

We'll evaluate the model with:
1. Test Base
2. Test Novel

Computing Harmonic Mean between them to evaluate the trade-off.


In [None]:
# Prepare novel class names
novel_classnames = [CLASS_NAMES[i] for i in novel_classes]

print(f"Base classnames ({len(base_classnames)}): {base_classnames[:3]}...")
print(f"Novel classnames ({len(novel_classnames)}): {novel_classnames[:3]}...")

Base classnames (51): ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells']...
Novel classnames (51): ['wild pansy', 'primula', 'sunflower']...


In [None]:
# Preparation: Get novel classnames
novel_classnames = [CLASS_NAMES[i] for i in novel_classes]

print(f"Base classnames ({len(base_classnames)}): {base_classnames[:3]}...")
print(f"Novel classnames ({len(novel_classnames)}): {novel_classnames[:3]}...")

print("\n" + "="*60)
print("EVALUATION")
print("="*60)

# ‚ö†Ô∏è CORRECTED EVALUATION
base_acc_cocoop = trainer.eval(test_base, base_classes, batch_size=64, classnames=base_classnames)
novel_acc_cocoop = trainer.eval(test_novel, novel_classes, batch_size=64, classnames=novel_classnames)
hm_cocoop = harmonic_mean(base_acc_cocoop, novel_acc_cocoop)

print("\n" + "="*60)
print("CoCoOp RESULTS")
print("="*60 + "\n")

print(f" Base Accuracy:  {base_acc_cocoop*100:6.2f}%")
print(f" Novel Accuracy: {novel_acc_cocoop*100:6.2f}%")
print(f" Harmonic Mean:  {hm_cocoop*100:6.2f}%")

print("\n" + "="*60)

Base classnames (51): ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells']...
Novel classnames (51): ['wild pansy', 'primula', 'sunflower']...

EVALUATION


Eval: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 39/39 [05:12<00:00,  8.02s/it]
Eval: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 58/58 [07:45<00:00,  8.03s/it]


CoCoOp RESULTS

 Base Accuracy:    0.81%
 Novel Accuracy:   1.77%
 Harmonic Mean:    1.11%




