# ProtoCoCoOp: Few-shot adaptation of CLIP

Deep Learning Course Project - a.y. 2024/2025

Authors:
- Andrea Giampietro - xxxxxx
- Marco Gandolfi - 258017
- Stefano Camposilvan - 257848

# TODO:
- TESTING OF FINAL MODEL
    - wrt to baseline CLIP and base CoCoOp

- ABLATION STUDY/PERFORMANCE COMPARISON
    - train and eval baseline (CLIP)
    - train and eval CoCoOp alone 
    - train and eval CoCoOp + proto
    - train and eval CoCoOp + KD 
    - train and eval full model (CoCoOp + proto + KD)

    note: ADD IMAGES, GRAPHS, etc...

- PROPERLY COMMENT CODE

- PROPERLY EXPLAIN THEORY/RATIONALE/IDEAS


## Table of contents
- Introduction
- Setup
- The Baseline: CLIP
- Our Approach
    - Overview
    - CoCoOp
    - Knowledge distillation
    - Prototypes Generation
    - Implementation
- Results and Discussion
- Conclusions
- References

## Introduction

## Setup

### Initialization

In [2]:
# Import necessary packages
import os
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

import sys
import torch
import torchvision
import numpy as np
import random
import gc
from matplotlib import pyplot as plt
import csv
from shutil import copy
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader

# Install CLIP if not already installed
try:
    import clip
    print("✓ CLIP already installed")
except Exception:
    print("Installing CLIP...")
    import subprocess, importlib
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", 
                              "git+https://github.com/openai/CLIP.git"])
    importlib.invalidate_caches()
    import clip

✓ CLIP already installed


### Paths and constants definition

In [3]:
# -- PATHS DEFINITION --
# Directory for dataset
data_path = "data"
os.makedirs(data_path, exist_ok=True) # for dataset storage

# Directories for saving results
models_path = "results/models"
os.makedirs(models_path, exist_ok=True) # for saving models
logs_path = "results/logs"
os.makedirs(logs_path, exist_ok=True) # for saving logs
plots_path = "results/plots"
os.makedirs(plots_path, exist_ok=True) # for saving plots

# -- CONSTANTS DEFINITION --
# Class names for Flowers102 dataset
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Seed for reproducibility
SEED = 42

# -- REPRODUCIBILITY SETUP --
# Function to set random seed for reproducibility
def set_seed(seed):
    """Set random seed for reproducibility
    Args:
        seed (int): The seed value to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed for reproducibility
set_seed(SEED)

# Worker initialization function for DataLoader
def worker_init_fn(worker_id):
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)

### Data preparation
We define utility functions for:
- **`get_data()`**: Load Flowers102 from torchvision
- **`base_novel_categories()`**: Split 102 classes into base (0-50) and novel (51-101)
- **`split_data()`**: Filter images for base/novel in each split

This simulates the real scenario: we have 51 seen classes during training (base) and 51 new ones (novel).

In [4]:
# -- DATA PREPARATION FUNCTIONS --
# Load specific split of Flowers102 dataset, with given transformation
def load_split(split, transform):
    """Load Flowers102 dataset split with given transformation.
    Args:
        split (str): One of "train", "val", or "test".
        transform (callable): Transformation to apply to the images.
    Returns:
        torchvision.datasets.Flowers102: The requested dataset split.
    """
    return torchvision.datasets.Flowers102(root=data_path, split=split, download=True, transform=transform)

# Load Flowers102 dataset and return train, val, test sets
def get_data(transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        transform (callable, optional): Transformation to apply to the images. Defaults to None.
    Returns:
        tuple: (train_set, val_set, test_set) as torchvision.datasets.Flowers102 instances.
    """
    train = load_split("train", transform)
    val = load_split("val", transform)
    test = load_split("test", transform)

    return train, val, test

# Split dataset classes into base and novel classes
def split_classes(dataset):
    """Return base and novel class id lists using the actual labels present in the dataset.
    Args:
        dataset (torchvision.datasets.Flowers102): The dataset to split classes from.
    Returns:
        tuple: (base_classes, novel_classes) as lists of class ids.
    """
    labels = getattr(dataset, "targets", None)
    if labels is None:
        labels = getattr(dataset, "labels", None)

    if labels is None and hasattr(dataset, "_labels"):
        labels = dataset._labels

    if labels is None:
        raise ValueError("Could not find labels on dataset (checked 'targets','labels','_labels').")

    unique_labels = sorted(set(labels))
    num_classes = len(unique_labels)
    mid = num_classes // 2

    # Split classes into base and novel (first half and second half)
    base_classes = unique_labels[:mid]
    novel_classes = unique_labels[mid:]

    return base_classes, novel_classes

# Split dataset into base and novel datasets
def split_data(dataset, base_classes):
    """Split dataset into base and novel datasets based on provided base classes.
    Args:
        dataset (torchvision.datasets.Flowers102): The dataset to split.
        base_classes (list): List of class ids considered as base classes.
    Returns:
        tuple: (base_dataset, novel_dataset) as torch.utils.data.Subset instances.
    """
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)

    return base_dataset, novel_dataset

In [1]:
# Load CLIP model and preprocessing
model, preprocess = clip.load("ViT-B/16", device=device)

# Load dataset and split into base and novel datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# Get base and novel classes from the test set
base_classes, novel_classes = split_classes(test_set)
classes = base_classes + novel_classes

# Get class names
base_class_names = [CLASS_NAMES[i] for i in base_classes]
print(f"Base classes ({len(base_classes)}): {base_class_names}")
novel_class_names = [CLASS_NAMES[i] for i in novel_classes]
print(f"Novel classes ({len(novel_classes)}): {novel_class_names}")
print(f"All classes: ({len(classes)}: { [CLASS_NAMES[i] for i in classes] }")

# Create base and novel datasets
base_train_set, _ = split_data(train_set, base_classes)
base_val_set, _ = split_data(val_set, base_classes)
base_test_set, novel_test_set = split_data(test_set, base_classes)

NameError: name 'clip' is not defined

## Harmonic Mean (HM)

Standard metric for few-shot adaptation papers.

Formula: HM = 2 / (1/base_acc + 1/novel_acc)

**Why HM instead of arithmetic mean?**
- HM heavily penalizes outliers
- If base=90% and novel=50%: arithmetic mean=70%, HM=64.3%
- Forces the model to balance both accuracies

**goal:** maximize HM between `base_acc_cocoop` and `novel_acc_cocoop`.


In [5]:
# Harmonic Mean Calculation
def harmonic_mean(a, b):
    return 2 * a * b / (a + b) if (a + b) > 0 else 0.0

## The Baseline: CLIP

In [None]:
@torch.no_grad()
def test(model, dataset, classes, batch_size, device, label=""):
    """Evaluate CLIP model on given dataset and classes.
    Args:
        model (torch.nn.Module): The CLIP model.
        dataset (torch.utils.data.Dataset): The dataset to evaluate on.
        classes (list): List of class ids to consider.
        batch_size (int): Batch size for DataLoader.
        device (str): Device to run the evaluation on.
        label (str, optional): Label for progress bar. Defaults to none.
    Returns:
        float: Accuracy of the model on the given dataset and classes.
    """
    # Set model to evaluation mode
    model.eval()

    # Remap original class ids to contiguous ids starting from zero
    class_ids = {cls: id for id, cls in enumerate(classes)}

    # Apply and tokenize standard clip sentences
    text_inputs = clip.tokenize([f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in classes]).to(device)

    # Encode text features and normalize
    text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Create dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2, worker_init_fn=worker_init_fn)

    # Compute accuracy of the model
    correct_predictions = 0
    for image, target in tqdm(dataloader, desc=f"Evaluating on {label}", leave=False):
        target = torch.Tensor([class_ids[t.item()] for t in target]).long()
        
        image = image.to(device)
        target = target.to(device)

        # Encode image features and normalize
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # Predict class by finding the text feature with highest similarity
        predicted_class = (image_features @ text_features.T).argmax(dim=-1)
        correct_predictions += (predicted_class == target).sum().item()

    accuracy = correct_predictions/len(dataset)

    return accuracy

print("\nComputing CLIP zero-shot accuracy on base and novel classes...")
base_acc = test(model=model, dataset=base_test_set, classes=base_classes, batch_size=128, device=device, label="base classes")
novel_acc = test(model=model, dataset=novel_test_set, classes=novel_classes, batch_size=128, device=device, label="novel classes")
hm = harmonic_mean(base_acc, novel_acc)
print("\nComputation done.\n")

print(f"Zero-shot accuracy on base classes: {base_acc*100:.2f}%")
print(f"Zero-shot accuracy on novel classes: {novel_acc*100:.2f}%")
print(f"Harmonic Mean: {hm*100:.2f}%")

## Our Approach: Proto-guided CoCoOp with Knowledge Distillation

### Overview

### CoCoOp

### Knowledge Distillation

### Prototypes Generation

We construct **class prototypes** from CLIP image embeddings of training samples.

**Key Design Choices:**
- Use **frozen CLIP** (not the adapted model) to preserve zero-shot knowledge
- Compute prototypes from **both normal and augmented samples** for better coverage
- **L2-normalize** embeddings before averaging and after

**At Inference:**
- Compute prototype similarity: $\text{sim}_{\text{proto}}(x, c) = \frac{f(x) \cdot p_c}{\|f(x)\| \|p_c\|}$
- Fuse with CoCoOp logits: $\text{logits}_{\text{final}} = \alpha \cdot \text{logits}_{\text{CoCoOp}} + (1-\alpha) \cdot \text{logits}_{\text{proto}}$
- The fusion weight $\alpha$ controls the trade-off between prompt-based and prototype-based predictions

In [None]:
# Data augmentation transform for prototype construction
aug_view_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    torchvision.transforms.Lambda(lambda im: im.convert("RGB")),
    torchvision.transforms.RandomCrop(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomRotation(30),
    torchvision.transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                     (0.26862954, 0.26130258, 0.27577711)),
])

# Class to apply transform to an element of the dataset
class TransformView(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        img, y = self.subset[idx]
        img = self.transform(img)
        
        return img, y

# Build prototypes from augumented dataset
@torch.no_grad()
def build_prototypes(model, dataset, base_classes, device='cuda'):
    model.eval()
    
    # Collect embeddings per class
    embeddings_per_class = {c: [] for c in base_classes}
    
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=64, shuffle=False, num_workers=0
    )
    
    print(f"Extracting embeddings from {len(dataset)} samples...")
    
    for images, labels in tqdm(dataloader, desc="Building Prototypes"):
        images = images.to(device)
        
        # Get CLIP image features
        features = model.encode_image(images)
        features = features / features.norm(dim=-1, keepdim=True)  # L2 normalize
        
        for feat, label in zip(features, labels):
            label_id = label.item()
            if label_id in embeddings_per_class:
                embeddings_per_class[label_id].append(feat.cpu())
    
    # Compute mean prototype per class
    prototypes = {}

    for cls_id in base_classes:
        if len(embeddings_per_class[cls_id]) == 0:
            print(f"Warning: no samples for class {cls_id}")
            continue

        class_embeddings = torch.stack(embeddings_per_class[cls_id])
        prototype = class_embeddings.mean(dim=0).to(device)
        prototype = prototype / prototype.norm()

        prototypes[cls_id] = prototype

    # Create matrix for efficient inference (ordered by base_classes)
    prototype_matrix = torch.stack([prototypes[c] for c in base_classes]).to(device)
    
    print(f"Built {len(prototypes)} prototypes | Matrix shape: {prototype_matrix.shape}")
    
    return prototypes, prototype_matrix  # matrix of shape (num_base_classes, feature_dim)

In [16]:
# Load raw train dataset (PIL images)
train_raw = load_split("train", transform=None)

# Build base subset indices on the same object (= avoid mismatched _labels across dataset instances)
base_set = set(base_classes)
base_idx = [i for i, y in enumerate(train_raw._labels) if y in base_set]  # uses Flowers102._labels
base_train_raw = torch.utils.data.Subset(train_raw, base_idx)

# Define transforms for original and augmented views
orig_view = TransformView(base_train_raw, preprocess)

num_samples = 10  # number of augmented views per original image
views = [orig_view] + [TransformView(base_train_raw, aug_view_transform) for _ in range(num_samples)]

# Create the prototype pool by concatenating all views
proto_pool = torch.utils.data.ConcatDataset(views)

print("N =", len(orig_view), "pool =", len(proto_pool))

# Build prototypes using frozen CLIP
prototypes, prototype_matrix = build_prototypes(
    model=model,
    dataset=proto_pool,
    base_classes=base_classes,
    device=device
)

N = 510 pool = 5610
Extracting embeddings from 5610 samples...


Building Prototypes: 100%|██████████| 88/88 [00:46<00:00,  1.88it/s]

Built 51 prototypes | Matrix shape: torch.Size([51, 512])





## Implementation

**Components:**
1. **Context Vectors (V):** 16 vectors (learnable).
   - Shape: `(16, 512)`
   - Initialized: Gaussian noise N(0, 0.02)
   - Function: Provide the base context for the prompt.

2. **Meta-Network (Bias Generator):**
   - Architecture: Linear(512->32) -> ReLU -> Linear(32->512)
   - Input: Image Features `(Batch, 512)`
   - Output: Bias `(Batch, 512)` added to Context Vectors.
   - **Note:** Unlike the paper's simplified notation "$\pi$", we implement this as an **additive bias** to the context vectors.

3. **Class Embeddings:**
   - Pre-computed embeddings for "[CLASS] + EOS".
   - Fixed during training.

**Forward Pass (Vectorized):**
Instead of looping through images, we broadcast tensors to shape `(Batch, Num_Classes, Sequence_Length, Dim)`:
1. **Compute Bias:** $Bias = MetaNet(Image)$
2. **Shift Context:** $Ctx_{new} = Ctx_{base} + Bias$ (Broadcasting over classes)
3. **Concatenate:** $[Prefix] + [Ctx_{new}] + [Suffix]$ (All in parallel)

In [None]:
# Text Encoder module definition
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(self.dtype)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

# Prompt Learner module definition
class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=4, ctx_init=None, device='cuda'):
        super().__init__()
        self.dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0] # Dimension of context vectors
        vis_dim = clip_model.visual.output_dim # Dimension of visual features
        self.n_cls = len(classnames)
        self.n_ctx = n_ctx
        self.device = device

        # Meta network to generate context bias from visual features
        hidden_dim = vis_dim // 16
        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, hidden_dim)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(hidden_dim, ctx_dim))
        ])).to(device)
        
        # Context Initialization
        if ctx_init: # If context initialization is provided (i.e. a string)
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).to(self.dtype)
            ctx_vectors = embedding[0, 1:1+n_ctx, :]
            prompt_prefix = ctx_init
        else:
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=torch.float32)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        
        # Context vectors as learnable parameters
        self.ctx = nn.Parameter(ctx_vectors)
        
        ref_classnames = [name.replace("_", " ") for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in ref_classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).to(self.dtype)
            
        self.register_buffer("token_prefix", embedding[:, :1, :]) # Register prefix tokens as buffer (non-learnable)
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :]) # Register suffix tokens as buffer (non-learnable)
        self.tokenized_prompts = tokenized_prompts

    def forward(self, im_features):
        batch_size = im_features.shape[0]
        ctx = self.ctx.to(self.dtype).unsqueeze(0)
        bias = self.meta_net(im_features).unsqueeze(1)
        
        ctx_shifted = ctx + bias  # Shift context by adding bias from meta network
        
        prefix = self.token_prefix.unsqueeze(0).expand(batch_size, -1, -1, -1)
        suffix = self.token_suffix.unsqueeze(0).expand(batch_size, -1, -1, -1)
        ctx_expanded = ctx_shifted.unsqueeze(1).expand(-1, self.n_cls, -1, -1)
        
        return torch.cat([prefix, ctx_expanded, suffix], dim=2) # (batch, n_cls, n_tokens, dim)

# ProtoCoCoOp model definition, extending CoCoOp with optional prototype residuals at inference time
class ProtoCoCoOp(nn.Module):
    def __init__(self, clip_model, classnames, base_ids, n_ctx=4, ctx_init=None, device='cuda'):
        super().__init__()
        self.logit_scale = clip_model.logit_scale
        self.clip_model = clip_model
        self.dtype = self.clip_model.dtype
        self.base_ids = torch.tensor(base_ids, device=device)
        self.device = device

        self.image_encoder = self.clip_model .visual
        self.text_encoder = TextEncoder(self.clip_model)
        self.prompt_learner = PromptLearner(self.clip_model, classnames, n_ctx, ctx_init, device)

        self.tokenized_prompts = self.prompt_learner.tokenized_prompts

        self.prototype_matrix = None
        self.alpha = None            

    def set_prototypes(self, prototype_matrix, alpha=0.2):
        self.prototype_matrix = prototype_matrix.to(self.device).type(self.dtype)
        self.alpha = alpha

    def forward(self, image, use_prototypes=False):
        image = image.to(self.device).type(self.dtype)
        image_features = self.image_encoder(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        prompts = self.prompt_learner(image_features)
        B, C, T, D = prompts.shape
        prompts = prompts.reshape(B * C, T, D).type(self.dtype)

        tokenized = self.tokenized_prompts.to(prompts.device).repeat(B, 1)

        text_features = self.text_encoder(prompts, tokenized)
        text_features = text_features.reshape(B, C, -1)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logits = self.logit_scale.exp() * (image_features.unsqueeze(1) @ text_features.transpose(1, 2)).squeeze(1)

        # Prototype fusion (inference only)
        if use_prototypes and self.prototype_matrix is not None:
            # Compute prototype logits
            proto_logits = self.logit_scale.exp() * (image_features @ self.prototype_matrix.T)

            # Fuse only base classes
            logits_base = logits[:, self.base_ids]
            logits[:, self.base_ids] = logits_base + self.alpha * proto_logits

        return logits

## Training and Evaluation

Class that manages:

**1. Initialization:**
- Create PromptLearner
- Freeze CLIP (`requires_grad=False`)
- Configure SGD optimizer for prompt learner only

**2. train_epoch():**
- Forward: Image encoder + PromptLearner + Text encoder
- **Critical step:** Encode soft prompts through text transformer
  - Add positional embeddings
  - Pass through CLIP's transformer
  - Extract first token
  - Apply final layer norm + projection
- Compute loss: Cross-entropy on base classes
- Backward: Backprop only in PromptLearner
- Return: Average loss of the epoch

**3. eval() with Prototype Fusion:**
- Same forward procedure as training
- **NEW:** Optionally fuse CoCoOp logits with prototype similarity scores
- Fusion formula: $\text{logits} = \alpha \cdot \text{logits}_{\text{CoCoOp}} + (1-\alpha) \cdot \text{logits}_{\text{prototype}}$
- Compute accuracy on any dataset (base or novel)

**Important note:** We don't use `model.encode_text()` on soft prompts
because that method expects integer tokens, not embeddings.
We manually forward through the text transformer.

In [None]:
class CoCoOpTrainer:
    def __init__(self, clip_model, classnames, base_classes, config, params, device="cuda"):
        """
        CoCoOp Trainer class for training and evaluation.

        Args:
            clip_model: Pretrained CLIP model.
            classnames: List of all class names.
            base_classes: List of base class ids.
            config: Configuration dictionary for CoCoOp.
                    Contains 'mode', 'n_ctx', 'ctx_init'.
            params: Training parameters dictionary.
                    contains 'lr', 'momentum', 'weight_decay', 'kd_alpha', 'temperature', 'num_epochs'.
            device: Device to run the model on (default: "cuda").
        """
        self.mode = config["mode"].lower()
        if self.mode == "standard":
            self.use_proto = False
            self.use_kd = False
        elif self.mode == "kd":
            self.use_proto = False
            self.use_kd = True
        elif self.mode == "proto":
            self.use_proto = True
            self.use_kd = False
        elif self.mode == "proto_kd":
            self.use_proto = True
            self.use_kd = True
        else:
            raise ValueError(f"Invalid mode: {self.mode}. Choose from 'standard', 'kd', 'proto', 'proto_kd'.")
        
        print(f"Initialized CoCoOpTrainer in '{self.mode}' mode | use_proto={self.use_proto} | use_kd={self.use_kd}")

        self.kd_alpha = params["kd_alpha"]
        self.temperature = params["temperature"]
        self.num_epochs = params["num_epochs"]
        self.tr_batch_size = params["tr_batch_size"]
        self.ts_batch_size = params["ts_batch_size"]
        self.device = device

        # Freeze CLIP model
        self.clip_model = clip_model.float().to(device).eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False

        # Precompute CLIP model's text features
        with torch.no_grad():
            prompts = [f"a photo of a {c}" for c in classnames]
            tokens = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
            text_features = self.clip_model.encode_text(tokens)
            text_features /= text_features.norm(dim=-1, keepdim=True)

        self.clip_text_features = text_features

        # Model initialization
        self.model = ProtoCoCoOp(
            self.clip_model,
            classnames,
            base_ids=base_classes,
            n_ctx=config["n_ctx"],
            ctx_init=config["ctx_init"],
            device=device
        ).to(device)

        # Optimizer definition
        self.optimizer = torch.optim.SGD(
            self.model.prompt_learner.parameters(),
            lr=params["lr"],
            momentum=params["momentum"],
            weight_decay=params["weight_decay"]
        )

        # Learning rate scheduler definition
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.num_epochs)

        # Base class ids tensor
        self.base_ids = torch.tensor(base_classes, device=device)

        # Safe label mapping
        num_total_classes = len(classnames)
        self.label_map = torch.full((num_total_classes,), -1, dtype=torch.long, device=device)
        self.label_map[self.base_ids] = torch.arange(len(base_classes), device=device)

    # Knowledge Distillation Loss computation
    def compute_kd_loss(self, student_logits, teacher_logits):
        T = self.temperature

        student_log_probs = F.log_softmax(student_logits / T, dim=-1)
        teacher_probs = F.softmax(teacher_logits / T, dim=-1)

        return F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (T ** 2)
    
    # Training function
    def train(self, dataset):
        """
        Trains the model for one epoch.

        Args:
            dataset: Training dataset.
        Returns:
            Average training loss over the epoch.
        """
        self.model.train()

        total_loss = 0.0
        total_samples = 0

        # Initialize Dataloader
        train_loader = DataLoader(dataset, batch_size=self.tr_batch_size, shuffle=True, num_workers=1, worker_init_fn=worker_init_fn)

        for images, labels in tqdm(train_loader, desc=f"Training [{self.mode}]"):
            images = images.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()

            # Forward pass through the model without prototype fusion
            logits = self.model(images, use_prototypes=False)

            # Compute Cross-Entropy Loss on base classes
            base_logits = logits[:, self.base_ids]
            targets = self.label_map[labels]

            loss_ce = F.cross_entropy(base_logits, targets)

            # Compute Knowledge Distillation Loss if enabled
            if self.use_kd:
                with torch.no_grad():
                    img_feat = self.model.clip_model.encode_image(images)
                    img_feat /= img_feat.norm(dim=-1, keepdim=True)

                    teacher_logits = (self.model.clip_model.logit_scale.exp() * img_feat @ self.clip_text_features.T)

                loss_kd = self.compute_kd_loss(logits, teacher_logits)

                loss = (1 - self.kd_alpha) * loss_ce + self.kd_alpha * loss_kd
            else:
                loss = loss_ce

            loss.backward()
            self.optimizer.step()

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

        self.scheduler.step()

        return total_loss/total_samples
    
    # Evaluation function
    @torch.no_grad()
    def test(self, dataset, class_ids, use_prototypes=False):
        """
        Evaluates the model on the given dataset. 

        Args:
            dataset: Dataset to evaluate on.
            class_ids: List of class ids to consider during evaluation.

        Returns:
            Tuple of (accuracy, average loss) over the dataset.        
        """
        # Set model to evaluation mode
        self.model.eval()

        # Build mapping
        class_ids = torch.tensor(class_ids, device=self.device)
        mapping = torch.full((len(self.clip_text_features),), -1, dtype=torch.long, device=self.device)
        mapping[class_ids] = torch.arange(len(class_ids), device=self.device)

        # Initialize Dataloader
        test_loader = DataLoader(dataset, batch_size=self.ts_batch_size, shuffle=False, num_workers=2, worker_init_fn=worker_init_fn)

        correct_predictions = 0
        predictions = 0
        total_loss = 0.0

        # Evaluation loop
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images = images.to(self.device)
            labels = labels.to(self.device)

            # Forward pass through the model with optional prototype fusion
            logits = self.model(images, use_prototypes=use_prototypes)
                
            logits = logits[:, class_ids]

            targets = mapping[labels]

            preds = logits.argmax(dim=1)
            loss = F.cross_entropy(logits, targets)

            correct_predictions += (preds == targets).sum().item()
            predictions += images.size(0)
            total_loss += loss.item() * images.size(0)

        return (correct_predictions/predictions, total_loss/predictions)

### Training

We will train the PromptLearner for **10 epochs** on **base classes only**.

**Hyperparameters (Optimized):**
- **Context Length (`n_ctx`):** 16 (Increased capacity for fine-grained details)
- **Batch size:** 4 (Increased from 1 thanks to parallelization)
- **Learning rate:** 0.002 (SGD)
- **Momentum:** 0.9
- **Weight decay:** 5e-4
- **Epochs:** 10

**What happens:**
- The `PromptLearner` adapts its 4 context vectors to the Flowers102 dataset.
- The `MetaNetwork` learns to inject image-specific bias efficiently.
- **Optimization:** We use a GPU-based label lookup table to speed up target mapping.

**Expected output:**
- Initial loss: ~2.5 - 3.5
- Final loss: ~0.5 - 1.0 (Lower than before due to better context capacity)
- Training time: ~2-4 minutes on GPU

In [None]:
# Configuration
CURRENT_MODE = "standard" # Select one of: "standard", "proto", "kd", "proto_kd"

config = {
    "mode": CURRENT_MODE,
    "n_ctx": 8,
    "ctx_init": None
}

# Define training hyperparameters
params = {
    "lr": 0.002,
    "momentum": 0.9,
    "weight_decay": 5e-4,
    "tr_batch_size": 1,    # Training batch size
    "ts_batch_size": 32,   # Testing batch size
    "patience_init": 5,
    "num_epochs": 15,
    "proto_alpha": 0.2,    # Weight for prototype logits in ProtoCoCoOp
    "kd_alpha": 0.3,       # Weight for KD Loss
    "temperature": 2.0     # Softmax temperature for KD
}

# Initialization
trainer = CoCoOpTrainer(
    clip_model=model,            # pretrained CLIP model
    classnames=CLASS_NAMES,      # all class names
    base_classes=base_classes,   # base class ids
    config=config,
    params=params,
    device=device,
)

# Results storage
results = {
    "mode": config["mode"],
    "sampled_epochs": [],
    "val_accs": [],
    "best_val_acc": 0.0,
    "losses_train": [],
    "losses_val": [],
}

# Initialize early stopping patience
patience = params["patience_init"]

# Training loop
print("\n" + "="*70)
print(f"TRAINING LOOP (Patience: {params['patience_init']}) | Mode: {config['mode'].upper()}")
print("="*70)

for epoch in range(trainer.num_epochs):
    results["sampled_epochs"].append(epoch)

    # Training Step
    train_loss = trainer.train(base_train_set)
    print(f"\nEpoch {epoch+1}/{trainer.num_epochs} | Train Loss: {train_loss:.4f}")

    results["losses_train"].append(np.asarray(train_loss).mean())

    # Evaluation Step
    val_acc, val_loss = trainer.test(base_val_set, base_classes, use_prototypes=False)
    print(f" Validation Acc: {val_acc*100:.2f}% | Val Loss: {np.asarray(val_loss).mean():.4f}")
   
    results["val_accs"].append(val_acc)
    results["losses_val"].append(np.asarray(val_loss).mean())

    # Early Stopping and checkpointing
    if val_acc > results["best_val_acc"]:
        results["best_val_acc"] = val_acc
        patience = params["patience_init"] # Reset patience
        
        # Save model data
        save_path = os.path.join(models_path, f"best_model_{config['mode']}.pth")
        model_data = {
            "model_state_dict": trainer.model.state_dict(),
            "optimizer_state_dict": trainer.optimizer.state_dict(),
            "epoch": epoch,
            "config": config,
            "params": params,
            "results": results
        }
        torch.save(model_data, save_path)
        print(f"[BEST MODEL SAVED] Acc: {val_acc*100:.2f}%")
    else:
        patience -= 1
        print(f" [No Improvement | Patience left: {patience}]")
        if patience == 0:
            print(f"\nEARLY STOPPING TRIGGERED at epoch {epoch+1}!")
            break

print("="*70)
print(f"Training complete. Best Val Acc: {results['best_val_acc']*100:.2f}%")

### Training results logging and plotting

In [None]:
# Plots training results
def plot_results(results, plots_path):
    plt.figure()
    plt.plot(results["sampled_epochs"], results["losses_train"], label="Training Loss", marker="o")
    plt.plot(results["sampled_epochs"], results["losses_val"], label="Validation Loss", marker="x")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Training and Dev Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
        
    filename = f"{config['mode']}_training_plot.png"
    filepath = os.path.join(plots_path, filename)
    plt.savefig(filepath)
    plt.close()

# Logs training results
def log_results(params, config, results, log_path):
    # Log and save training results
    log_fields = [
        "model_type",
        "num_epochs",
        "lr",
        "tr_batch_size",
        "ts_batch_size",
        "momentum",
        "weight_decay",
        "kd_alpha",
        "proto_alpha",
        "temperature",
        "n_ctx",
        "base_accuracy"
    ]
    if not os.path.exists(log_path):
        with open(log_path, mode="w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=log_fields)
            writer.writeheader()
    with open(log_path, mode="a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=log_fields)
        writer.writerow({
            "model_type": results["mode"],
            "num_epochs": len(results["sampled_epochs"]),
            "lr": params["lr"],
            "tr_batch_size": params["tr_batch_size"],
            "ts_batch_size": params["ts_batch_size"],
            "momentum": params["momentum"],
            "weight_decay": params["weight_decay"],
            "kd_alpha": params.get("kd_alpha"),
            "proto_alpha": params.get("proto_alpha"),
            "temperature": params.get("temperature"),
            "n_ctx": config["n_ctx"],
            "base_accuracy": f"{results['best_val_acc']*100:.2f}"
        })

# Plot and log results
plot_results(results, plots_path)

log_filepath = os.path.join(logs_path, "training_log.csv")
log_results(params, config, results, log_filepath)

### Testing

We'll test the model with:
1. **Test Base** - CoCoOp only vs CoCoOp + Prototypes
2. **Test Novel** - CoCoOp only (no prototypes for novel classes)

Computing Harmonic Mean between them to evaluate the trade-off.

**Note:** Prototypes are only available for base classes (built from training data).


In [None]:
# Load best model for final evaluation
best_model_path = os.path.join(models_path, f"best_model_{config['mode']}.pth")

if os.path.exists(best_model_path):
    print(f"\nLoading best model from {best_model_path}...")
    # Load model data
    model_data = torch.load(best_model_path, weights_only=False)

    # Re-initialize trainer and load best model state
    config = model_data["config"]
    params = model_data["params"] 

    trainer = CoCoOpTrainer(
        clip_model=model,
        classnames=CLASS_NAMES,
        base_classes=base_classes,
        config=config,
        params=params,
        device=device,
    )
    trainer.model.load_state_dict(model_data["model_state_dict"])

    print("Best model loaded successfully.")
else:
    print("Warning: Best model checkpoint not found! Using current model state.")

if trainer.use_proto:
    print("Setting prototypes for inference...")
    trainer.model.set_prototypes(prototype_matrix, alpha=params["proto_alpha"])

base_acc, _ = trainer.test(base_test_set, base_classes, use_prototypes=trainer.use_proto)
novel_acc, _ = trainer.test(novel_test_set, novel_classes, use_prototypes=False)
hm = harmonic_mean(base_acc, novel_acc)

print("\n" + "="*70)
print(f"RESULTS for MODE: {config['mode'].upper()}")
print("="*70)
print(f"  Base Accuracy:  {base_acc*100:6.2f}%")
print(f"  Novel Accuracy: {novel_acc*100:6.2f}%")
print(f"  Harmonic Mean:  {hm*100:6.2f}%")
print("="*70)

## Results and Discussion

## Conclusions

## References

- CLIP

    - Radford et al., 2021 — Learning Transferable Visual Models From Natural Language Supervision

- CoOp / CoCoOp

    - Zhou et al., 2022 — Learning to Prompt for Vision-Language Models (CoOp)

- Zhou et al., 2022 — Conditional Prompt Learning for Vision-Language Models (CoCoOp)

    - Tip-Adapter

- Zhang et al., 2022 — Tip-Adapter: Training-Free Adaption of CLIP for Few-Shot Classification

    - Proto-based CLIP adaptation

- Zhang et al., 2022 — Tip-Adapter-F (fine-tuned version)

    - Some works refer to this direction as “cache-based adaptation” or “prototype adaptation”