# CLIP with Flowers!?!?!??!?

In [1]:
import os
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

import sys
import torch
import torchvision
import random
import gc
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import OrderedDict

In [2]:
try:
    import clip
    print("✓ CLIP already installed")
except Exception:
    print("Installing CLIP...")
    import subprocess, importlib
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", 
                              "git+https://github.com/openai/CLIP.git"])
    importlib.invalidate_caches()
    import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"PyTorch: {torch.__version__}")

✓ CLIP already installed
Device: cpu
PyTorch: 2.9.1+cpu


  from pkg_resources import packaging


## Dataset Functions

We define utility functions for:
- **`get_data()`**: Load Flowers102 from torchvision
- **`base_novel_categories()`**: Split 102 classes into base (0-50) and novel (51-101)
- **`split_data()`**: Filter images for base/novel in each split

This simulates the real scenario: we have 51 seen classes during training (base) and 51 new ones (novel).


In [3]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets."""
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test


def base_novel_categories(dataset):
    """Return base and novel class id lists using the actual labels present
    in the dataset. Prefer public attributes (`targets` then `labels`) and
    only fall back to the dataset private attribute `_labels` if neither is
    available.
    """
    labels = getattr(dataset, "targets", None)
    if labels is None:
        labels = getattr(dataset, "labels", None)

    if labels is None and hasattr(dataset, "_labels"):
        labels = dataset._labels

    if labels is None:
        raise ValueError("Could not find labels on dataset (checked 'targets','labels','_labels').")

    unique_labels = sorted(set(labels))
    num_classes = len(unique_labels)
    mid = num_classes // 2
    base_classes = unique_labels[:mid]
    novel_classes = unique_labels[mid:]
    return base_classes, novel_classes


def split_data(dataset, base_classes):
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Class Names and Dataset Loading

We load the names of 102 flower classes from Flowers102.

This is **critical** for CLIP:
- Creates prompts like "a photo of a **rose**, a type of flower"
- Each prompt is encoded by CLIP's text encoder
- Image features are compared against these text templates


In [4]:
_, _, tmp_test = get_data()
base_classes, novel_classes = base_novel_categories(tmp_test)

CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# Uncomment to see class names
print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

Base Class Names: [(0, 'pink primrose'), (1, 'hard-leaved pocket orchid'), (2, 'canterbury bells'), (3, 'sweet pea'), (4, 'english marigold'), (5, 'tiger lily'), (6, 'moon orchid'), (7, 'bird of paradise'), (8, 'monkshood'), (9, 'globe thistle'), (10, 'snapdragon'), (11, "colt's foot"), (12, 'king protea'), (13, 'spear thistle'), (14, 'yellow iris'), (15, 'globe-flower'), (16, 'purple coneflower'), (17, 'peruvian lily'), (18, 'balloon flower'), (19, 'giant white arum lily'), (20, 'fire lily'), (21, 'pincushion flower'), (22, 'fritillary'), (23, 'red ginger'), (24, 'grape hyacinth'), (25, 'corn poppy'), (26, 'prince of wales feathers'), (27, 'stemless gentian'), (28, 'artichoke'), (29, 'sweet william'), (30, 'carnation'), (31, 'garden phlox'), (32, 'love in the mist'), (33, 'mexican aster'), (34, 'alpine sea holly'), (35, 'ruby-lipped cattleya'), (36, 'cape flower'), (37, 'great masterwort'), (38, 'siam tulip'), (39, 'lenten rose'), (40, 'barbeton daisy'), (41, 'daffodil'), (42, 'sword 

In [5]:
# Load CLIP model and preprocessing
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)

# --- Data transformation for Augmentation ---
#keeping CLIP normalization values
aug_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    torchvision.transforms.RandomCrop(224),           # Random Crop
    torchvision.transforms.RandomHorizontalFlip(p=0.5), 
    torchvision.transforms.RandomRotation(15),        # smooth rotation
    torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Color Jitter
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.481, 0.457, 0.408), (0.268, 0.261, 0.275)) # Mean/Std of CLIP
])

print(f"Device: {device}")
print(f"Model: ViT-B/16")
print("Augmentation pipeline defined.")

Device: cpu
Model: ViT-B/16
Augmentation pipeline defined.


## Prototype-Based Representations

We construct **class prototypes** from CLIP image embeddings of training samples.

**Key Design Choices:**
- Use **frozen CLIP** (not the adapted model) to preserve zero-shot knowledge
- Compute prototypes from **both normal and augmented samples** for better coverage
- **L2-normalize** embeddings before averaging and after

**At Inference:**
- Compute prototype similarity: $\text{sim}_{\text{proto}}(x, c) = \frac{f(x) \cdot p_c}{\|f(x)\| \|p_c\|}$
- Fuse with CoCoOp logits: $\text{logits}_{\text{final}} = \alpha \cdot \text{logits}_{\text{CoCoOp}} + (1-\alpha) \cdot \text{logits}_{\text{proto}}$
- The fusion weight $\alpha$ controls the trade-off between prompt-based and prototype-based predictions

In [6]:
@torch.no_grad()
def build_prototypes(clip_model, train_dataset, base_classes, device='cuda'):
    """
    Extract CLIP image embeddings and compute mean prototype per class.
    Uses the ORIGINAL frozen CLIP to preserve zero-shot knowledge.
    
    Args:
        clip_model: Frozen CLIP model
        train_dataset: Training dataset (can include augmented samples)
        base_classes: List of base class IDs
        device: Device to use
    
    Returns:
        prototypes: Dict mapping class_id -> normalized prototype tensor (shape: [dim])
        prototype_matrix: Tensor of shape [num_classes, dim] for efficient inference
    """
    clip_model.eval()
    
    # Collect embeddings per class
    embeddings_per_class = {c: [] for c in base_classes}
    
    dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=64, shuffle=False, num_workers=0
    )
    
    print(f"Extracting embeddings from {len(train_dataset)} samples...")
    
    for images, labels in tqdm(dataloader, desc="Building Prototypes"):
        images = images.to(device)
        
        # Get CLIP image features
        features = clip_model.encode_image(images)
        features = features / features.norm(dim=-1, keepdim=True)  # L2 normalize
        
        for feat, label in zip(features, labels):
            label_id = label.item()
            if label_id in embeddings_per_class:
                embeddings_per_class[label_id].append(feat.cpu())
    
    # Compute mean prototype per class
    prototypes = {}
    for cls_id in base_classes:
        if embeddings_per_class[cls_id]:
            class_embeddings = torch.stack(embeddings_per_class[cls_id])
            prototype = class_embeddings.mean(dim=0).to(device)
            prototype = prototype / prototype.norm()  # Re-normalize after averaging
            prototypes[cls_id] = prototype
    
    # Create matrix for efficient inference (ordered by base_classes)
    prototype_matrix = torch.stack([prototypes[c] for c in base_classes]).to(device)
    
    print(f"✓ Built {len(prototypes)} prototypes | Matrix shape: {prototype_matrix.shape}")
    
    return prototypes, prototype_matrix

## Load Flowers102 and Split Base/Novel

We load the 3 splits (train, val, test) and divide into base/novel.

**Statistics:**
- Train Base: 10 images × 51 classes = 510 images
- Val Base: 10 images × 51 classes = 510 images
- Test Base: ~10 images × 51 classes (from test split)
- Test Novel: Remaining (~10 per class)

**Note:** Train and val have ~10 images per class (few-shot setting).


In [None]:
# --- STEP 1: Raw Data Retrieval ---
# Note: Flowers102 has 10 imgs per class in train set. We use only train for 10 shots.
# Augmented samples are used ONLY for building prototypes, not for training.

# Helper to load with a specific transformation
def load_split(split, transform):
    return torchvision.datasets.Flowers102(root="./data", split=split, download=True, transform=transform)

# Load "Normal" sets (Standard CLIP preprocess)
train_set_norm = load_split("train", preprocess)
val_set_norm = load_split("val", preprocess)
test_set = load_split("test", preprocess)

# Load "Augmented" set (only for prototype building)
train_set_aug = load_split("train", aug_transform)

# --- STEP 2: Base/Novel Class Split ---
base_classes, novel_classes = base_novel_categories(train_set_norm)

# --- STEP 3: Few-Shot Selection (10 Real Shots from Train Set) ---
shots_per_class = 10
import random
random.seed(42)

# Collect all available indices for each base class in the train set
indices_per_class = {c: [] for c in base_classes}
for idx, label in enumerate(train_set_norm._labels):
    if label in base_classes:
        indices_per_class[label].append(idx)

selected_indices = []
for c in base_classes:
    inds = indices_per_class.get(c, [])
    random.shuffle(inds)
    # Take 10 shots from train set
    selected_indices.extend(inds[:shots_per_class])

# --- STEP 4: Create Datasets ---
# Training dataset: ONLY normal samples (no augmentation)
train_base = torch.utils.data.Subset(train_set_norm, selected_indices)

# Prototype dataset: Normal + Augmented samples (for richer prototype representations)
subset_normal = torch.utils.data.Subset(train_set_norm, selected_indices)
subset_augmented = torch.utils.data.Subset(train_set_aug, selected_indices)
prototype_dataset = torch.utils.data.ConcatDataset([subset_normal, subset_augmented])

# Validation and Test sets
val_base, _ = split_data(test_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

print(f"Dataset Created Successfully!")
print(f"Base Classes: {len(base_classes)} | Shots: {shots_per_class}")
print(f"-> Training Set: {len(train_base)} samples (normal only)")
print(f"-> Prototype Set: {len(prototype_dataset)} samples (normal + augmented)")
print(f"-> Val Base: {len(val_base)} | Test Base: {len(test_base)} | Test Novel: {len(test_novel)}")

Dataset Created Successfully!
Base Classes: 51 | Shots: 16
Normal Subset Size: 816
Augmented Subset Size: 816
-> TOTAL Train Base: 1632 samples (Should be 1632)


## Harmonic Mean (HM)

Standard metric for few-shot adaptation papers.

Formula: HM = 2 / (1/base_acc + 1/novel_acc)

**Why HM instead of arithmetic mean?**
- HM heavily penalizes outliers
- If base=90% and novel=50%: arithmetic mean=70%, HM=64.3%
- Forces the model to balance both accuracies

**Obiettivo:** massimizzare l'HM tra `base_acc_cocoop` e `novel_acc_cocoop`.


In [8]:
def harmonic_mean(base_accuracy, novel_accuracy):
    # Guard against zero to avoid division-by-zero errors
    if base_accuracy <= 0 or novel_accuracy <= 0:
        return 0.0
    numerator = 2.0
    denominator = 1.0 / base_accuracy + 1.0 / novel_accuracy
    return numerator / denominator


## Text Encoder

In [9]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(self.dtype)
        x = x[torch.arange(int(x.shape[0])), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x

## CoCoOpPromptLearner: Dynamic Prompts (Optimized)

**Components:**
1. **Context Vectors (V):** 16 vectors (learnable).
   - Shape: `(16, 512)`
   - Initialized: Gaussian noise N(0, 0.02)
   - Function: Provide the base context for the prompt.

2. **Meta-Network (Bias Generator):**
   - Architecture: Linear(512->32) -> ReLU -> Linear(32->512)
   - Input: Image Features `(Batch, 512)`
   - Output: Bias `(Batch, 512)` added to Context Vectors.
   - **Note:** Unlike the paper's simplified notation "$\pi$", we implement this as an **additive bias** to the context vectors.

3. **Class Embeddings:**
   - Pre-computed embeddings for "[CLASS] + EOS".
   - Fixed during training.

**Forward Pass (Vectorized):**
Instead of looping through images, we broadcast tensors to shape `(Batch, Num_Classes, Sequence_Length, Dim)`:
1. **Compute Bias:** $Bias = MetaNet(Image)$
2. **Shift Context:** $Ctx_{new} = Ctx_{base} + Bias$ (Broadcasting over classes)
3. **Concatenate:** $[Prefix] + [Ctx_{new}] + [Suffix]$ (All in parallel)

In [10]:
class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=16, ctx_init=None, device='cuda'):
        super().__init__()
        n_cls = len(classnames)
        # Get embedding dimension from CLIP's final layer
        ctx_dim = clip_model.ln_final.weight.shape[0]
        vis_dim = clip_model.visual.output_dim
        
        # 1. Context Vectors Initialization
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            
            # --- FIX: Move tokenized prompt to correct device ---
            prompt = clip.tokenize(ctx_init).to(device) 
            
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(torch.float16)
            
            # embedding[0] because tokenize adds a batch dim
            ctx_vectors = embedding[0, 1:1+n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # Random initialization standard (Sigma=0.02)
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=torch.float16)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        
        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words: {n_ctx}")
        self.ctx = nn.Parameter(ctx_vectors)
        
        # 2. Meta-Network (Less Aggressive Bottleneck for better Generalization)
        # Trying to increase hidden dim from 32 (//16) to 128 (//4) to prevent underfitting novel classes
        hidden_dim = vis_dim // 4
        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, hidden_dim)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(hidden_dim, ctx_dim))
        ]))
        
        # 3. Pre-computing Class Names (Prefix/Suffix)
        classnames = [name.replace("_", " ") for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        
        # Tokenize and get embeddings
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        tokenized_prompts = tokenized_prompts.to(device)
        
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(torch.float16)
        
        # Save Prefix (SOS) and Suffix (Class Name + EOS) as fixed buffers
        self.register_buffer("token_prefix", embedding[:, :1, :])      # (n_cls, 1, dim)
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :]) # (n_cls, len, dim)
        
        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
    
    def forward(self, im_features):
        batch_size = im_features.shape[0]
        
        # 1. Calculate bias from image
        bias = self.meta_net(im_features)
        bias = bias.unsqueeze(1) # (Batch, 1, 512)
        
        # 2. Generate shifted context
        ctx = self.ctx.unsqueeze(0) # (1, n_ctx, dim)
        ctx_shifted = ctx + bias    # (Batch, n_ctx, dim)
        
        # 3. Parallel prompt construction
        prefix = self.token_prefix.unsqueeze(0).expand(batch_size, -1, -1, -1)
        ctx_expanded = ctx_shifted.unsqueeze(1).expand(-1, self.n_cls, -1, -1)
        suffix = self.token_suffix.unsqueeze(0).expand(batch_size, -1, -1, -1)
        
        prompts = torch.cat(
            [prefix, ctx_expanded, suffix],
            dim=2
        )
        
        return prompts

## CoCoOpTrainer: Training and Evaluation (with Prototype Fusion)

Class that manages:

**1. Initialization:**
- Create PromptLearner
- Freeze CLIP (`requires_grad=False`)
- Configure SGD optimizer for prompt learner only

**2. train_epoch():**
- Forward: Image encoder + PromptLearner + Text encoder
- **Critical step:** Encode soft prompts through text transformer
  - Add positional embeddings
  - Pass through CLIP's transformer
  - Extract first token
  - Apply final layer norm + projection
- Compute loss: Cross-entropy on base classes
- Backward: Backprop only in PromptLearner
- Return: Average loss of the epoch

**3. eval() with Prototype Fusion:**
- Same forward procedure as training
- **NEW:** Optionally fuse CoCoOp logits with prototype similarity scores
- Fusion formula: $\text{logits} = \alpha \cdot \text{logits}_{\text{CoCoOp}} + (1-\alpha) \cdot \text{logits}_{\text{prototype}}$
- Compute accuracy on any dataset (base or novel)

**Important note:** We don't use `model.encode_text()` on soft prompts
because that method expects integer tokens, not embeddings.
We manually forward through the text transformer.

In [11]:
class CustomCLIP(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=4, ctx_init=None, device='cuda'):
        super().__init__()
        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx=n_ctx, ctx_init=ctx_init, device=device)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        
        # Prototype fusion parameters
        self.prototype_matrix = None  # Will be set externally
        self.alpha = 0.5  # Fusion weight: alpha * CoCoOp + (1-alpha) * Prototype
    
    def set_prototypes(self, prototype_matrix, alpha=0.5):
        """Set prototype matrix for fusion at inference time."""
        self.prototype_matrix = prototype_matrix.type(self.dtype)
        self.alpha = alpha
        print(f"✓ Prototypes set | Alpha (CoCoOp weight): {alpha}")
    
    def forward(self, image, label=None, use_prototypes=False):
        """
        Forward pass with optional prototype fusion.
        
        Args:
            image: Input images
            label: Labels for training (returns loss if provided)
            use_prototypes: Whether to fuse prototype logits at inference
        """
        # encode images
        logit_scale = self.logit_scale.exp()
        image_features = self.image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # generate instance-conditioned prompts: (batch, n_cls, n_tokens, dim)
        prompts = self.prompt_learner(image_features)
        batch_size = int(prompts.shape[0])
        n_cls = int(self.prompt_learner.n_cls)

        # Flatten prompts to (batch*n_cls, n_tokens, dim) for parallel encoding
        n_tokens = prompts.shape[2]
        dim = prompts.shape[3]
        prompts_flat = prompts.reshape(batch_size * n_cls, n_tokens, dim).type(self.dtype)

        # Repeat tokenized prompts for each image in the batch
        tokenized = self.tokenized_prompts.to(prompts_flat.device)
        tokenized_expanded = tokenized.repeat(batch_size, 1)

        # Encode all prompts in parallel and reshape back: (batch, n_cls, dim)
        text_features = self.text_encoder(prompts_flat, tokenized_expanded)
        text_features = text_features.reshape(batch_size, n_cls, -1)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Compute CoCoOp logits: (batch, n_cls)
        image_features_expanded = image_features.unsqueeze(1)
        cocoop_logits = logit_scale * (image_features_expanded @ text_features.transpose(1, 2)).squeeze(1)

        # Prototype fusion
        if use_prototypes and self.prototype_matrix is not None:
            # Compute prototype logits: (batch, n_cls)
            prototype_logits = logit_scale * (image_features @ self.prototype_matrix.T)
            
            # Fused logits
            logits = self.alpha * cocoop_logits + (1 - self.alpha) * prototype_logits
        else:
            logits = cocoop_logits

        if label is not None:
            return F.cross_entropy(logits, label)
        return logits

## Training CoCoOp (Optimized)

We will train the PromptLearner for **10 epochs** on **base classes only**.

**Hyperparameters (Optimized):**
- **Context Length (`n_ctx`):** 16 (Increased capacity for fine-grained details)
- **Batch size:** 4 (Increased from 1 thanks to parallelization)
- **Learning rate:** 0.002 (SGD)
- **Momentum:** 0.9
- **Weight decay:** 5e-4
- **Epochs:** 10

**What happens:**
- The `PromptLearner` adapts its 16 context vectors to the Flowers102 dataset.
- The `MetaNetwork` learns to inject image-specific bias efficiently.
- **Optimization:** We use a GPU-based label lookup table to speed up target mapping.

**Expected output:**
- Initial loss: ~2.5 - 3.5
- Final loss: ~0.5 - 1.0 (Lower than before due to better context capacity)
- Training time: ~2-4 minutes on GPU

In [12]:
class CoCoOpTrainer:
    def __init__(self, clip_model, classnames, base_classes, novel_classes, 
                 device='cuda', lr=0.002, n_ctx=16, num_epochs=10, ctx_init=None): # <--- NEW PARAMETER
        
        self.clip_model = clip_model.float()
        self.classnames = classnames
        self.base_classes = base_classes
        self.device = device
        self.num_epochs = num_epochs
        
        # --- LABEL MAPPING OPTIMIZATION (GPU Lookup) ---
        # We need to map original dataset labels (e.g., 0, 55, 101) to local indices (0..N)
        max_label_id = max(max(base_classes), max(novel_classes)) + 1
        self.label_map = torch.full((max_label_id,), -1, dtype=torch.long, device=device)
        
        base_ids_tensor = torch.tensor(base_classes, device=device, dtype=torch.long)
        target_indices = torch.arange(len(base_classes), device=device, dtype=torch.long)
        self.label_map[base_ids_tensor] = target_indices
        
        # Freeze CLIP
        for p in self.clip_model.parameters():
            p.requires_grad = False
        
        # Initialize Custom Model (Passing ctx_init)
        print(f"Initializing CustomCLIP with ctx_init='{ctx_init}'...")
        self.model = CustomCLIP(self.clip_model, classnames, n_ctx=n_ctx, ctx_init=ctx_init, device=device).to(device)
        
        # Optimizer
        self.optimizer = torch.optim.SGD(
            self.model.prompt_learner.parameters(), 
            lr=lr, 
            momentum=0.9, 
            weight_decay=5e-4
        )
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs)
        
        trainable = sum(p.numel() for p in self.model.prompt_learner.parameters())
        print(f"\nCoCoOpTrainer initialized: {trainable:,} trainable params")
        print(f"Context: {n_ctx} | Gradient Accumulation Enabled | Init: {ctx_init if ctx_init else 'Random'}")
    
    def train_epoch(self, train_dataset, batch_size=4, accumulation_steps=4):
        """
        Runs a training epoch with Gradient Accumulation.
        -> Effective Batch Size = batch_size * accumulation_steps (e.g., 4 * 4 = 16)
        """
        self.model.train()
        self.clip_model.eval()
        
        dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False
        )
        
        total_loss = 0
        n_batches = 0
        
        self.optimizer.zero_grad()
        
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc="Training (Grad Accum)")):
            images = images.to(self.device).float()
            labels = labels.to(self.device)
            
            # Fast mapping using pre-computed GPU tensor
            labels_mapped = self.label_map[labels]
            
            # Forward pass
            loss = self.model(images, labels_mapped)
            
            # --- GRADIENT ACCUMULATION LOGIC ---
            loss = loss / accumulation_steps
            loss.backward()
            
            if (batch_idx + 1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
            
            total_loss += loss.item() * accumulation_steps
            n_batches += 1
        
        # Process remaining gradients if dataloader size is not divisible by accumulation_steps
        if len(dataloader) % accumulation_steps != 0:
             self.optimizer.step()
             self.optimizer.zero_grad()
        
        self.scheduler.step()
        return total_loss / max(1, n_batches)
    
    @torch.no_grad()
    def eval(self, dataset, categories, batch_size=128, use_prototypes=False):
        """
        Evaluate model on dataset.
        
        Args:
            dataset: Dataset to evaluate on
            categories: List of category IDs
            batch_size: Batch size for evaluation
            use_prototypes: Whether to use prototype fusion (requires set_prototypes called first)
        """
        self.model.eval()
        self.clip_model.eval()
        
        local_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=False, num_workers=0
        )
        
        correct = 0
        total = 0
    
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(self.device).float()
            logits = self.model(images, use_prototypes=use_prototypes)
            pred = logits.argmax(dim=1)
            
            labels_cpu = labels.tolist()
            try:
                mapped_targets = torch.tensor(
                    [local_cat2idx[l] for l in labels_cpu], 
                    device=self.device
                )
                correct += (pred == mapped_targets).sum().item()
                total += labels.size(0)
            except KeyError:
                continue
        
        return correct / total if total > 0 else 0.0

In [None]:
#Loading data and model
model, preprocess = clip.load("ViT-B/16", device=device)

# Note: train_base (normal only) and prototype_dataset (normal + augmented) 
# are already created in Cell 10

train_set, val_set, test_set = get_data(transform=preprocess)
base_classes, novel_classes = base_novel_categories(train_set)

val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

print(f"Train Base: {len(train_base)} samples (normal only, for training)")
print(f"Prototype Dataset: {len(prototype_dataset)} samples (normal + augmented, for prototypes)")
print(f"Val Base: {len(val_base)} | Test Base: {len(test_base)} | Test Novel: {len(test_novel)}")

Train Base: 510 | Test Base: 2473 | Test Novel: 3676


In [None]:
# Build class prototypes from frozen CLIP using normal + augmented samples
print("\n" + "="*70)
print("BUILDING CLASS PROTOTYPES (using normal + augmented samples)")
print("="*70)

prototypes, prototype_matrix = build_prototypes(
    clip_model=model,  # Frozen original CLIP
    train_dataset=prototype_dataset,  # Uses both normal and augmented samples
    base_classes=base_classes,
    device=device
)

print(f"Prototype matrix dtype: {prototype_matrix.dtype}")


BUILDING CLASS PROTOTYPES
Extracting embeddings from 510 samples...


Building Prototypes: 100%|██████████| 8/8 [01:34<00:00, 11.78s/it]

✓ Built 51 prototypes | Matrix shape: torch.Size([51, 512])
Prototype matrix dtype: torch.float32





In [15]:
# 1. Configurazione
base_classnames = [CLASS_NAMES[i] for i in base_classes]
novel_classnames = [CLASS_NAMES[i] for i in novel_classes]

os.makedirs("checkpoints", exist_ok=True)

trainer = CoCoOpTrainer(
    clip_model=model,
    classnames=base_classnames,
    base_classes=base_classes,
    novel_classes=novel_classes,
    device=device,
    lr=0.002,          
    n_ctx=16,           #if ctx_init is used, this will be ignored and set to 5
    num_epochs=50,      
    ctx_init="a photo of a flower" 
)

# INTRODUCING EARLY STOPPING
patience = 5         
counter = 0        
best_acc = 0.0
best_epoch = 0

print("\n" + "="*70)
print(f"TRAINING CoCoOp (Smart Init + Early Stopping @ Patience {patience})")
print("="*70)

for epoch in range(trainer.num_epochs):
    # Training
    avg_loss = trainer.train_epoch(train_base, batch_size=4, accumulation_steps=4)
    
    # Validation
    val_acc = trainer.eval(val_base, base_classes, batch_size=64)
    
    print(f"Epoch {epoch+1}/{trainer.num_epochs} - Loss: {avg_loss:.4f} | Val Acc: {val_acc*100:.2f}%", end="")
    
    # EARLY STOPPING and CHECKPOINTING
    if val_acc > best_acc:
        best_acc = val_acc
        best_epoch = epoch + 1
        counter = 0
        
        # saving best model
        state_dict = trainer.model.prompt_learner.state_dict()
        torch.save(state_dict, "checkpoints/best_model.pth")
        print(f"  [★ BEST SAVED] - Counter reset")
    else:
        counter += 1
        print(f"  [No Improv. {counter}/{patience}]")
        
        if counter >= patience:
            print(f"\n⏹ EARLY STOPPING TRIGGERED at epoch {epoch+1}!")
            print(f"La validation accuracy non migliora da {patience} epoche.")
            break

print("="*70)
print(f"Training terminated. Best Val Acc: {best_acc*100:.2f}% at epoch {best_epoch}")

# Reload Best Model
print("\nReloading best model weights for final evaluation...")
best_checkpoint = torch.load("checkpoints/best_model.pth")
trainer.model.prompt_learner.load_state_dict(best_checkpoint)
print("Best model loaded.")

# Set prototypes for inference-time fusion
trainer.model.set_prototypes(prototype_matrix, alpha=0.5)  # Try different alpha values: 0.3, 0.5, 0.7

Initializing CustomCLIP with ctx_init='a photo of a flower'...
Initial context: "a photo of a flower"
Number of context words: 5

CoCoOpTrainer initialized: 134,272 trainable params
Context: 16 | Gradient Accumulation Enabled | Init: a photo of a flower

TRAINING CoCoOp (Smart Init + Early Stopping @ Patience 5)


Training (Grad Accum): 100%|██████████| 128/128 [33:28<00:00, 15.69s/it]
Evaluating:  50%|█████     | 4/8 [11:54<11:54, 178.75s/it]


KeyboardInterrupt: 

## Final Evaluation (CoCoOp + Prototype Fusion)

We'll evaluate the model with:
1. **Test Base** - CoCoOp only vs CoCoOp + Prototypes
2. **Test Novel** - CoCoOp only (no prototypes for novel classes)

Computing Harmonic Mean between them to evaluate the trade-off.

**Note:** Prototypes are only available for base classes (built from training data).
For novel classes, we rely solely on CoCoOp's generalization.


In [None]:
print("\n" + "="*70)
print("EVALUATION")
print("="*70)

base_acc = trainer.eval(test_base, base_classes, batch_size=64, use_prototypes=True)
print(f"Base Accuracy: {base_acc*100:.2f}%")


EVALUATION


Evaluating: 100%|██████████| 39/39 [04:53<00:00,  7.52s/it]


In [30]:
# Evaluation for novel classes with in-place class swapping
@torch.no_grad()
def evaluate_novel_inplace(trainer, test_dataset, novel_classnames, novel_classes_ids, device='cuda'):
    print(f"Swapping class definitions to {len(novel_classnames)} novel classes (In-Place)...")
    
    model = trainer.model
    prompt_learner = model.prompt_learner
    
    # 1. SAVE ORIGINAL STATE (Base Classes)
    old_n_cls = prompt_learner.n_cls
    old_token_prefix = prompt_learner.token_prefix
    old_token_suffix = prompt_learner.token_suffix
    old_tokenized_prompts = model.tokenized_prompts
    
    # 2. GENERATE NEW TEXT EMBEDDINGS (Novel Classes)
    # Tokenize new names
    clean_names = [name.replace("_", " ") for name in novel_classnames]
    
    # Reconstruct the standard prompt template used in PromptLearner
    # PromptLearner uses: "X X X X classname."
    # We must replicate the exact logic to get correct prefix and suffix
    n_ctx = prompt_learner.n_ctx
    dummy_ctx = " ".join(["X"] * n_ctx)
    prompts = [dummy_ctx + " " + name + "." for name in clean_names]
    
    new_tokenized = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
    
    # Get embeddings from CLIP's Text Encoder
    with torch.no_grad():
        embedding = trainer.clip_model.token_embedding(new_tokenized).type(trainer.clip_model.dtype)
    
    # 3. OVERWRITE MODEL BUFFERS
    # PromptLearner needs prefix and suffix to "sandwich" the learned vectors in between
    new_token_prefix = embedding[:, :1, :]           # [n_cls, 1, dim]
    new_token_suffix = embedding[:, 1+n_ctx:, :]     # [n_cls, len-1-n_ctx, dim]
    
    prompt_learner.register_buffer("token_prefix", new_token_prefix)
    prompt_learner.register_buffer("token_suffix", new_token_suffix)
    prompt_learner.n_cls = len(novel_classnames)
    prompt_learner.tokenized_prompts = new_tokenized
    model.tokenized_prompts = new_tokenized 
    
    print("Class definitions swapped. Starting evaluation...")
    
    # 4. EVALUATION
    model.eval()
    dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=32, shuffle=False, num_workers=0 # trying a higher batch size
    )
    
    correct = 0
    total = 0
    
    # Mapping from original novel class IDs to local indices (0..N-1)
    target_map = {original_id: idx for idx, original_id in enumerate(novel_classes_ids)}
    
    for images, labels in tqdm(dataloader, desc="Evaluating Novel"):
        images = images.to(device)
        
        logits = model(images) # generating logits for novel classes
        pred = logits.argmax(dim=1)
        
        # Mapping labels
        labels_cpu = labels.tolist()
        try:
            mapped_labels = torch.tensor([target_map[l] for l in labels_cpu], device=device)
            correct += (pred == mapped_labels).sum().item()
            total += labels.size(0)
        except KeyError:
            continue
            
    acc = correct / total if total > 0 else 0.0
    
    # 5. RESTORE ORIGINAL STATE
    prompt_learner.register_buffer("token_prefix", old_token_prefix)
    prompt_learner.register_buffer("token_suffix", old_token_suffix)
    prompt_learner.n_cls = old_n_cls
    prompt_learner.tokenized_prompts = old_tokenized_prompts
    model.tokenized_prompts = old_tokenized_prompts
    
    return acc

# EXECUTION
torch.cuda.empty_cache()
gc.collect()

novel_acc = evaluate_novel_inplace(
    trainer, 
    test_novel, 
    novel_classnames, # Name list ['rose', 'tulip'...]
    novel_classes,    # ID list [51, 52...]
    device=device
)

print(f"\nCorrected Novel Accuracy: {novel_acc*100:.2f}%")

Swapping class definitions to 51 novel classes (In-Place)...
Class definitions swapped. Starting evaluation...


Evaluating Novel: 100%|██████████| 115/115 [06:49<00:00,  3.56s/it]


Corrected Novel Accuracy: 73.86%





In [31]:
hm = harmonic_mean(base_acc, novel_acc)

print("\n" + "="*70)
print("RESULTS")
print("="*70)
print(f"  Base Accuracy:  {base_acc*100:6.2f}%")
print(f"  Novel Accuracy: {novel_acc*100:6.2f}%")
print(f"  Harmonic Mean:  {hm*100:6.2f}%")
print("="*70)


RESULTS
  Base Accuracy:   91.87%
  Novel Accuracy:  73.86%
  Harmonic Mean:   81.89%


## Real-World-Scenario Testing -> base + novel classes at the same time

In [33]:
@torch.no_grad()
def evaluate_generalized(trainer, test_dataset, all_classnames, all_class_ids, device='cuda'):
    print(f"Evaluating on ALL {len(all_classnames)} classes simultaneously (Generalized Setting)...")
    
    # 1. Model Setup with ALL classes (Base + Novel)
    model = trainer.model
    prompt_learner = model.prompt_learner
    
    # Save old state to restore later
    old_n_cls = prompt_learner.n_cls
    old_token_prefix = prompt_learner.token_prefix
    old_token_suffix = prompt_learner.token_suffix
    old_tokenized = model.tokenized_prompts

    # 2. Create prompts for ALL classes (0..101)
    clean_names = [name.replace("_", " ") for name in all_classnames]
    
    # Reconstruct dummy prompts
    n_ctx = prompt_learner.n_ctx
    dummy_ctx = " ".join(["X"] * n_ctx)
    prompts = [dummy_ctx + " " + name + "." for name in clean_names]
    
    new_tokenized = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
    
    with torch.no_grad():
        embedding = trainer.clip_model.token_embedding(new_tokenized).type(trainer.clip_model.dtype)
    
    # Update buffers
    prompt_learner.register_buffer("token_prefix", embedding[:, :1, :])
    prompt_learner.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :])
    prompt_learner.n_cls = len(all_classnames)
    prompt_learner.tokenized_prompts = new_tokenized
    model.tokenized_prompts = new_tokenized

    # 3. Evaluation
    model.eval()
    # Use the COMPLETE test dataset (Base + Novel) if possible, or a subset
    dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    correct = 0
    total = 0
    
    # Map: Original ID (0..101) -> Model Index (0..101)
    # Since "all" classes are ordered, the map is usually identical (0->0, 101->101)
    # But for safety, we use the passed IDs
    target_map = {original_id: idx for idx, original_id in enumerate(all_class_ids)}
    
    for images, labels in tqdm(dataloader, desc="Eval Generalized"):
        images = images.to(device)
        labels = labels.to(device)
        
        logits = model(images) # Output shape: [Batch, 102]
        pred = logits.argmax(dim=1)
        
        # Labels here arrive as Original IDs (e.g. 0, 55, 101)
        # We must ensure they correspond to the model indices
        mapped_labels = torch.tensor([target_map[l.item()] for l in labels], device=device)
        
        correct += (pred == mapped_labels).sum().item()
        total += labels.size(0)
        
    # Restore state
    prompt_learner.register_buffer("token_prefix", old_token_prefix)
    prompt_learner.register_buffer("token_suffix", old_token_suffix)
    prompt_learner.n_cls = old_n_cls
    prompt_learner.tokenized_prompts = old_tokenized
    model.tokenized_prompts = old_tokenized
    
    return correct / total

# Merge lists
all_names = base_classnames + novel_classnames
all_ids = list(base_classes) + list(novel_classes)

# Merge test datasets (Base + Novel) to perform a unique "Real" test
full_test_set = torch.utils.data.ConcatDataset([test_base, test_novel])

acc_generalized = evaluate_generalized(trainer, full_test_set, all_names, all_ids)
print(f"Generalized Accuracy (Base + Novel mixed): {acc_generalized*100:.2f}%")

Evaluating on ALL 102 classes simultaneously (Generalized Setting)...


Eval Generalized:   0%|          | 0/193 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.44 GiB. GPU 0 has a total capacity of 14.74 GiB of which 1.28 GiB is free. Process 4116 has 13.46 GiB memory in use. Of the allocated memory 13.22 GiB is allocated by PyTorch, and 106.70 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)