# Deep Learning Course Project - a.y. 2024/2025

This notebook represents our work for the 24/25 Deep Learning course project offered by the University of Trento. 

The task was few-shot adaptation of CLIP on the Flower102 dataset.

**Authors**:
- Andrea Giampietro - 258237
- Marco Gandolfi - 258017
- Stefano Camposilvan - 257848

## Table of contents

1. [**Introduction**](#1-introduction) 

2. [**Setup**](#2-setup)

3. [**The Baseline: CLIP**](#3-the-baseline-clip)

4. [**Our Approach: ProtoCoCoOp**](#4-our-approach-protococoop)

5. [**Results and Discussion**](#6-results-and-discussion)

6. [**Conclusions**](#5-conclusions)

7. [**References**](#7-references)

## 1. Introduction

### 1.1. The Context

Vision–Language Models (VLMs) are a class of models that integrate natural language processing and computer vision to perform a wide range of tasks, including image captioning, visual question answering, and text-to-image generation. When trained on large-scale datasets, these models achieve remarkable performance across diverse domains. </br>
Among them, Contrastive Language–Image Pre-training (CLIP) in particular has demonstrated strong zero-shot capabilities, enabling it to recognize and classify images without explicit task-specific training.

However, in many real-world scenarios, large labeled datasets are unavailable: data may be scarce, expensive to obtain, or highly specialized. Furthermore, fine-grained classification tasks, where categories differ only by subtle details, pose additional challenges. Thus, zero-shot performance may not be sufficient, and task-specific adaptation becomes necessary.

In this context, Few-Shot Adaptation tries to address this challenge by improving generalization when only a limited number of labeled examples per class are available. The goal of such method is to leverage prior knowledge learned during pretraining and adapt the model to new tasks using minimal supervision, mimicking the human ability to learn new concepts from only a handful of examples. Importantly, the few-shot setting requires the model not only to specialize on the Base classes using limited supervision, but also to preserve its original zero-shot generalization on the Novel ones.


### 1.1. Our Proposal

In this project, we tackle this Few-shot Adapatation problem in order to try improving over CLIP's performance. 

To do so, we use the Oxford Flowers102 dataset as a benchmark for fine-grained visual recognition. Such dataset contains 102 flower species, many of which exhibit subtle inter-class difference, making classification more challenging. To simulate a realistic few-shot scenario, we adopt a base–novel split in which only a small number of labeled samples (specifically, 10 shots per class) are available for the Base categories during adaptation, while the remaining classes are treated as Novel and remain unseen during training.

We then build upon CoCoOp, a prompt-learning method designed for few-shot adaptation, and propose a strategy to better balance the base–novel trade-off. Specifically, we combine:

1. **Knowledge Distillation (KD)**

We treat the original frozen CLIP model as a teacher and regularize our adapted model (the student) to remain close to CLIP’s zero-shot predictions. This is achieved through a distillation loss that aligns the student’s logits with those of the teacher. The goal is to prevent overfitting to Base classes and preserve generalization on Novel classes.

2. **Prototype-Based Residual Fusion (Inference-Time)**

For each Base class, we compute a visual prototype by averaging normalized image embeddings extracted from the few-shot training samples (including augmented views). These prototypes act as compact representations of Base-class visual structure.

At inference time, we compute the similarity between the input image embedding and each Base-class prototype, and use this similarity to add a residual logit contribution exclusively to Base classes. This mechanism strengthens discrimination among Base categories while leaving Novel predictions unaffected.

Our final approach therefore explicitly separates:

- Training-time regularization, via Knowledge Distillation to preserve zero-shot behavior.

- Inference-time enhancement, via prototype-based residual fusion to strengthen Base discrimination.

By combining these two mechanisms, we aim to improve the harmonic mean between Base and Novel accuracy, achieving a more balanced base-to-novel generalization on the Flowers102 benchmark.

## 2. Setup

This section serves as the foundation of our project and ensures reproducibility and correct execution of all experiments. </br>
The corresponding code cells include environment preparation, dataset handling, and evaluation protocol definition.

More specifically, this section covers:

- Installation of required packages and dependencies;

- Definition of directories for data storage, model checkpoints, logs, and plots;

- Initialization of constants and global parameters used throughout training and evaluation;

- Dataset loading, splitting, and preprocessing;

- Definition of evaluation metrics.

Additional information are contained in each specific cell's comments.

### 2.1. Initialization

In [2]:
# Import necessary packages
import os
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

import sys
import torch
import torchvision
import numpy as np
import random
import gc
from matplotlib import pyplot as plt
import csv
from shutil import copy
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader

# Install CLIP if not already installed
try:
    import clip
    print("✓ CLIP already installed")
except Exception:
    print("Installing CLIP...")
    import subprocess, importlib
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", 
                              "git+https://github.com/openai/CLIP.git"])
    importlib.invalidate_caches()
    import clip

✓ CLIP already installed


### 2.2. Paths and Constants definition

In [None]:
# -- PATHS DEFINITION --
# Directory for dataset
data_path = "data"
os.makedirs(data_path, exist_ok=True) 

# Directories for saving results
models_path = "results/models"
os.makedirs(models_path, exist_ok=True) 
logs_path = "results/logs"
os.makedirs(logs_path, exist_ok=True) 
plots_path = "results/plots"
os.makedirs(plots_path, exist_ok=True) 

# -- CONSTANTS DEFINITION --
# Class names for Flowers102 dataset
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Seed for reproducibility
SEED = 42

### 2.3. Reproducibility Settings

In [None]:
# -- REPRODUCIBILITY SETUP --
# Function to set random seed for reproducibility
def set_seed(seed):
    """Set random seed for reproducibility
    Args:
        seed (int): The seed value to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed for reproducibility
set_seed(SEED)

# Worker initialization function for DataLoader
def worker_init_fn(worker_id):
    """Initialize random seed for each worker in DataLoader
    Args:
        worker_id (int): The ID of the worker.
    """
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)

### 2.4. Data preparation

Oxford Flowers102 dataset downloading and preparation.

The preparation consist in dividing the 102 classes of the dataset into two blocks: base and novel categories. </br>
The base classes are the first 51, while the novel classes are the remaining 51. </br>
Only 10 labeled training samples per Base class are available. </br>

Additionally to this, train, evaluation and test splits are also created.

This setup simulates the few-shot adaptation protocol commonly adopted in the literature:

- The model is adapted using only the given samples from the known classes (**Base**), while **Novel categories** remain unseen during training.
- Evaluation is performed separately on Base and Novel categories to measure base-to-novel generalization.



In [None]:
# -- DATA PREPARATION FUNCTIONS --
# Load specific split of Flowers102 dataset, with given transformation
def load_split(split, transform):
    """Load Flowers102 dataset split with given transformation.
    Args:
        split (str): One of "train", "val", or "test".
        transform (callable): Transformation to apply to the images.
    Returns:
        torchvision.datasets.Flowers102: The requested dataset split.
    """
    return torchvision.datasets.Flowers102(root=data_path, split=split, download=True, transform=transform)

# Load Flowers102 dataset and return train, val, test sets
def get_data(transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        transform (callable, optional): Transformation to apply to the images. Defaults to None.
    Returns:
        tuple: (train_set, val_set, test_set) as torchvision.datasets.Flowers102 instances.
    """
    train = load_split("train", transform)
    val = load_split("val", transform)
    test = load_split("test", transform)

    return train, val, test

# Split dataset classes into base and novel classes
def split_classes(dataset):
    """Return base and novel class id lists using the actual labels present in the dataset.
    Args:
        dataset (torchvision.datasets.Flowers102): The dataset to split classes from.
    Returns:
        tuple: (base_classes, novel_classes) as lists of class ids.
    """
    labels = getattr(dataset, "targets", None)
    if labels is None:
        labels = getattr(dataset, "labels", None)

    if labels is None and hasattr(dataset, "_labels"):
        labels = dataset._labels

    if labels is None:
        raise ValueError("Could not find labels on dataset (checked 'targets','labels','_labels').")

    unique_labels = sorted(set(labels))
    num_classes = len(unique_labels)
    mid = num_classes // 2

    # Split classes into base and novel (first half and second half)
    base_classes = unique_labels[:mid]
    novel_classes = unique_labels[mid:]

    return base_classes, novel_classes

# Split dataset into base and novel datasets
def split_data(dataset, base_classes):
    """Split dataset into base and novel datasets based on provided base classes.
    Args:
        dataset (torchvision.datasets.Flowers102): The dataset to split.
        base_classes (list): List of class ids considered as base classes.
    Returns:
        tuple: (base_dataset, novel_dataset) as torch.utils.data.Subset instances.
    """
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)

    return base_dataset, novel_dataset


# -- DATA PREPARATION --
# Load CLIP model and preprocessing
model, preprocess = clip.load("ViT-B/16", device=device)

# Load dataset and split into base and novel datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# Get base and novel classes from the test set
base_classes, novel_classes = split_classes(test_set)
classes = base_classes + novel_classes

# Get class names
base_class_names = [CLASS_NAMES[i] for i in base_classes]
print(f"Base classes ({len(base_classes)}): {base_class_names}")
novel_class_names = [CLASS_NAMES[i] for i in novel_classes]
print(f"Novel classes ({len(novel_classes)}): {novel_class_names}")
print(f"All classes: ({len(classes)}: { [CLASS_NAMES[i] for i in classes] }")

# Create base and novel datasets
base_train_set, _ = split_data(train_set, base_classes)
base_val_set, _ = split_data(val_set, base_classes)
base_test_set, novel_test_set = split_data(test_set, base_classes)

### 2.4. Evaluation metric: Harmonic Mean (HM)

To evaluate model's performance correctly, it is important to consider both the accuracy on the Base and the accuracy on the Novel categories. </br>
In order to correctly consider this trade-off, the **Harmonic Mean (HM)** , a standard metric in few-shot adaptation literature, is used.

The harmonic mean between Base accuracy and Novel accuracy is defined as:

$$
    HM = \frac{2}{\frac{1}{\text{BaseAcc}} + \frac{1}{\text{NovelAcc}}}
$$

Where $BaseAcc$ is the accuracy on the Base classes, $NovelAcc$ the accuracy on the Novel ones.

The motivation behind this choice is that harmonic mean strongly penalizes imbalanced performance, as it decreases significantly when one of the two accuracies is low. </br> By considering HM as the principal evaluation metric, it is possible to lean towards balanced performance rather than optimizing one split at the expense of the other, which is particularly important in base-to-novel generalization.

In [None]:
# Harmonic Mean Calculation
def harmonic_mean(a, b):
    """Compute the harmonic mean of two accuracies."""
    # Guard against division by zero when both a and b are zero
    return 2 * a * b / (a + b) if (a + b) > 0 else 0.0

## 3. The Baseline: CLIP

### 3.1 Contrastive Language–Image Pre-training

Contrastive Language–Image Pre-training (CLIP) [Radford et al., 2021] is a Vision–Language Model trained on large-scale image–text pairs collected from the web. </br>
Instead of learning to classify images into a fixed set of predefined categories, CLIP is trained using a contrastive objective that aligns images and their corresponding textual descriptions in a shared embedding space.

The model consists of two main components:

- An **image encoder**, which extracts visual features from an image.
- A **text encoder**, which extracts semantic features from a textual description.

Both encoders project their inputs into the same high-dimensional space. During training, CLIP learns to bring matching image–text pairs closer together in this space while pushing mismatched pairs apart.  

As a result, the model does not memorize class labels. Instead, it learns a semantic alignment between visual concepts and language, enabling it to reason about images through textual descriptions.


### 3.2 Zero-Shot Classification with CLIP

One of CLIP’s most powerful properties is its ability to perform **zero-shot classification**, meaning it can classify images without being explicitly trained on the target dataset.

Given a set of class names $C = \{c_1, \dots, c_K\}$, each class is converted into a textual prompt, for example:

> "a photo of a {class_name}"

These prompts are encoded by the text encoder to obtain a representation for each category.  
At inference time, an input image is encoded by the image encoder, and classification is performed by computing the similarity between the image embedding and each class embedding.

The predicted label corresponds to the class whose textual representation is most similar to the image representation.

Importantly, this procedure requires **no additional training** on the downstream dataset. As long as a category can be described in natural language, CLIP can attempt to recognize it.  

This flexibility makes CLIP a powerful baseline for few-shot adaptation, as it already provides strong generalization to unseen classes.


In [None]:
@torch.no_grad()
def test(model, dataset, classes, batch_size, device, label=""):
    """Evaluate CLIP model on given dataset and classes.
    Args:
        model (torch.nn.Module): The CLIP model.
        dataset (torch.utils.data.Dataset): The dataset to evaluate on.
        classes (list): List of class ids to consider.
        batch_size (int): Batch size for DataLoader.
        device (str): Device to run the evaluation on.
        label (str, optional): Label for progress bar. Defaults to none.
    Returns:
        float: Accuracy of the model on the given dataset and classes.
    """
    # Set model to evaluation mode
    model.eval()

    # Remap original class ids to contiguous ids starting from zero
    class_ids = {cls: id for id, cls in enumerate(classes)}

    # Apply and tokenize standard clip sentences
    text_inputs = clip.tokenize([f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in classes]).to(device)

    # Encode text features and normalize
    text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Create dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2, worker_init_fn=worker_init_fn)

    # Compute accuracy of the model
    correct_predictions = 0
    for image, target in tqdm(dataloader, desc=f"Evaluating on {label}", leave=False):
        target = torch.Tensor([class_ids[t.item()] for t in target]).long()
        
        image = image.to(device)
        target = target.to(device)

        # Encode image features and normalize
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # Predict class by finding the text feature with highest similarity
        predicted_class = (image_features @ text_features.T).argmax(dim=-1)
        correct_predictions += (predicted_class == target).sum().item()

    accuracy = correct_predictions/len(dataset)

    return accuracy

print("\nComputing CLIP zero-shot accuracy on base and novel classes...")
base_acc = test(model=model, dataset=base_test_set, classes=base_classes, batch_size=128, device=device, label="base classes")
novel_acc = test(model=model, dataset=novel_test_set, classes=novel_classes, batch_size=128, device=device, label="novel classes")
hm = harmonic_mean(base_acc, novel_acc)
print("\nComputation done.\n")

print(f"Zero-shot accuracy on base classes: {base_acc*100:.2f}%")
print(f"Zero-shot accuracy on novel classes: {novel_acc*100:.2f}%")
print(f"Harmonic Mean: {hm*100:.2f}%")

### CoCoOp: Conditional Context Optimization

Our approach builds upon **CoCoOp** (Conditional Context Optimization) [Zhou et al., 2022], which addresses a critical limitation of its predecessor, CoOp.

To avoid searching for the prompt that maximizes CLIP zero-shot performance (i.e.,
finding a better prompt than a photo of $\text{class\_name}$), **CoOp** [Zhou et al., 2022] automatizes the process by
learning a set of context vectors (i.e., prompts) using few annotated samples. Formally, let $P ∈ RM×d$ be the set of learnable tokens. Then the input to the text encoder for class class_name becomes $c = [P1, ..., PM, \text{class\_name}]$. By forwarding images and corresponding text prompts through the pre-trained VLM, CoOp tunes P by computing the cross-entropy loss:
$$
\mathcal{L}_{CE}(x_i, y_i)
= - \log
\frac{
\exp\left(\langle f_\theta^v(x_i), f_\theta^t(y_i)\rangle\right)
}{
\sum_{c \in C}
\exp\left(\langle f_\theta^v(x_i), f_\theta^t(c)\rangle\right)
}
$$
where both $y_i$ and $c$ are built by pretending the learnable context to the category name. Yet, CoOp cannot generalize to Novel classes of the same dataset, caused by overfitting of Base classes during adaptation.

To address this problem, **CoCoOp** improves over its predecessor by combining the context vectors with an image-conditioned token, which shifts the focus away from a specific set of classes to a specific input instance, reducing overfitting. To generate the image-conditioned token, CoCoOp trains a lightweight MLP called **Meta-Net** alongside learnable prompts. Formally, let $h_φ$ be the Meta-Net, then each conditional token is obtained as $P_m(x_i) = P_m+h_φ(x_i)$,
where $P_m$ is the m-th learnable context vector in the sequence and $P_m(x_i)$ is the m-th conditional token. Both the Meta-Net and the context vectors are trained end-to-end using the cross-entropy loss. Combining context vectors with an input-conditioned token, has the joint benefit of adapting the VLM to Base classes while keeping high accuracies for Novel ones.

### Knowledge Distillation

A key challenge in prompt learning is that while adapting prompts to base classes improves performance on those classes, it can **degrade generalization to novel classes**. The learned prompts may drift away from CLIP's original semantic alignment, losing the zero-shot transfer capabilities that make CLIP powerful.

We employ **knowledge distillation (KD)** to regularize the training process, using the frozen CLIP model as a *teacher* and the CoCoOp model as a *student*. The goal is to encourage the student's predictions to remain close to the teacher's zero-shot predictions, preserving CLIP's general knowledge while still adapting to the task.

During training, for each input image $x$:

1. **Student logits:** The CoCoOp model produces logits $z^{(s)}$ using learned instance-conditional prompts.

2. **Teacher logits:** The frozen CLIP model computes logits $z^{(t)}$ using standard hand-crafted prompts (`"a photo of a [CLASS]"`).

3. **Soft probability matching:** We convert logits to soft probability distributions using temperature scaling $T$:
   $$p^{(s)}_i = \frac{\exp(z^{(s)}_i / T)}{\sum_j \exp(z^{(s)}_j / T)}, \quad p^{(t)}_i = \frac{\exp(z^{(t)}_i / T)}{\sum_j \exp(z^{(t)}_j / T)}$$

4. **KL divergence loss:** The distillation loss minimizes the KL divergence between student and teacher distributions:
   $$\mathcal{L}_{\text{KD}} = T^2 \cdot \text{KL}(p^{(t)} \| p^{(s)})$$

   The $T^2$ factor compensates for the reduced gradient magnitude when using temperature scaling.

5. **Combined loss:** The final training objective balances task-specific learning with knowledge preservation:
   $$\mathcal{L} = (1 - \alpha) \cdot \mathcal{L}_{\text{CE}} + \alpha \cdot \mathcal{L}_{\text{KD}}$$

   where $\alpha$ controls the trade-off between cross-entropy on base classes and distillation from CLIP.

By regularizing toward CLIP's predictions over *all* classes (not just base classes), we expect the model to retain better generalization to unseen categories. The soft targets from the teacher should provide richer supervision than hard labels alone, acting as a regularizer that reduces overfitting. As a result, the trade-off between base and novel accuracy is expected to improve, leading to higher harmonic mean scores.

### Prototype Generation and Fusion

While CoCoOp provides instance-conditional prompts, we further enhance base class performance by incorporating **class prototypes**—representative embeddings that capture the visual characteristics of each class.

#### Prototype Construction

For each base class $c$, we construct a prototype $\mathbf{p}_c$ by aggregating CLIP image embeddings from training samples:

1. **Extract embeddings:** For each training image $x_i$ of class $c$, compute the CLIP image embedding $f(x_i)$ using the **frozen** CLIP encoder.

2. **Data augmentation:** To build more robust prototypes, we extract embeddings from both original images and multiple augmented views (random crops, flips, rotations, color jitter). This increases the effective sample size and captures intra-class variation.

3. **L2 normalization:** All embeddings are L2-normalized before aggregation:
   $$\hat{f}(x_i) = \frac{f(x_i)}{\|f(x_i)\|}$$

4. **Mean aggregation:** The prototype is the normalized mean of all class embeddings:
   $$\mathbf{p}_c = \frac{1}{|\mathcal{X}_c|} \sum_{x_i \in \mathcal{X}_c} \hat{f}(x_i), \quad \mathbf{p}_c \leftarrow \frac{\mathbf{p}_c}{\|\mathbf{p}_c\|}$$

**Key design choice:** We use the frozen CLIP encoder (not the adapted CoCoOp model) to preserve CLIP's zero-shot semantic structure in the prototype space.

#### Inference-Time Fusion

At inference, we combine CoCoOp's prompt-based predictions with prototype-based similarity scores:

1. **CoCoOp logits:** Compute logits $z_{\text{CoCoOp}}$ using instance-conditional prompts (as described in the CoCoOp section).

2. **Prototype logits:** Compute similarity between the test image embedding and each class prototype:
   $$z_{\text{proto}}(c) = \tau \cdot \cos(f(x), \mathbf{p}_c)$$
   where $\tau$ is CLIP's learned temperature parameter.

3. **Additive fusion:** The final logits for base classes are:
   $$z_{\text{final}}(c) = z_{\text{CoCoOp}}(c) + \alpha \cdot z_{\text{proto}}(c)$$
   where $\alpha$ controls the contribution of prototype information.

**Note:** Prototype fusion is applied **only to base classes** since prototypes are constructed from training data. Novel classes rely solely on CoCoOp's generalization.

#### Expected Benefits

- **Improved base class accuracy:** Prototypes provide an additional source of class-specific information, complementing the prompt-based predictions.
- **Robustness to few-shot settings:** By aggregating multiple augmented views, prototypes capture richer class representations even with limited training samples.
- **Preserved novel class performance:** Since prototypes don't affect novel class predictions, the generalization benefits of CoCoOp are retained.
- **Better harmonic mean:** The combination of improved base accuracy and maintained novel accuracy leads to higher overall HM scores.

The code below implements prototype construction with the following components:

- **`aug_view_transform`**: A data augmentation pipeline (resize, random crop, flip, rotation, color jitter) used to generate diverse views of each training image, improving prototype robustness.

- **`TransformView`**: A lightweight dataset wrapper that applies a given transform to images from a subset, enabling us to create multiple augmented views of the same underlying data.

- **`build_prototypes`**: Extracts CLIP image embeddings from all samples (original + augmented), groups them by class, computes the L2-normalized mean for each class, and returns both a dictionary of prototypes and a stacked matrix for efficient batch inference.

In [None]:
# Data augmentation transform for prototype construction
aug_view_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    torchvision.transforms.Lambda(lambda im: im.convert("RGB")),
    torchvision.transforms.RandomCrop(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomRotation(30),
    torchvision.transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                     (0.26862954, 0.26130258, 0.27577711)),
])

# Class to apply transform to an element of the dataset
class TransformView(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        img, y = self.subset[idx]
        img = self.transform(img)
        
        return img, y

# Build prototypes from augumented dataset
@torch.no_grad()
def build_prototypes(model, dataset, base_classes, device='cuda'):
    """
    Build class prototypes from image embeddings extracted using frozen CLIP.

    Args:
        model: Frozen CLIP model used to extract image features.
        dataset: Dataset containing images and labels.
        base_classes: List of base class ids to build prototypes for.
        device: Device to run computations on (default: 'cuda').

    Returns:
        Tuple of (prototypes_dict, prototype_matrix):
        - prototypes_dict: Dictionary mapping class_id -> prototype tensor (feature_dim,)
        - prototype_matrix: Stacked tensor of shape (num_base_classes, feature_dim) for efficient inference
    """
    model.eval()
    
    # Collect embeddings per class
    embeddings_per_class = {c: [] for c in base_classes}
    
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=64, shuffle=False, num_workers=0
    )
    
    print(f"Extracting embeddings from {len(dataset)} samples...")
    
    for images, labels in tqdm(dataloader, desc="Building Prototypes"):
        images = images.to(device)
        
        # Get CLIP image features
        features = model.encode_image(images)
        features = features / features.norm(dim=-1, keepdim=True)  # L2 normalize
        
        for feat, label in zip(features, labels):
            label_id = label.item()
            if label_id in embeddings_per_class:
                embeddings_per_class[label_id].append(feat.cpu())
    
    # Compute mean prototype per class
    prototypes = {}

    for cls_id in base_classes:
        if len(embeddings_per_class[cls_id]) == 0:
            print(f"Warning: no samples for class {cls_id}")
            continue

        class_embeddings = torch.stack(embeddings_per_class[cls_id])
        prototype = class_embeddings.mean(dim=0).to(device)
        prototype = prototype / prototype.norm()

        prototypes[cls_id] = prototype

    # Create matrix for efficient inference (ordered by base_classes)
    prototype_matrix = torch.stack([prototypes[c] for c in base_classes]).to(device)
    
    print(f"Built {len(prototypes)} prototypes | Matrix shape: {prototype_matrix.shape}")
    
    return prototypes, prototype_matrix  # matrix of shape (num_base_classes, feature_dim)

In [16]:
# Load raw train dataset (PIL images)
train_raw = load_split("train", transform=None)

# Build base subset indices on the same object (= avoid mismatched _labels across dataset instances)
base_set = set(base_classes)
base_idx = [i for i, y in enumerate(train_raw._labels) if y in base_set]  # uses Flowers102._labels
base_train_raw = torch.utils.data.Subset(train_raw, base_idx)

# Define transforms for original and augmented views
orig_view = TransformView(base_train_raw, preprocess)

num_samples = 10  # number of augmented views per original image
views = [orig_view] + [TransformView(base_train_raw, aug_view_transform) for _ in range(num_samples)]

# Create the prototype pool by concatenating all views
proto_pool = torch.utils.data.ConcatDataset(views)

print("N =", len(orig_view), "pool =", len(proto_pool))

# Build prototypes using frozen CLIP
prototypes, prototype_matrix = build_prototypes(
    model=model,
    dataset=proto_pool,
    base_classes=base_classes,
    device=device
)

N = 510 pool = 5610
Extracting embeddings from 5610 samples...


Building Prototypes: 100%|██████████| 88/88 [00:46<00:00,  1.88it/s]

Built 51 prototypes | Matrix shape: torch.Size([51, 512])





### Implementation

**Components:**
1. **Context Vectors (V):** 16 vectors (learnable).
   - Shape: `(16, 512)`
   - Initialized: Gaussian noise N(0, 0.02)
   - Function: Provide the base context for the prompt.

2. **Meta-Network (Bias Generator):**
   - Architecture: Linear(512->32) -> ReLU -> Linear(32->512)
   - Input: Image Features `(Batch, 512)`
   - Output: Bias `(Batch, 512)` added to Context Vectors.
   - **Note:** Unlike the paper's simplified notation "$\pi$", we implement this as an **additive bias** to the context vectors.

3. **Class Embeddings:**
   - Pre-computed embeddings for "[CLASS] + EOS".
   - Fixed during training.

**Forward Pass (Vectorized):**
Instead of looping through images, we broadcast tensors to shape `(Batch, Num_Classes, Sequence_Length, Dim)`:
1. **Compute Bias:** $Bias = MetaNet(Image)$
2. **Shift Context:** $Ctx_{new} = Ctx_{base} + Bias$ (Broadcasting over classes)
3. **Concatenate:** $[Prefix] + [Ctx_{new}] + [Suffix]$ (All in parallel)

The code below defines three key modules that work together to implement our ProtoCoCoOp model:

**`TextEncoder`**: A wrapper around CLIP's text transformer that processes *continuous prompt embeddings* rather than discrete tokens. Given a prompt tensor of shape `(Batch × Num_Classes, Sequence_Length, Dim)`, it adds positional embeddings, permutes to sequence-first format for the transformer, passes through CLIP's frozen transformer layers, applies layer normalization, extracts the embedding at the EOS token position, and projects it to the joint image-text embedding space via the text projection matrix. This module is necessary because CLIP's built-in `encode_text()` expects tokenized integers, not learnable embedding vectors.

**`PromptLearner`**: Implements the CoCoOp prompt learning mechanism with the following components:
- **Context Vectors**: $M$ learnable vectors of dimension $d$ (e.g., 16 × 512), initialized with Gaussian noise $\mathcal{N}(0, 0.02)$, providing the base context for all prompts.
- **Meta-Net**: A lightweight MLP with architecture `Linear(512→32) → ReLU → Linear(32→512)` that takes image features as input and produces an instance-specific bias vector. Unlike the paper's notation $\pi$, we implement this as an additive shift to the context vectors.
- **Fixed Token Embeddings**: Pre-computed embeddings for the SOS prefix token and the class name + EOS suffix, which remain frozen during training.

During the forward pass, the `PromptLearner` computes the instance-conditional context as $\mathbf{v}(x) = \mathbf{v} + h_\theta(f(x))$, then constructs the full prompt sequence by concatenating `[Prefix] + [Shifted Context] + [Suffix]` for each class. To avoid looping over images, the implementation broadcasts tensors to shape `(Batch, Num_Classes, Sequence_Length, Dim)` and performs all operations in parallel.

**`ProtoCoCoOp`**: The main model that orchestrates the full forward pipeline. Given an input batch of images, it first encodes them using CLIP's frozen ViT image encoder and L2-normalizes the resulting features. These image features are passed to `PromptLearner`, which generates instance-conditional prompt embeddings for all classes simultaneously. The prompts are reshaped to `(Batch × Num_Classes, Sequence_Length, Dim)` and fed through `TextEncoder` to obtain text features, which are also L2-normalized. The model then computes cosine similarity logits between image and text features, scaled by CLIP's learned temperature parameter $\tau$. At inference time, when prototype fusion is enabled, the model additionally computes similarity between the image features and pre-computed class prototypes, and additively fuses these prototype logits with the CoCoOp logits for base classes only (controlled by the $\alpha$ hyperparameter).

In [None]:
# Text Encoder module adapts CLIP's text transformer for batched prompt embeddings.
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        # Reuse components from the loaded CLIP text encoder
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        """
        Encode batched prompt embeddings using CLIP's text transformer.
        Args:
            prompts: (batch_size, seq_len, dim) tensor of prompt embeddings
            tokenized_prompts: (batch_size, seq_len) tensor of token ids
        Returns:
            (batch_size, proj_dim) tensor of encoded text features
        """
        # prompts: (batch_tokens, seq_len, dim) positional embeddings already included below
        # tokenized_prompts: token ids (used to pick the final token's embedding)
        x = prompts + self.positional_embedding.type(self.dtype)  # add positional embeddings
        x = x.permute(1, 0, 2)  # transformer expects (seq_len, batch, dim)
        x = self.transformer(x)  # run through CLIP transformer
        x = x.permute(1, 0, 2)  # back to (batch, seq_len, dim)
        x = self.ln_final(x).type(self.dtype)  # layer norm and cast
        # select the embedding at the end-of-text token for each sequence, then project
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x  # (batch, proj_dim)


# Prompt Learner generates per-class prompt embeddings, optionally conditioned on image features.
class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=4, ctx_init=None, device='cuda'):
        """
        Initialize the PromptLearner.
        Args:
            clip_model: Pretrained CLIP model.
            classnames: List of class names for the dataset.
            n_ctx: Number of context tokens to learn.
            ctx_init: Optional string to initialize context tokens.
            device: Device to run the model on.
        """
        super().__init__()
        self.dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]  # dimensionality of token embeddings
        vis_dim = clip_model.visual.output_dim  # dimensionality of visual features
        self.n_cls = len(classnames)
        self.n_ctx = n_ctx
        self.device = device

        # Meta network: maps image features -> additive bias for context vectors.
        hidden_dim = vis_dim // 16
        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, hidden_dim)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(hidden_dim, ctx_dim))
        ])).to(device)
        
        # Context initialization: either from provided text or random normal.
        if ctx_init:  # if a string is provided, initialize context from its token embeddings
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).to(self.dtype)
            # use tokens after the special start token (1:1+n_ctx)
            ctx_vectors = embedding[0, 1:1+n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # learnable context vectors initialized from N(0, 0.02)
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=torch.float32)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        
        # Make context vectors learnable parameters
        self.ctx = nn.Parameter(ctx_vectors)
        
        # Prepare tokenized prompts for all classes using the prefix and class names
        ref_classnames = [name.replace("_", " ") for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in ref_classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        
        # Obtain static token embeddings for prefix and suffix parts (non-learnable buffers)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).to(self.dtype)
            
        # token_prefix: the special start token (e.g., [SOS]) for each class
        self.register_buffer("token_prefix", embedding[:, :1, :])
        # token_suffix: the remaining tokens after the learnable context tokens
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :])
        self.tokenized_prompts = tokenized_prompts

    def forward(self, im_features):
        # im_features: (batch, vis_dim)
        batch_size = im_features.shape[0]
        ctx = self.ctx.to(self.dtype).unsqueeze(0)  # (1, n_ctx, dim)
        bias = self.meta_net(im_features).unsqueeze(1)  # (batch, 1, dim)
        
        # Add image-conditioned bias to the base context vectors
        ctx_shifted = ctx + bias  # (batch, n_ctx, dim)
        
        # Expand prefix and suffix for batch and classes
        prefix = self.token_prefix.unsqueeze(0).expand(batch_size, -1, -1, -1)  # (batch, n_cls, 1, dim)
        suffix = self.token_suffix.unsqueeze(0).expand(batch_size, -1, -1, -1)  # (batch, n_cls, suffix_len, dim)
        ctx_expanded = ctx_shifted.unsqueeze(1).expand(-1, self.n_cls, -1, -1)  # (batch, n_cls, n_ctx, dim)
        
        # Concatenate tokens into full prompt embeddings per class per batch
        return torch.cat([prefix, ctx_expanded, suffix], dim=2)  # (batch, n_cls, n_tokens, dim)


# ProtoCoCoOp model: builds on CoCoOp-style prompt learning and supports prototype fusion at inference.
class ProtoCoCoOp(nn.Module):
    def __init__(self, clip_model, classnames, base_ids, n_ctx=4, ctx_init=None, device='cuda'):
        """
        Initialize the ProtoCoCoOp model.
        Args:
            clip_model: Pretrained CLIP model.
            classnames: List of class names for the dataset.
            base_ids: List of indices of base classes for prototype fusion.
            n_ctx: Number of context tokens to learn.
            ctx_init: Optional string to initialize context tokens.
            device: Device to run the model on.
        """
        super().__init__()
        # CLIP logit scale and model references
        self.logit_scale = clip_model.logit_scale
        self.clip_model = clip_model
        self.dtype = self.clip_model.dtype
        self.base_ids = torch.tensor(base_ids, device=device)  # indices of base classes for prototype fusion
        self.device = device

        # Encoders and prompt learner
        self.image_encoder = self.clip_model.visual
        self.text_encoder = TextEncoder(self.clip_model)
        self.prompt_learner = PromptLearner(self.clip_model, classnames, n_ctx, ctx_init, device)

        # Tokenized prompts for selecting projected text outputs
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts

        # Prototype matrix (num_prototypes x dim) and fusion weight alpha set at inference
        self.prototype_matrix = None
        self.alpha = None            

    def set_prototypes(self, prototype_matrix, alpha=0.2):
        # Store prototypes and fusion coefficient
        self.prototype_matrix = prototype_matrix.to(self.device).type(self.dtype)
        self.alpha = alpha

    def forward(self, image, use_prototypes=False):
        # image: (batch, 3, H, W)
        image = image.to(self.device).type(self.dtype)
        image_features = self.image_encoder(image)  # visual embedding
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)  # normalize

        # Generate per-class prompt embeddings conditioned on image features
        prompts = self.prompt_learner(image_features)  # (batch, n_cls, n_tokens, dim)
        B, C, T, D = prompts.shape
        prompts = prompts.reshape(B * C, T, D).type(self.dtype)  # flatten for text encoder

        # Repeat the stored tokenized prompt ids for each batch instance
        tokenized = self.tokenized_prompts.to(prompts.device).repeat(B, 1)

        # Encode text prompts and normalize
        text_features = self.text_encoder(prompts, tokenized)  # (B*C, proj_dim)
        text_features = text_features.reshape(B, C, -1)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Compute CLIP-style logits: scaled cosine similarity
        logits = self.logit_scale.exp() * (image_features.unsqueeze(1) @ text_features.transpose(1, 2)).squeeze(1)

        # Optional prototype fusion during inference: add prototype logits to base-class logits
        if use_prototypes and self.prototype_matrix is not None:
            # proto_logits: (batch, num_prototypes) projected and scaled
            proto_logits = self.logit_scale.exp() * (image_features @ self.prototype_matrix.T)

            # Fuse prototypes into logits for base classes only
            logits_base = logits[:, self.base_ids]
            logits[:, self.base_ids] = logits_base + self.alpha * proto_logits

        return logits  # (batch, n_cls)

## Training and Evaluation

Class that manages:

**1. Initialization:**
- Create PromptLearner
- Freeze CLIP (`requires_grad=False`)
- Configure SGD optimizer for prompt learner only

**2. train_epoch():**
- Forward: Image encoder + PromptLearner + Text encoder
- **Critical step:** Encode soft prompts through text transformer
  - Add positional embeddings
  - Pass through CLIP's transformer
  - Extract first token
  - Apply final layer norm + projection
- Compute loss: Cross-entropy on base classes
- Backward: Backprop only in PromptLearner
- Return: Average loss of the epoch

**3. eval() with Prototype Fusion:**
- Same forward procedure as training
- **NEW:** Optionally fuse CoCoOp logits with prototype similarity scores
- Fusion formula: $\text{logits} = \alpha \cdot \text{logits}_{\text{CoCoOp}} + (1-\alpha) \cdot \text{logits}_{\text{prototype}}$
- Compute accuracy on any dataset (base or novel)

**Important note:** We don't use `model.encode_text()` on soft prompts
because that method expects integer tokens, not embeddings.
We manually forward through the text transformer.

## Training and Evaluation

The code below defines `CoCoOpTrainer`, a high-level class that manages the complete training and evaluation workflow for our ProtoCoCoOp model.

**Initialization**: The trainer accepts a pretrained CLIP model, class names, base class IDs, and configuration dictionaries for both the model architecture and training hyperparameters. It freezes all CLIP parameters to preserve the pretrained representations, initializes a `ProtoCoCoOp` model with the specified context length and optional initialization string, and configures an SGD optimizer that updates only the `PromptLearner` parameters (context vectors and Meta-Net weights). A cosine annealing learning rate scheduler is used to smoothly decay the learning rate over training. The trainer also precomputes CLIP's zero-shot text features for all classes, which serve as the teacher signal when knowledge distillation is enabled. Based on the selected mode (`standard`, `kd`, `proto`, or `proto_kd`), the trainer determines which components to activate during training and inference.

**Training**: The `train()` method performs one epoch of optimization. For each batch, it computes the forward pass through `ProtoCoCoOp` to obtain student logits, then calculates the cross-entropy loss on base classes only. When knowledge distillation is enabled, it additionally computes teacher logits using frozen CLIP with hand-crafted prompts, applies temperature-scaled softmax to both distributions, and minimizes the KL divergence between them. The final loss is a weighted combination of cross-entropy and distillation losses, controlled by the $\alpha$ hyperparameter. Gradients flow only through the `PromptLearner`, leaving CLIP's encoders unchanged.

**Evaluation**: The `test()` method evaluates the model on any dataset (base or novel classes). It computes accuracy by comparing predicted class indices against ground truth labels. When prototype fusion is enabled and prototypes have been set, the model additively combines CoCoOp logits with prototype similarity scores for base classes, while novel classes rely solely on the learned prompts.

In [None]:
class CoCoOpTrainer:
    def __init__(self, clip_model, classnames, base_classes, config, params, device="cuda"):
        """
        CoCoOp Trainer for training and evaluation.

        Args:
            clip_model: Pretrained CLIP model.
            classnames: List of all class names.
            base_classes: List of base class ids.
            config: Configuration dictionary for CoCoOp. Contains 'mode', 'n_ctx', 'ctx_init'.
            params: Training parameters dictionary. Contains 'lr', 'momentum', 'weight_decay',
                    'kd_alpha', 'temperature', 'num_epochs', 'tr_batch_size', 'ts_batch_size'.
            device: Device to run the model on (default: "cuda").
        """
        self.mode = config["mode"].lower()
        if self.mode == "standard":
            self.use_proto = False
            self.use_kd = False
        elif self.mode == "kd":
            self.use_proto = False
            self.use_kd = True
        elif self.mode == "proto":
            self.use_proto = True
            self.use_kd = False
        elif self.mode == "proto_kd":
            self.use_proto = True
            self.use_kd = True
        else:
            raise ValueError(f"Invalid mode: {self.mode}. Choose from 'standard', 'kd', 'proto', 'proto_kd'.")
        
        print(f"Initialized CoCoOpTrainer in '{self.mode}' mode | use_proto={self.use_proto} | use_kd={self.use_kd}")

        self.kd_alpha = params["kd_alpha"]
        self.temperature = params["temperature"]
        self.num_epochs = params["num_epochs"]
        self.tr_batch_size = params["tr_batch_size"]
        self.ts_batch_size = params["ts_batch_size"]
        self.device = device

        # Freeze CLIP model parameters (no fine-tuning of CLIP itself).
        self.clip_model = clip_model.float().to(device).eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False

        # Precompute normalized CLIP text features for all class prompts.
        with torch.no_grad():
            prompts = [f"a photo of a {c}" for c in classnames]
            tokens = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
            text_features = self.clip_model.encode_text(tokens)
            text_features /= text_features.norm(dim=-1, keepdim=True)

        self.clip_text_features = text_features

        # Initialize CoCoOp model (prompt learner + optional prototype components).
        self.model = ProtoCoCoOp(
            self.clip_model,
            classnames,
            base_ids=base_classes,
            n_ctx=config["n_ctx"],
            ctx_init=config["ctx_init"],
            device=device
        ).to(device)

        # Optimize only the prompt learner parameters.
        self.optimizer = torch.optim.SGD(
            self.model.prompt_learner.parameters(),
            lr=params["lr"],
            momentum=params["momentum"],
            weight_decay=params["weight_decay"]
        )

        # Cosine annealing LR scheduler over epochs.
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.num_epochs)

        # Base class ids tensor on device.
        self.base_ids = torch.tensor(base_classes, device=device)

        # Map global class indices -> compact base-class indices; -1 for non-base classes.
        num_total_classes = len(classnames)
        self.label_map = torch.full((num_total_classes,), -1, dtype=torch.long, device=device)
        self.label_map[self.base_ids] = torch.arange(len(base_classes), device=device)

    # Knowledge Distillation Loss computation (KL between teacher probs and student log-probs).
    def compute_kd_loss(self, student_logits, teacher_logits):
        """
        Compute the knowledge distillation loss between student and teacher logits.
        Args:
            student_logits: Logits from the student model.
            teacher_logits: Logits from the teacher model.
        Returns:
            KL divergence loss value.
        """
        T = self.temperature

        student_log_probs = F.log_softmax(student_logits / T, dim=-1)
        teacher_probs = F.softmax(teacher_logits / T, dim=-1)

        # Multiply by T^2 as in temperature-scaled KD formulation.
        return F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (T ** 2)
    
    # Training function
    def train(self, dataset):
        """
        Trains the model for one epoch.

        Args:
            dataset: Training dataset.
        Returns:
            Average training loss over the epoch.
        """
        self.model.train()

        total_loss = 0.0
        total_samples = 0

        # DataLoader for training.
        train_loader = DataLoader(dataset, batch_size=self.tr_batch_size, shuffle=True, num_workers=1, worker_init_fn=worker_init_fn)

        for images, labels in tqdm(train_loader, desc=f"Training [{self.mode}]"):
            images = images.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()

            # Forward pass (no prototype fusion during training here).
            logits = self.model(images, use_prototypes=False)

            # Compute CE loss restricted to base classes.
            base_logits = logits[:, self.base_ids]
            targets = self.label_map[labels]

            loss_ce = F.cross_entropy(base_logits, targets)

            # Optionally compute KD loss using frozen CLIP as teacher.
            if self.use_kd:
                with torch.no_grad():
                    img_feat = self.model.clip_model.encode_image(images)
                    img_feat /= img_feat.norm(dim=-1, keepdim=True)

                    teacher_logits = (self.model.clip_model.logit_scale.exp() * img_feat @ self.clip_text_features.T)

                loss_kd = self.compute_kd_loss(logits, teacher_logits)

                # Weighted combination of CE and KD losses.
                loss = (1 - self.kd_alpha) * loss_ce + self.kd_alpha * loss_kd
            else:
                loss = loss_ce

            loss.backward()
            self.optimizer.step()

            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

        self.scheduler.step()

        return total_loss/total_samples
    
    # Evaluation function
    @torch.no_grad()
    def test(self, dataset, class_ids, use_prototypes=False):
        """
        Evaluates the model on the given dataset. 

        Args:
            dataset: Dataset to evaluate on.
            class_ids: List of class ids to consider during evaluation.
            use_prototypes: Whether to apply prototype fusion at inference.

        Returns:
            Tuple of (accuracy, average loss) over the dataset.        
        """
        self.model.eval()

        # Build mapping from global class index -> compact evaluation index.
        class_ids = torch.tensor(class_ids, device=self.device)
        mapping = torch.full((len(self.clip_text_features),), -1, dtype=torch.long, device=self.device)
        mapping[class_ids] = torch.arange(len(class_ids), device=self.device)

        # DataLoader for testing.
        test_loader = DataLoader(dataset, batch_size=self.ts_batch_size, shuffle=False, num_workers=2, worker_init_fn=worker_init_fn)

        correct_predictions = 0
        predictions = 0
        total_loss = 0.0

        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images = images.to(self.device)
            labels = labels.to(self.device)

            # Forward pass with optional prototype fusion.
            logits = self.model(images, use_prototypes=use_prototypes)
                
            # Restrict logits to requested class subset.
            logits = logits[:, class_ids]

            targets = mapping[labels]

            preds = logits.argmax(dim=1)
            loss = F.cross_entropy(logits, targets)

            correct_predictions += (preds == targets).sum().item()
            predictions += images.size(0)
            total_loss += loss.item() * images.size(0)

        return (correct_predictions/predictions, total_loss/predictions)

### Training

The following cell implements the complete training pipeline for our ProtoCoCoOp model. It initializes the trainer with the specified configuration, sets up the training loop with early stopping based on validation accuracy, and saves the best model checkpoint during training. The training process runs for up to 15 epochs on base classes only, with early stopping triggered if no improvement is observed for 5 consecutive epochs.

We will train the PromptLearner for **15 epochs** on **base classes only**.

**Hyperparameters (Optimized):**
- **Context Length (`n_ctx`):** 8 (Balanced capacity for prompt learning)
- **Batch size:** 1 (Training batch size for memory efficiency)
- **Learning rate:** 0.002 (SGD)
- **Momentum:** 0.9
- **Weight decay:** 5e-4
- **Epochs:** 15
- **Early stopping patience:** 5 epochs

**What happens:**
The training pipeline initializes a CoCoOpTrainer instance with the pre-trained CLIP model and configures it according to the selected mode (standard, proto, kd, or proto_kd). During each epoch, the PromptLearner adapts its 8 context vectors to the Flowers102 dataset while the MetaNetwork learns to inject image-specific bias efficiently. The system uses GPU-based label lookup tables to speed up target mapping and implements early stopping to prevent overfitting by monitoring validation accuracy improvements.

**Expected output:**
The training process begins with an initial loss typically ranging from 2.5 to 3.5, which progressively decreases to a final loss between 0.5 and 1.0 as the model learns better prompt representations. Training time is expected to be approximately 2-4 minutes on GPU, with the best model checkpoint automatically saved when validation accuracy improves. The early stopping mechanism ensures efficient training by halting the process when no further improvements are observed.

In [None]:
# Configuration: choose training mode and prompt settings
CURRENT_MODE = "proto_kd"  # options: "standard", "proto", "kd", "proto_kd"

config = {
    "mode": CURRENT_MODE,
    "n_ctx": 8,      # number of learnable context tokens
    "ctx_init": None # optional string to initialize context from text (None -> random init)
}

# Training hyperparameters and algorithm switches
params = {
    "lr": 0.002,            # SGD LR for prompt parameters
    "momentum": 0.9,        # accelerate convergence and smooth updates
    "weight_decay": 5e-4,   # small weight decay to regularize learned context vectors slightly
    "tr_batch_size": 1,     # training batch size (small to fit prompt learner + meta-net)  
    "ts_batch_size": 32,    # evaluation batch size
    "patience_init": 5,     # top after N non-improving epochs
    "num_epochs": 15,       # maximum epochs to allow sufficient prompt adaptation
    "proto_alpha": 0.2,     # prototype fusion weight 
    "kd_alpha": 0.3,        # KD loss weight  
    "temperature": 2.0      # KD temperature: >1 for stable distillation
}

# Initialize trainer with frozen CLIP and prompt learner
trainer = CoCoOpTrainer(
    clip_model=model,            # pretrained CLIP model object
    classnames=CLASS_NAMES,      # list of class name strings
    base_classes=base_classes,   # list of base-class integer ids
    config=config,
    params=params,
    device=device,
)

# Container to record training progress and best model
results = {
    "mode": config["mode"],
    "sampled_epochs": [],   # epochs visited
    "val_accs": [],         # validation accuracies per sampled epoch
    "best_val_acc": 0.0,    # best validation accuracy seen so far
    "losses_train": [],     # training losses per epoch
    "losses_val": [],       # validation losses per epoch
}

# Initialize early stopping counter
patience = params["patience_init"]

# Training loop with periodic evaluation and model checkpointing on improvement
print("\n" + "="*70)
print(f"TRAINING LOOP (Patience: {params['patience_init']}) | Mode: {config['mode'].upper()}")
print("="*70)

for epoch in range(trainer.num_epochs):
    results["sampled_epochs"].append(epoch)

    # Training step: updates only the prompt learner parameters
    train_loss = trainer.train(base_train_set)
    print(f"\nEpoch {epoch+1}/{trainer.num_epochs} | Train Loss: {train_loss:.4f}")

    results["losses_train"].append(np.asarray(train_loss).mean())

    # Evaluation step: measure performance on base validation set (no prototype fusion here)
    val_acc, val_loss = trainer.test(base_val_set, base_classes, use_prototypes=False)
    print(f" Validation Acc: {val_acc*100:.2f}% | Val Loss: {np.asarray(val_loss).mean():.4f}")

    results["val_accs"].append(val_acc)
    results["losses_val"].append(np.asarray(val_loss).mean())

    # If validation improves, save checkpoint and reset patience
    if val_acc > results["best_val_acc"]:
        results["best_val_acc"] = val_acc
        patience = params["patience_init"]  # reset patience

        save_path = os.path.join(models_path, f"best_model_{config['mode']}.pth")
        model_data = {
            "model_state_dict": trainer.model.state_dict(),
            "optimizer_state_dict": trainer.optimizer.state_dict(),
            "epoch": epoch,
            "config": config,
            "params": params,
            "results": results
        }
        torch.save(model_data, save_path)
        print(f"[BEST MODEL SAVED] Acc: {val_acc*100:.2f}%")
    else:
        # No improvement: decrement patience and possibly stop early
        patience -= 1
        print(f" [No Improvement | Patience left: {patience}]")
        if patience == 0:
            print(f"\nEARLY STOPPING TRIGGERED at epoch {epoch+1}!")
            break

print("="*70)
print(f"Training complete. Best Val Acc: {results['best_val_acc']*100:.2f}%")

### Training results logging and plotting

In [None]:
# Plot and log utilities for training experiments

def plot_results(results, plots_path):
    """
    Save training and validation loss curves.

    Args:
        results (dict): Dictionary containing training stats:
                        - "sampled_epochs": list of epoch indices
                        - "losses_train": list of training losses per epoch
                        - "losses_val": list of validation losses per epoch
        plots_path (str): Directory where the plot image will be saved.
    """
    plt.figure()
    # Plot training and validation loss with markers for readability
    plt.plot(results["sampled_epochs"], results["losses_train"], label="Training Loss", marker="o")
    plt.plot(results["sampled_epochs"], results["losses_val"], label="Validation Loss", marker="x")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Training and Validation Loss ({config['mode']})")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    
    # Save plot to disk using experiment mode in filename
    filename = f"{config['mode']}_training_plot.png"
    filepath = os.path.join(plots_path, filename)
    plt.savefig(filepath)
    plt.close()


def log_results(params, config, results, log_path):
    """
    Append a CSV row summarizing experiment settings and best validation result.

    Args:
        params (dict): Training hyperparameters and settings.
        config (dict): Prompt/trainer configuration.
        results (dict): Collected training results including "best_val_acc".
        log_path (str): Path to the CSV log file.
    """
    # Fields recorded for each experiment run
    log_fields = [
        "model_type",
        "num_epochs",
        "lr",
        "tr_batch_size",
        "ts_batch_size",
        "momentum",
        "weight_decay",
        "kd_alpha",
        "proto_alpha",
        "temperature",
        "n_ctx",
        "base_accuracy"
    ]

    # If CSV does not exist, create it and write header
    if not os.path.exists(log_path):
        with open(log_path, mode="w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=log_fields)
            writer.writeheader()

    # Append a single row summarizing this run
    with open(log_path, mode="a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=log_fields)
        writer.writerow({
            "model_type": results.get("mode", config.get("mode")),
            "num_epochs": len(results.get("sampled_epochs", [])),
            "lr": params.get("lr"),
            "tr_batch_size": params.get("tr_batch_size"),
            "ts_batch_size": params.get("ts_batch_size"),
            "momentum": params.get("momentum"),
            "weight_decay": params.get("weight_decay"),
            "kd_alpha": params.get("kd_alpha"),
            "proto_alpha": params.get("proto_alpha"),
            "temperature": params.get("temperature"),
            "n_ctx": config.get("n_ctx"),
            "base_accuracy": f"{results.get('best_val_acc', 0.0)*100:.2f}"
        })


# Generate and persist training plot and CSV log for this run
plot_results(results, plots_path)

log_filepath = os.path.join(logs_path, "training_log.csv")
log_results(params, config, results, log_filepath)

### Testing

The following cell performs the final evaluation of our trained ProtoCoCoOp model. It loads the best model checkpoint saved during training, reinitializes the trainer with the optimal configuration, and conducts comprehensive testing on both base and novel classes to assess the model's generalization capabilities.

We'll test the model with two distinct evaluation scenarios: testing on base classes (where we can optionally use prototype fusion if the mode supports it) and testing on novel classes (where prototypes are unavailable since these classes weren't seen during training). The evaluation computes the harmonic mean between base and novel accuracies to assess the trade-off between maintaining performance on seen classes while generalizing to unseen ones.

It's important to note that prototypes are only available for base classes since they are built from the training data. Novel classes, by definition, have no training examples and therefore cannot benefit from prototype fusion during inference.


In [None]:
# Load best trained checkpoint (if available) and evaluate final model.
# The checkpoint file contains the saved model state, optimizer state, epoch and config/params used.
best_model_path = os.path.join(models_path, f"best_model_{config['mode']}.pth")

if os.path.exists(best_model_path):
    print(f"\nLoading best model from {best_model_path}...")
    # Load checkpoint dictionary (contains model_state_dict, optimizer_state_dict, config, params, results)
    model_data = torch.load(best_model_path, weights_only=False)

    # Restore training configuration and hyperparameters used for the saved model
    config = model_data["config"]
    params = model_data["params"] 

    # Re-create trainer with the exact config/params to ensure compatibility, then load weights
    trainer = CoCoOpTrainer(
        clip_model=model,
        classnames=CLASS_NAMES,
        base_classes=base_classes,
        config=config,
        params=params,
        device=device,
    )
    trainer.model.load_state_dict(model_data["model_state_dict"])

    print("Best model loaded successfully.")
else:
    # If no checkpoint is found, continue with the current in-memory trainer/model
    print("Warning: Best model checkpoint not found! Using current model state.")

# If prototype fusion mode was used, attach the precomputed prototype matrix and fusion weight.
# Prototypes were computed only for base classes and will be fused into base-class logits at inference.
if trainer.use_proto:
    print("Setting prototypes for inference...")
    trainer.model.set_prototypes(prototype_matrix, alpha=params["proto_alpha"])

# Evaluate on base and novel test splits. Prototype fusion applied only when enabled for the trainer.
base_acc, _ = trainer.test(base_test_set, base_classes, use_prototypes=trainer.use_proto)
novel_acc, _ = trainer.test(novel_test_set, novel_classes, use_prototypes=False)

# Compute harmonic mean to assess trade-off between base and novel performance.
hm = harmonic_mean(base_acc, novel_acc)

# Nicely formatted summary of final results
print("\n" + "="*70)
print(f"RESULTS for MODE: {config['mode'].upper()}")
print("="*70)
print(f"  Base Accuracy:  {base_acc*100:6.2f}%")
print(f"  Novel Accuracy: {novel_acc*100:6.2f}%")
print(f"  Harmonic Mean:  {hm*100:6.2f}%")
print("="*70)

## Results and Discussion

## Conclusions

## References

- Radford et al., 2021 — Learning Transferable Visual Models From Natural Language Supervision

- Zhou et al., 2022 — Learning to Prompt for Vision-Language Models (CoOp)

- Zhou et al., 2022 — Conditional Prompt Learning for Vision-Language Models (CoCoOp)

- Zhang et al., 2022 — Tip-Adapter: Training-Free Adaption of CLIP for Few-Shot Classification

- Hinton et al, 2015 - Distilling the Knowledge in a Neural Network