# CLIP with Flowers!?!?!??!?

In [18]:
import os
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

import sys
import torch
import torchvision
import random
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import OrderedDict

In [19]:
try:
    import clip
    print("✓ CLIP already installed")
except Exception:
    print("Installing CLIP...")
    import subprocess, importlib
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", 
                              "git+https://github.com/openai/CLIP.git"])
    importlib.invalidate_caches()
    import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"PyTorch: {torch.__version__}")

✓ CLIP already installed
Device: cuda
PyTorch: 2.9.0+cu126


## Dataset Functions

We define utility functions for:
- **`get_data()`**: Load Flowers102 from torchvision
- **`base_novel_categories()`**: Split 102 classes into base (0-50) and novel (51-101)
- **`split_data()`**: Filter images for base/novel in each split

This simulates the real scenario: we have 51 seen classes during training (base) and 51 new ones (novel).


In [20]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets."""
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test


def base_novel_categories(dataset):
    """Return base and novel class id lists using the actual labels present
    in the dataset. Prefer public attributes (`targets` then `labels`) and
    only fall back to the dataset private attribute `_labels` if neither is
    available.
    """
    labels = getattr(dataset, "targets", None)
    if labels is None:
        labels = getattr(dataset, "labels", None)

    if labels is None and hasattr(dataset, "_labels"):
        # FALLBACK: using private dataset internals. Flowers102 exposes
        # `_labels` but this is a private attribute; prefer public attributes
        # above so future datasets remain compatible.
        labels = dataset._labels

    if labels is None:
        raise ValueError("Could not find labels on dataset (checked 'targets','labels','_labels').")

    unique_labels = sorted(set(labels))
    num_classes = len(unique_labels)
    mid = num_classes // 2
    base_classes = unique_labels[:mid]
    novel_classes = unique_labels[mid:]
    return base_classes, novel_classes


def split_data(dataset, base_classes):
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Class Names and Dataset Loading

We load the names of 102 flower classes from Flowers102.

This is **critical** for CLIP:
- Creates prompts like "a photo of a **rose**, a type of flower"
- Each prompt is encoded by CLIP's text encoder
- Image features are compared against these text templates


In [21]:
_, _, tmp_test = get_data()
base_classes, novel_classes = base_novel_categories(tmp_test)

CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# Uncomment to see class names
# print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
# print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

In [22]:
# Load CLIP model and preprocessing
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)
print(f"Device: {device}")
print(f"Model: ViT-B/16")

Device: cuda
Model: ViT-B/16


## Load Flowers102 and Split Base/Novel

We load the 3 splits (train, val, test) and divide into base/novel.

**Statistics:**
- Train Base: 10 images × 51 classes = 510 images
- Val Base: 10 images × 51 classes = 510 images
- Test Base: ~10 images × 51 classes (from test split)
- Test Novel: Remaining (~10 per class)

**Note:** Train and val have ~10 images per class (few-shot setting).


In [23]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# Few-shot: sample `shots_per_class` images per base class from the train split
shots_per_class = 16
import random
random.seed(42)

# Collect indices per class in the original train_set
indices_per_class = {c: [] for c in base_classes}
for idx, label in enumerate(train_set._labels):
    if label in indices_per_class:
        indices_per_class[label].append(idx)

selected = []
for c in base_classes:
    inds = indices_per_class.get(c, [])
    random.shuffle(inds)
    # take up to shots_per_class (if fewer available, take all)
    selected.extend(inds[:shots_per_class])

# Create the few-shot training subset
train_base = torch.utils.data.Subset(train_set, selected)

# validation and test splits remain full (or filtered by base classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

print(f"Train Base (few-shot): {len(train_base)} samples ({shots_per_class} shots per class)")
print(f"Val Base: {len(val_base)} samples")
print(f"Test Base: {len(test_base)} samples")
print(f"Test Novel: {len(test_novel)} samples")

Train Base (few-shot): 510 samples (16 shots per class)
Val Base: 510 samples
Test Base: 2473 samples
Test Novel: 3676 samples


## Harmonic Mean (HM)

Standard metric for few-shot adaptation papers.

Formula: HM = 2 / (1/base_acc + 1/novel_acc)

**Why HM instead of arithmetic mean?**
- HM heavily penalizes outliers
- If base=90% and novel=50%: arithmetic mean=70%, HM=64.3%
- Forces the model to balance both accuracies

**Obiettivo:** massimizzare l'HM tra `base_acc_cocoop` e `novel_acc_cocoop`.


In [24]:
def harmonic_mean(base_accuracy, novel_accuracy):
    # Guard against zero to avoid division-by-zero errors
    if base_accuracy <= 0 or novel_accuracy <= 0:
        return 0.0
    numerator = 2.0
    denominator = 1.0 / base_accuracy + 1.0 / novel_accuracy
    return numerator / denominator


## Text Encoder

In [25]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(self.dtype)
        x = x[torch.arange(int(x.shape[0])), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x

## CoCoOpPromptLearner: Dynamic Prompts


**Components:**
1. **V1...VM:** 16 context vectors (learned via SGD)
   - Shape: (16, 512) tensors
   - Initialized randomly from N(0, 0.02²)
   - Optimized during training

2. **π(x):** Conditional token (generated per image)
   - Shape: (B, 512) from MetaNetwork output
   - Different for each image

3. **[CLASS]:** Class name embedding
   - Shape: (seq_len, 512) from CLIP's token embedding
   - Same for all images of the same class

**Forward Pass:**
- Input: image_features (B, 512)
- Output: prompts (B, num_classes, seq_len_total, 512)


In [26]:
class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=4, ctx_init=None, device='cuda'):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        vis_dim = clip_model.visual.output_dim
        
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(torch.float32)
            ctx_vectors = embedding[0, 1:1+n_ctx, :]
            prompt_prefix = ctx_init
        else:
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=torch.float32)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        
        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words: {n_ctx}")
        self.ctx = nn.Parameter(ctx_vectors)
        
        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
        ]))
        
        classnames = [name.replace("_", " ") for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        tokenized_prompts = tokenized_prompts.to(device)  # FIX: Move to device
        
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(torch.float32)
        
        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :])
        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
    
    def construct_prompts(self, ctx, prefix, suffix, label=None):
        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]
        prompts = torch.cat([prefix, ctx, suffix], dim=1)
        return prompts
    
    def forward(self, im_features):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx.unsqueeze(0)
        bias = self.meta_net(im_features)
        bias = bias.unsqueeze(1)
        ctx_shifted = ctx + bias
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)
            prompts.append(pts_i)
        prompts = torch.stack(prompts)
        return prompts

## CoCoOpTrainer: Training and Evaluation

Class that manages:

**1. Initialization:**
- Create PromptLearner
- Freeze CLIP (`requires_grad=False`)
- Configure SGD optimizer for prompt learner only

**2. train_epoch():**
- Forward: Image encoder + PromptLearner + Text encoder
- **Critical step:** Encode soft prompts through text transformer
  - Add positional embeddings
  - Pass through CLIP's transformer
  - Extract first token
  - Apply final layer norm + projection
- Compute loss: Cross-entropy on base classes
- Backward: Backprop only in PromptLearner
- Return: Average loss of the epoch

**3. eval():**
- Same forward procedure as training
- Without backward pass
- Compute accuracy on any dataset (base or novel)

**Important note:** We don't use `model.encode_text()` on soft prompts
because that method expects integer tokens, not embeddings.
We manually forward through the text transformer.

In [27]:
class CustomCLIP(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=4, ctx_init=None, device='cuda'):
        super().__init__()
        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx=n_ctx, ctx_init=ctx_init, device=device)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
    
    def forward(self, image, label=None): #now it's parallel -> more efficient
        # encode images
        logit_scale = self.logit_scale.exp()
        image_features = self.image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # generate instance-conditioned prompts: (batch, n_cls, n_tokens, dim)
        prompts = self.prompt_learner(image_features)
        batch_size = int(prompts.shape[0])
        n_cls = int(self.prompt_learner.n_cls)

        # Flatten prompts to (batch*n_cls, n_tokens, dim) for parallel encoding
        n_tokens = prompts.shape[2]
        dim = prompts.shape[3]
        prompts_flat = prompts.reshape(batch_size * n_cls, n_tokens, dim).type(self.dtype)

        # Repeat tokenized prompts for each image in the batch
        tokenized = self.tokenized_prompts.to(prompts_flat.device)
        tokenized_expanded = tokenized.repeat(batch_size, 1)

        # Encode all prompts in parallel and reshape back: (batch, n_cls, dim)
        text_features = self.text_encoder(prompts_flat, tokenized_expanded)
        text_features = text_features.reshape(batch_size, n_cls, -1)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Compute logits: (batch, n_cls)
        image_features_expanded = image_features.unsqueeze(1)
        logits = logit_scale * (image_features_expanded @ text_features.transpose(1, 2)).squeeze(1)

        if label is not None:
            return F.cross_entropy(logits, label)
        return logits

## Training CoCoOp

We will train the PromptLearner for **10 epochs** on **base classes only**.

**Hyperparameters:**
- Learning rate: 0.002 (SGD)
- Momentum: 0.9
- Weight decay: 5e-4
- Batch size: 1
- Epochs: 10

**What happens:**
- Context vectors V1...VM adapt to the Flowers102 dataset
- MetaNetwork learns to generate useful conditional tokens
- CLIP remains frozen (unchanged)

**Expected output:**
- Initial loss: ~3.0
- Final loss: ~1.3-1.5
- Training time: ~5-10 minutes on GPU

In [28]:
class CoCoOpTrainer:
    def __init__(self, clip_model, classnames, base_classes, novel_classes, 
                 device='cuda', lr=0.002, n_ctx=4, num_epochs=10):
        
        self.clip_model = clip_model.float()
        self.classnames = classnames
        self.base_classes = base_classes
        self.device = device
        self.num_epochs = num_epochs
        self.contig_cat2idx = {cat: idx for idx, cat in enumerate(self.base_classes)}
        
        for p in self.clip_model.parameters():
            p.requires_grad = False
        
        self.model = CustomCLIP(self.clip_model, classnames, n_ctx=n_ctx, device=device).to(device)
        self.optimizer = torch.optim.SGD(self.model.prompt_learner.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs)
        
        trainable = sum(p.numel() for p in self.model.prompt_learner.parameters())
        print(f"\nCoCoOpTrainer initialized: {trainable:,} trainable params\n")
    
    def train_epoch(self, train_dataset, batch_size=4):
        self.model.train()
        self.clip_model.eval()
        
        dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False
        )
        
        total_loss = 0
        n_batches = 0
        
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc="Training")):
            images = images.to(self.device).float()
            labels = labels.to(self.device)
            
            labels_mapped = torch.tensor(
                [self.contig_cat2idx[l.item()] for l in labels],
                dtype=torch.long,
                device=self.device
            )
            
            loss = self.model(images, labels_mapped)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
        
        self.scheduler.step()
        return total_loss / max(1, n_batches)
    
    @torch.no_grad()
    def eval(self, dataset, categories, batch_size=64):
        self.model.eval()
        self.clip_model.eval()
        
        contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False
        )
        
        correct = 0
        total = 0
        
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(self.device).float()
            labels = labels.to(self.device)
            
            logits = self.model(images)
            pred = logits.argmax(dim=1)
            
            labels_mapped = torch.tensor(
                [contig_cat2idx[l.item()] for l in labels],
                dtype=torch.long,
                device=self.device
            )
            
            correct += (pred == labels_mapped).sum().item()
            total += labels.size(0)
        
        return correct / total if total > 0 else 0.0

In [29]:
#Loading data and model
model, preprocess = clip.load("ViT-B/16", device=device)

train_set, val_set, test_set = get_data(transform=preprocess)
base_classes, novel_classes = base_novel_categories(train_set)

shots_per_class = 16
random.seed(42)
torch.manual_seed(42)

indices_per_class = {c: [] for c in base_classes}
for idx, label in enumerate(train_set._labels):
    if label in indices_per_class:
        indices_per_class[label].append(idx)

selected = []
for c in base_classes:
    inds = indices_per_class.get(c, [])
    random.shuffle(inds)
    selected.extend(inds[:shots_per_class])

train_base = torch.utils.data.Subset(train_set, selected)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

print(f"Train Base: {len(train_base)} | Test Base: {len(test_base)} | Test Novel: {len(test_novel)}")

Train Base: 510 | Test Base: 2473 | Test Novel: 3676


In [30]:
base_classnames = [CLASS_NAMES[i] for i in base_classes]
novel_classnames = [CLASS_NAMES[i] for i in novel_classes]

trainer = CoCoOpTrainer(
    clip_model=model,
    classnames=base_classnames,
    base_classes=base_classes,
    novel_classes=novel_classes,
    device=device,
    lr=0.002,
    n_ctx=4,
    num_epochs=10
)

print("\n" + "="*70)
print("TRAINING CoCoOp")
print("="*70)

for epoch in range(10):
    avg_loss = trainer.train_epoch(train_base, batch_size=4)
    print(f"Epoch {epoch+1}/10 - Loss: {avg_loss:.4f}")

Initial context: "X X X X"
Number of context words: 4

CoCoOpTrainer initialized: 35,360 trainable params


TRAINING CoCoOp


Training: 100%|██████████| 128/128 [02:11<00:00,  1.02s/it]


Epoch 1/10 - Loss: 1.0291


Training: 100%|██████████| 128/128 [02:10<00:00,  1.02s/it]


Epoch 2/10 - Loss: 0.5585


Training: 100%|██████████| 128/128 [02:10<00:00,  1.02s/it]


Epoch 3/10 - Loss: 0.4377


Training: 100%|██████████| 128/128 [02:10<00:00,  1.02s/it]


Epoch 4/10 - Loss: 0.3555


Training: 100%|██████████| 128/128 [02:09<00:00,  1.02s/it]


Epoch 5/10 - Loss: 0.3048


Training: 100%|██████████| 128/128 [02:10<00:00,  1.02s/it]


Epoch 6/10 - Loss: 0.2562


Training: 100%|██████████| 128/128 [02:09<00:00,  1.02s/it]


Epoch 7/10 - Loss: 0.2340


Training: 100%|██████████| 128/128 [02:10<00:00,  1.02s/it]


Epoch 8/10 - Loss: 0.2191


Training: 100%|██████████| 128/128 [02:09<00:00,  1.01s/it]


Epoch 9/10 - Loss: 0.2101


Training: 100%|██████████| 128/128 [02:10<00:00,  1.02s/it]

Epoch 10/10 - Loss: 0.2029





## Final Evaluation (CoCoOp only)

We'll evaluate the model with:
1. Test Base
2. Test Novel

Computing Harmonic Mean between them to evaluate the trade-off.


In [31]:
print("\n" + "="*70)
print("EVALUATION")
print("="*70)

base_acc = trainer.eval(test_base, base_classes, batch_size=64)


EVALUATION


Evaluating: 100%|██████████| 39/39 [05:02<00:00,  7.75s/it]


In [32]:
#Evaluation for novel classes with in-place class swapping
@torch.no_grad()
def evaluate_novel_inplace(trainer, test_dataset, novel_classnames, novel_classes_ids, device='cuda'):
    print(f"Swapping class definitions to {len(novel_classnames)} novel classes (In-Place)...")
    
    model = trainer.model
    prompt_learner = model.prompt_learner
    
    # 1. SALVIAMO LO STATO ORIGINALE (Base Classes)
    # Per poterlo ripristinare se servisse (opzionale, ma buona pratica)
    old_n_cls = prompt_learner.n_cls
    old_token_prefix = prompt_learner.token_prefix
    old_token_suffix = prompt_learner.token_suffix
    old_tokenized_prompts = model.tokenized_prompts
    
    # 2. GENERIAMO I NUOVI EMBEDDING DI TESTO (Novel Classes)
    # Tokenizziamo i nuovi nomi
    # Nota: la logica di replace("_", " ") è già nei nomi passati o la facciamo qui
    clean_names = [name.replace("_", " ") for name in novel_classnames]
    prompts = [prompt_learner.ctx_init + " " + name + "." if hasattr(prompt_learner, 'ctx_init') and prompt_learner.ctx_init else "X " * prompt_learner.n_ctx + name + "." for name in clean_names]
    
    # Ricostruiamo il prompt template standard usato in PromptLearner
    # PromptLearner usa: "X X X X classname."
    # Dobbiamo replicare la logica esatta per ottenere prefix e suffix corretti
    # Recuperiamo il prefisso "X X X X" (dummy) usato per l'inizializzazione
    n_ctx = prompt_learner.n_ctx
    dummy_ctx = " ".join(["X"] * n_ctx)
    prompts = [dummy_ctx + " " + name + "." for name in clean_names]
    
    new_tokenized = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
    
    # Otteniamo gli embedding dal Text Encoder di CLIP
    with torch.no_grad():
        embedding = trainer.clip_model.token_embedding(new_tokenized).type(trainer.clip_model.dtype)
    
    # 3. SOVRASCRIVIAMO I BUFFER DEL MODELLO
    # PromptLearner ha bisogno di prefix e suffix per "incastrare" i vettori appresi nel mezzo
    new_token_prefix = embedding[:, :1, :]           # [n_cls, 1, dim]
    new_token_suffix = embedding[:, 1+n_ctx:, :]     # [n_cls, len-1-n_ctx, dim]
    
    prompt_learner.register_buffer("token_prefix", new_token_prefix)
    prompt_learner.register_buffer("token_suffix", new_token_suffix)
    prompt_learner.n_cls = len(novel_classnames)
    prompt_learner.tokenized_prompts = new_tokenized
    model.tokenized_prompts = new_tokenized # Aggiorniamo anche il riferimento nel modello padre
    
    print("Class definitions swapped. Starting evaluation...")
    
    # 4. VALUTAZIONE
    model.eval()
    dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=16, shuffle=False, num_workers=0 # Batch basso per sicurezza
    )
    
    correct = 0
    total = 0
    
    # Mappa: ID Originale (51) -> Indice Locale (0)
    target_map = {original_id: idx for idx, original_id in enumerate(novel_classes_ids)}
    
    for images, labels in tqdm(dataloader, desc="Evaluating Novel"):
        images = images.to(device)
        
        logits = model(images) # Ora genera logits per le classi Novel!
        pred = logits.argmax(dim=1)
        
        # Mapping labels
        labels_cpu = labels.tolist()
        try:
            mapped_labels = torch.tensor([target_map[l] for l in labels_cpu], device=device)
            correct += (pred == mapped_labels).sum().item()
            total += labels.size(0)
        except KeyError:
            continue
            
    acc = correct / total if total > 0 else 0.0
    
    # 5. RIPRISTINO (Opzionale, se vuoi riusare il trainer per le base classes dopo)
    prompt_learner.register_buffer("token_prefix", old_token_prefix)
    prompt_learner.register_buffer("token_suffix", old_token_suffix)
    prompt_learner.n_cls = old_n_cls
    prompt_learner.tokenized_prompts = old_tokenized_prompts
    model.tokenized_prompts = old_tokenized_prompts
    
    return acc

# === ESECUZIONE ===
# Assicuriamoci di liberare memoria spazzatura prima di iniziare
import gc
torch.cuda.empty_cache()
gc.collect()

novel_acc = evaluate_novel_inplace(
    trainer, 
    test_novel, 
    novel_classnames, # Lista nomi ['rose', 'tulip'...]
    novel_classes,    # Lista ID [51, 52...] definita nel tuo notebook
    device=device
)

print(f"\nCorrected Novel Accuracy: {novel_acc*100:.2f}%")

Swapping class definitions to 51 novel classes (In-Place)...
Class definitions swapped. Starting evaluation...


Evaluating Novel: 100%|██████████| 230/230 [07:17<00:00,  1.90s/it]


Corrected Novel Accuracy: 73.50%





In [33]:
hm = harmonic_mean(base_acc, novel_acc)

print("\n" + "="*70)
print("RESULTS")
print("="*70)
print(f"  Base Accuracy:  {base_acc*100:6.2f}%")
print(f"  Novel Accuracy: {novel_acc*100:6.2f}%")
print(f"  Harmonic Mean:  {hm*100:6.2f}%")
print("="*70)


RESULTS
  Base Accuracy:   92.03%
  Novel Accuracy:  73.50%
  Harmonic Mean:   81.73%


## Test applicazione reale -> tutte le classi insieme

In [34]:
@torch.no_grad()
def evaluate_generalized(trainer, test_dataset, all_classnames, all_class_ids, device='cuda'):
    print(f"Evaluating on ALL {len(all_classnames)} classes simultaneously (Generalized Setting)...")
    
    # 1. Setup del modello con TUTTE le classi (Base + Novel)
    model = trainer.model
    prompt_learner = model.prompt_learner
    
    # Salviamo stato vecchio
    old_n_cls = prompt_learner.n_cls
    old_token_prefix = prompt_learner.token_prefix
    old_token_suffix = prompt_learner.token_suffix
    old_tokenized = model.tokenized_prompts

    # 2. Creiamo i prompt per TUTTE le classi (0..101)
    clean_names = [name.replace("_", " ") for name in all_classnames]
    
    # Ricostruiamo i prompt dummy
    n_ctx = prompt_learner.n_ctx
    dummy_ctx = " ".join(["X"] * n_ctx)
    prompts = [dummy_ctx + " " + name + "." for name in clean_names]
    
    new_tokenized = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
    
    with torch.no_grad():
        embedding = trainer.clip_model.token_embedding(new_tokenized).type(trainer.clip_model.dtype)
    
    # Aggiorniamo i buffer
    prompt_learner.register_buffer("token_prefix", embedding[:, :1, :])
    prompt_learner.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :])
    prompt_learner.n_cls = len(all_classnames)
    prompt_learner.tokenized_prompts = new_tokenized
    model.tokenized_prompts = new_tokenized

    # 3. Valutazione
    model.eval()
    # Usiamo il dataset di test COMPLETO (Base + Novel) se possibile, o un subset
    dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    correct = 0
    total = 0
    
    # Mappa: ID Originale (0..101) -> Indice nel modello (0..101)
    # Essendo "tutte" le classi ordinate, la mappa è solitamente identica (0->0, 101->101)
    # Ma per sicurezza usiamo gli ID passati
    target_map = {original_id: idx for idx, original_id in enumerate(all_class_ids)}
    
    for images, labels in tqdm(dataloader, desc="Eval Generalized"):
        images = images.to(device)
        labels = labels.to(device)
        
        logits = model(images) # Output shape: [Batch, 102]
        pred = logits.argmax(dim=1)
        
        # Le label qui arrivano come ID originali (es. 0, 55, 101)
        # Dobbiamo assicurarci che corrispondano agli indici del modello
        mapped_labels = torch.tensor([target_map[l.item()] for l in labels], device=device)
        
        correct += (pred == mapped_labels).sum().item()
        total += labels.size(0)
        
    # Ripristino
    prompt_learner.register_buffer("token_prefix", old_token_prefix)
    prompt_learner.register_buffer("token_suffix", old_token_suffix)
    prompt_learner.n_cls = old_n_cls
    prompt_learner.tokenized_prompts = old_tokenized
    model.tokenized_prompts = old_tokenized
    
    return correct / total

# === UTILIZZO ===
# Uniamo le liste
all_names = base_classnames + novel_classnames
all_ids = list(base_classes) + list(novel_classes)

# Uniamo i dataset di test (Base + Novel) per fare un test unico "Reale"
full_test_set = torch.utils.data.ConcatDataset([test_base, test_novel])

acc_generalized = evaluate_generalized(trainer, full_test_set, all_names, all_ids)
print(f"Generalized Accuracy (Base + Novel mixed): {acc_generalized*100:.2f}%")

Evaluating on ALL 102 classes simultaneously (Generalized Setting)...


Eval Generalized: 100%|██████████| 193/193 [23:14<00:00,  7.23s/it]

Generalized Accuracy (Base + Novel mixed): 72.39%



