# Deep Learning Course Project

## Table of contents
- Introduction
- The Baseline: CLIP
- Our Approach
    - Overview
    - CoCoOp
    - Prototypes Generation
    - Knowledge distillation
    - Implementation
- Results and Discussion
- Conclusions
- Authors
- References

## Introduction

## Harmonic Mean (HM)

Standard metric for few-shot adaptation papers.

Formula: HM = 2 / (1/base_acc + 1/novel_acc)

**Why HM instead of arithmetic mean?**
- HM heavily penalizes outliers
- If base=90% and novel=50%: arithmetic mean=70%, HM=64.3%
- Forces the model to balance both accuracies

**goal:** maximize HM between `base_acc_cocoop` and `novel_acc_cocoop`.


In [None]:
def harmonic_mean(base_accuracy, novel_accuracy):
    # Guard against zero to avoid division-by-zero errors
    if base_accuracy <= 0 or novel_accuracy <= 0:
        return 0.0
    numerator = 2.0
    denominator = 1.0 / base_accuracy + 1.0 / novel_accuracy
    return numerator / denominator


## The Baseline: CLIP

### Initialization

In [None]:
# Import necessary packages
import os
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

import sys
import torch
import torchvision
import numpy as np
import random
import gc
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import OrderedDict
from torch.utils.data import Dataset

try:
    import clip
    print("✓ CLIP already installed")
except Exception:
    print("Installing CLIP...")
    import subprocess, importlib
    try:
        get_ipython().run_line_magic('pip', 'install --upgrade git+https://github.com/openai/CLIP.git')
    except:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", 
                              "git+https://github.com/openai/CLIP.git"])
    importlib.invalidate_caches()
    import clip

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Create folder if it doesn't exist (for saving models)
os.makedirs("models", exist_ok=True)

# Seed for reproducibility
SEED = 42

# Function to set random seed for reproducibility
def set_seed(seed):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed for reproducibility
set_seed(SEED)

# Worker init function
def worker_init_fn(worker_id):
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)

### Data preparation
We define utility functions for:
- **`get_data()`**: Load Flowers102 from torchvision
- **`base_novel_categories()`**: Split 102 classes into base (0-50) and novel (51-101)
- **`split_data()`**: Filter images for base/novel in each split

This simulates the real scenario: we have 51 seen classes during training (base) and 51 new ones (novel).

In [None]:
# -- DATA PREPARATION FUNCTIONS --
# Load specific split of Flowers102 dataset, with given transformation
def load_split(split, transform):
    """Load Flowers102 dataset split with given transformation."""
    return torchvision.datasets.Flowers102(root="./data", split=split, download=True, transform=transform)

# Load Flowers102 dataset and return train, val, test sets
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets."""
    train = load_split("train", transform)
    val = load_split("val", transform)
    test = load_split("test", transform)

    return train, val, test

# Split dataset classes into base and novel classes
def split_classes(dataset):
    """Return base and novel class id lists using the actual labels present in the dataset."""
    labels = getattr(dataset, "targets", None)
    if labels is None:
        labels = getattr(dataset, "labels", None)

    if labels is None and hasattr(dataset, "_labels"):
        labels = dataset._labels

    if labels is None:
        raise ValueError("Could not find labels on dataset (checked 'targets','labels','_labels').")

    unique_labels = sorted(set(labels))
    num_classes = len(unique_labels)
    mid = num_classes // 2
    base_classes = unique_labels[:mid]
    novel_classes = unique_labels[mid:]

    return base_classes, novel_classes

# Split dataset into base and novel datasets
def split_data(dataset, base_classes):
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)

    return base_dataset, novel_dataset

In [None]:
# Load CLIP model and preprocessing
model, preprocess = clip.load("ViT-B/16", device=device)

# Define class names for Flowers102 dataset
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# Load dataset and split into base and novel datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# Get base and novel classes from the test set
base_classes, novel_classes = split_classes(test_set)
classes = base_classes + novel_classes

# Get class names
base_class_names = [CLASS_NAMES[i] for i in base_classes]
novel_class_names = [CLASS_NAMES[i] for i in novel_classes]
class_names = [CLASS_NAMES[i] for i in classes]

# Create base and novel datasets
base_train_set, _ = split_data(train_set, base_classes)
base_val_set, _ = split_data(val_set, base_classes)
base_test_set, novel_test_set = split_data(test_set, base_classes)

### Evaluation

In [None]:
@torch.no_grad()
def eval(model, dataset, classes, batch_size, device):
    # Set model to evaluation mode
    model.eval()

    # Map original class ids to contiguous ids starting from zero
    class_map = {cat: idx for idx, cat in enumerate(classes)}

    # Apply and tokenize standard clip sentences
    text_inputs = clip.tokenize([f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in classes]).to(device)

    # Encode text features and normalize
    text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Create dataloader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        worker_init_fn=worker_init_fn
    )

    # Compute accuracy of the model on the dataset
    tp = 0
    for image, target in tqdm(dataloader):
        target = torch.Tensor([class_map[t.item()] for t in target]).long()
        
        image = image.to(device)
        target = target.to(device)

        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # Cosine similarity between image and text features and keep the argmax for every image
        predicted_class = (image_features @ text_features.T).argmax(dim=-1)

        tp += (predicted_class == target).sum().item()

    accuracy = tp/len(dataset)

    return accuracy

print("Computing Zero-shot accuracy on both base and novel classes...")
zero_shot_base_accuracy = eval(model=model, dataset=base_test_set, classes=base_classes, batch_size=128, device=device)
zero_shot_novel_accuracy = eval(model=model, dataset=novel_test_set, classes=novel_classes, batch_size=128, device=device)
print("Computation done.\n")

print(f"Zero-shot accuracy on base classes: {zero_shot_base_accuracy*100:.2f}%")
print(f"Zero-shot accuract on novel classes: {zero_shot_novel_accuracy*100:.2f}%")

## Our Approach: Proto-guided CoCoOp with Knowledge Distillation

### Prototypes Generation

We construct **class prototypes** from CLIP image embeddings of training samples.

**Key Design Choices:**
- Use **frozen CLIP** (not the adapted model) to preserve zero-shot knowledge
- Compute prototypes from **both normal and augmented samples** for better coverage
- **L2-normalize** embeddings before averaging and after

**At Inference:**
- Compute prototype similarity: $\text{sim}_{\text{proto}}(x, c) = \frac{f(x) \cdot p_c}{\|f(x)\| \|p_c\|}$
- Fuse with CoCoOp logits: $\text{logits}_{\text{final}} = \alpha \cdot \text{logits}_{\text{CoCoOp}} + (1-\alpha) \cdot \text{logits}_{\text{proto}}$
- The fusion weight $\alpha$ controls the trade-off between prompt-based and prototype-based predictions

In [None]:
# Data augmentation transform for prototype construction
aug_view_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    torchvision.transforms.Lambda(lambda im: im.convert("RGB")),
    torchvision.transforms.RandomCrop(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomRotation(30),
    torchvision.transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                     (0.26862954, 0.26130258, 0.27577711)),
])

# Applies a transform to the PIL image and returns (tensor, label)
class TransformView(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        img, y = self.subset[idx]          # img is PIL.Image.Image
        img = self.transform(img)          # must become a torch.Tensor
        
        return img, y

# Build prototypes from augumented dataset
@torch.no_grad()
def build_prototypes(model, dataset, base_classes, device='cuda'):
    model.eval()
    
    # Collect embeddings per class
    embeddings_per_class = {c: [] for c in base_classes}
    
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=64, shuffle=False, num_workers=0
    )
    
    print(f"Extracting embeddings from {len(dataset)} samples...")
    
    for images, labels in tqdm(dataloader, desc="Building Prototypes"):
        images = images.to(device)
        
        # Get CLIP image features
        features = model.encode_image(images)
        features = features / features.norm(dim=-1, keepdim=True)  # L2 normalize
        
        for feat, label in zip(features, labels):
            label_id = label.item()
            if label_id in embeddings_per_class:
                embeddings_per_class[label_id].append(feat.cpu())
    
    # Compute mean prototype per class
    prototypes = {}
    present_classes = []

    for cls_id in base_classes:
        if len(embeddings_per_class[cls_id]) == 0:
            print(f"Warning: no samples for class {cls_id}")
            continue

        class_embeddings = torch.stack(embeddings_per_class[cls_id])
        prototype = class_embeddings.mean(dim=0).to(device)
        prototype = prototype / prototype.norm()

        prototypes[cls_id] = prototype
        present_classes.append(cls_id)

    prototype_matrix = torch.stack(
        [prototypes[c] for c in present_classes]
    ).to(device)
    
    # Create matrix for efficient inference (ordered by base_classes)
    prototype_matrix = torch.stack([prototypes[c] for c in base_classes]).to(device)
    
    print(f"✓ Built {len(prototypes)} prototypes | Matrix shape: {prototype_matrix.shape}")
    
    return prototypes, prototype_matrix

In [None]:
# Load raw train dataset (PIL images)
train_raw = load_split("train", transform=None)

# Build base subset indices on the same object (= avoid mismatched _labels across dataset instances)
base_set = set(base_classes)
base_idx = [i for i, y in enumerate(train_raw._labels) if y in base_set]  # uses Flowers102._labels
base_train_raw = torch.utils.data.Subset(train_raw, base_idx)

# Define transforms for original and augmented views
orig_view = TransformView(base_train_raw, preprocess)

num_samples = 10  # number of augmented views per original image
views = [orig_view] + [TransformView(base_train_raw, aug_view_transform) for _ in range(num_samples)]

# Create prototype pool by concatenating all views
proto_pool = torch.utils.data.ConcatDataset(views)

print("N =", len(orig_view), "pool =", len(proto_pool))  # should be N*(1+num_samples)

# Build prototypes using frozen CLIP
prototypes, prototype_matrix = build_prototypes(
    model=model,
    dataset=proto_pool,
    base_classes=base_classes,
    device=device
)

### Model Definition

**Components:**
1. **Context Vectors (V):** 16 vectors (learnable).
   - Shape: `(16, 512)`
   - Initialized: Gaussian noise N(0, 0.02)
   - Function: Provide the base context for the prompt.

2. **Meta-Network (Bias Generator):**
   - Architecture: Linear(512->32) -> ReLU -> Linear(32->512)
   - Input: Image Features `(Batch, 512)`
   - Output: Bias `(Batch, 512)` added to Context Vectors.
   - **Note:** Unlike the paper's simplified notation "$\pi$", we implement this as an **additive bias** to the context vectors.

3. **Class Embeddings:**
   - Pre-computed embeddings for "[CLASS] + EOS".
   - Fixed during training.

**Forward Pass (Vectorized):**
Instead of looping through images, we broadcast tensors to shape `(Batch, Num_Classes, Sequence_Length, Dim)`:
1. **Compute Bias:** $Bias = MetaNet(Image)$
2. **Shift Context:** $Ctx_{new} = Ctx_{base} + Bias$ (Broadcasting over classes)
3. **Concatenate:** $[Prefix] + [Ctx_{new}] + [Suffix]$ (All in parallel)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(self.dtype)
        # Extract [EOS] features
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x

class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=16, ctx_init=None, device='cuda'):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        vis_dim = clip_model.visual.output_dim
        
        # Context Initialization
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(torch.float16)
            ctx_vectors = embedding[0, 1:1+n_ctx, :]
            prompt_prefix = ctx_init
        else:
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=torch.float16)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        
        self.ctx = nn.Parameter(ctx_vectors)
        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
        ]))
        
        classnames = [name.replace("_", " ") for name in classnames]
        prompts = [f"{prompt_prefix} {name}." for name in classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(torch.float16)
            
        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :])
        self.n_cls, self.n_ctx = n_cls, n_ctx
        self.tokenized_prompts = tokenized_prompts

    def forward(self, im_features):
        batch_size = im_features.shape[0]
        bias = self.meta_net(im_features).unsqueeze(1)
        ctx_shifted = self.ctx.unsqueeze(0) + bias
        
        prefix = self.token_prefix.unsqueeze(0).expand(batch_size, -1, -1, -1)
        suffix = self.token_suffix.unsqueeze(0).expand(batch_size, -1, -1, -1)
        ctx_expanded = ctx_shifted.unsqueeze(1).expand(-1, self.n_cls, -1, -1)
        
        return torch.cat([prefix, ctx_expanded, suffix], dim=2)

class CustomCLIP(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=4, ctx_init=None, device='cuda'):
        super().__init__()
        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx, ctx_init, device)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.prototype_matrix = None 
        self.alpha = 0.5 

    def set_prototypes(self, prototype_matrix, alpha=0.5):
        self.prototype_matrix = prototype_matrix.type(self.dtype)
        self.alpha = alpha

    def forward(self, image, use_prototypes=False):
        logit_scale = self.logit_scale.exp()
        image_features = self.image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        prompts = self.prompt_learner(image_features)
        b, c, n, d = prompts.shape
        prompts_flat = prompts.reshape(b * c, n, d).type(self.dtype)
        tokenized_expanded = self.tokenized_prompts.repeat(b, 1)

        text_features = self.text_encoder(prompts_flat, tokenized_expanded)
        text_features = text_features.reshape(b, c, -1)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logits = logit_scale * (image_features.unsqueeze(1) @ text_features.transpose(1, 2)).squeeze(1)

        if use_prototypes and self.prototype_matrix is not None:
            proto_logits = logit_scale * (image_features @ self.prototype_matrix.T)
            return self.alpha * logits + (1 - self.alpha) * proto_logits
        return logits

## Training and Evaluation

Class that manages:

**1. Initialization:**
- Create PromptLearner
- Freeze CLIP (`requires_grad=False`)
- Configure SGD optimizer for prompt learner only

**2. train_epoch():**
- Forward: Image encoder + PromptLearner + Text encoder
- **Critical step:** Encode soft prompts through text transformer
  - Add positional embeddings
  - Pass through CLIP's transformer
  - Extract first token
  - Apply final layer norm + projection
- Compute loss: Cross-entropy on base classes
- Backward: Backprop only in PromptLearner
- Return: Average loss of the epoch

**3. eval() with Prototype Fusion:**
- Same forward procedure as training
- **NEW:** Optionally fuse CoCoOp logits with prototype similarity scores
- Fusion formula: $\text{logits} = \alpha \cdot \text{logits}_{\text{CoCoOp}} + (1-\alpha) \cdot \text{logits}_{\text{prototype}}$
- Compute accuracy on any dataset (base or novel)

**Important note:** We don't use `model.encode_text()` on soft prompts
because that method expects integer tokens, not embeddings.
We manually forward through the text transformer.

In [None]:
class CoCoOpTrainer:
    def __init__(self, clip_model, classnames, base_classes, device='cuda', lr=0.002, num_epochs=10):
        self.device = device
        self.classnames = classnames
        self.base_classes = base_classes
        self.base_indices = torch.arange(len(self.base_classes), device=self.device)
        self.num_epochs = num_epochs

        # Teacher CLIP model (frozen)
        self.teacher = clip_model.float().to(device).eval()
        for p in self.teacher.parameters():
            p.requires_grad = False

        # Pre-compute teacher text features for base classes
        with torch.no_grad():
            tokens = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classnames]).to(self.device)
            text_features = self.teacher.encode_text(tokens)
            text_features /= text_features.norm(dim=-1, keepdim=True)
        self.teacher_text_features = text_features

        # Student model
        self.model = CustomCLIP(clip_model, classnames, device=device).to(device)

        # Label mapping: global dataset label -> local base-class index (for CE)
        max_label_id = max(base_classes) + 1
        self.label_map = torch.full((max_label_id,), -1, dtype=torch.long, device=device)
        base_ids_tensor = torch.tensor(base_classes, device=device)
        self.label_map[base_ids_tensor] = torch.arange(len(base_classes), device=device)

        self.optimizer = torch.optim.SGD(
            self.model.prompt_learner.parameters(),
            lr=lr, momentum=0.9, weight_decay=5e-4
        )
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=num_epochs)
        

    def compute_kd_loss(self, student_logits, teacher_logits, temperature=2.0):
        student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
        teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
        return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

    def train_epoch_with_kd(self, dataloader, kd_alpha=0.5, temperature=2.0, accumulation_steps=4):
        self.model.train()
        total_loss, n_batches = 0.0, 0
        self.optimizer.zero_grad()

        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc="Training with KD (Grad Accum)")):
            images = images.to(self.device).float()
            labels = labels.to(self.device)

            # Map labels to 0..N-1 for base classes CE
            labels_mapped = self.label_map[labels]

            # Student forward (logits over all classes)
            student_logits = self.model(images, use_prototypes=False)

            # Teacher forward (zero-shot logits over all classes)
            with torch.no_grad():
                img_features = self.teacher.encode_image(images)
                img_features /= img_features.norm(dim=-1, keepdim=True)
                teacher_logits = self.teacher.logit_scale.exp() * (img_features @ self.teacher_text_features.T)

            # Compute CE only on base classes
            student_base_logits = student_logits[:, self.base_indices]
            loss_ce = F.cross_entropy(student_base_logits, labels_mapped)

            # Compute KD over all classes (student vs teacher logits)
            loss_kd = self.compute_kd_loss(student_logits, teacher_logits, temperature)

            # Weighted hybrid loss
            loss = (1 - kd_alpha) * loss_ce + kd_alpha * loss_kd
            loss = loss / accumulation_steps
            loss.backward()

            # Gradient accumulation
            if (batch_idx + 1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

            total_loss += loss.item() * accumulation_steps
            n_batches += 1

        # Step for remaining gradients
        if len(dataloader) % accumulation_steps != 0:
            self.optimizer.step()
            self.optimizer.zero_grad()

        self.scheduler.step()
        return total_loss / max(1, n_batches)
    
    def train_epoch(self, dataloader, accumulation_steps=4):
        self.model.train()
        
        total_loss = 0
        n_batches = 0
        
        self.optimizer.zero_grad()
        
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc="Training (Grad Accum)")):
            images = images.to(self.device).float()
            labels = labels.to(self.device)
            
            # Fast mapping using pre-computed GPU tensor
            labels_mapped = self.label_map[labels]
            
            # Forward pass (logits over all classes)
            logits = self.model(images, use_prototypes=False)

            # Compute CE only on base classes
            base_logits = logits[:, self.base_indices]
            loss = F.cross_entropy(base_logits, labels_mapped)
            
            # --- GRADIENT ACCUMULATION LOGIC ---
            loss = loss / accumulation_steps
            loss.backward()
            
            if (batch_idx + 1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
            
            total_loss += loss.item() * accumulation_steps
            n_batches += 1
        
        # Process remaining gradients if dataloader size is not divisible by accumulation_steps
        if len(dataloader) % accumulation_steps != 0:
             self.optimizer.step()
             self.optimizer.zero_grad()
        
        self.scheduler.step()
        return total_loss / max(1, n_batches)

    @torch.no_grad()
    def eval(self, dataset, categories, batch_size=64, use_prototypes=False):
        self.model.eval()
        local_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

        correct, total = 0, 0
        for images, labels in tqdm(dataloader, desc="Validating"):
            images = images.to(self.device).float()
            logits = self.model(images, use_prototypes=use_prototypes)
            preds = logits.argmax(dim=1)
            targets = torch.tensor([local_cat2idx[l.item()] for l in labels], device=self.device)
            correct += (preds == targets).sum().item()
            total += labels.size(0)

        return correct / total if total > 0 else 0.0


### Training

We will train the PromptLearner for **10 epochs** on **base classes only**.

**Hyperparameters (Optimized):**
- **Context Length (`n_ctx`):** 16 (Increased capacity for fine-grained details)
- **Batch size:** 4 (Increased from 1 thanks to parallelization)
- **Learning rate:** 0.002 (SGD)
- **Momentum:** 0.9
- **Weight decay:** 5e-4
- **Epochs:** 10

**What happens:**
- The `PromptLearner` adapts its 16 context vectors to the Flowers102 dataset.
- The `MetaNetwork` learns to inject image-specific bias efficiently.
- **Optimization:** We use a GPU-based label lookup table to speed up target mapping.

**Expected output:**
- Initial loss: ~2.5 - 3.5
- Final loss: ~0.5 - 1.0 (Lower than before due to better context capacity)
- Training time: ~2-4 minutes on GPU

In [None]:
import os
from torch.utils.data import DataLoader

# Make sure checkpoint folder exists
os.makedirs("checkpoints", exist_ok=True)

# Initialize trainer
trainer = CoCoOpTrainer(
    clip_model=model,
    classnames=base_class_names,   # Only base classes for training
    base_classes=base_classes,     
    device=device,
    lr=0.002,
    num_epochs=50
)

# Training parameters
patience = 5
mode = "standard"  # "standard" = CE only, "kd" = knowledge distillation
accumulation_steps = 4

best_val_acc = 0.0
best_epoch = 0
counter = 0

print("\n" + "="*70)
print(f"TRAINING CoCoOp (Patience: {patience})")
print("="*70)

for epoch in range(trainer.num_epochs):
    # Base-class training dataloader
    train_loader = DataLoader(base_train_set, batch_size=4, shuffle=True)

    # --- TRAINING STEP ---
    if mode == "standard":
        avg_loss = trainer.train_epoch(train_loader, accumulation_steps=accumulation_steps)
    else:
        avg_loss = trainer.train_epoch_with_kd(
            train_loader,
            kd_alpha=0.5,
            temperature=2.0,
            accumulation_steps=accumulation_steps
        )

    # --- EVALUATION STEP ---
    # 1. Base-class validation (few-shot) with prototypes
    val_acc_base = trainer.eval(base_val_set, base_classes, batch_size=64, use_prototypes=True)
    
    # 2. Optional: Novel-class validation (zero-shot)
    val_acc_novel = trainer.eval(novel_test_set, novel_classes, batch_size=64, use_prototypes=False)

    print(f"Epoch {epoch+1:02d}/{trainer.num_epochs} - "
          f"Loss: {avg_loss:.4f} | Base Acc: {val_acc_base*100:.2f}% | "
          f"Novel Acc: {val_acc_novel*100:.2f}%", end="")

    # --- EARLY STOPPING & CHECKPOINT ---
    if val_acc_base > best_val_acc:
        best_val_acc = val_acc_base
        best_epoch = epoch + 1
        counter = 0
        torch.save(trainer.model.prompt_learner.state_dict(), "checkpoints/best_model.pth")
        print("  [BEST MODEL SAVED]")
    else:
        counter += 1
        print(f"  [No Improvement {counter}/{patience}]")
        if counter >= patience:
            print(f"\nEARLY STOPPING TRIGGERED at epoch {epoch+1}!")
            break

print("="*70)
print(f"Training complete. Best Base Acc: {best_val_acc*100:.2f}% at epoch {best_epoch}")


### Testing

We'll test the model with:
1. **Test Base** - CoCoOp only vs CoCoOp + Prototypes
2. **Test Novel** - CoCoOp only (no prototypes for novel classes)

Computing Harmonic Mean between them to evaluate the trade-off.

**Note:** Prototypes are only available for base classes (built from training data).


In [None]:
# --- LOAD BEST MODEL FOR FINAL EVALUATION ---
trainer.model.prompt_learner.load_state_dict(torch.load("checkpoints/best_model.pth"))

# --- INJECT PROTOTYPE MATRIX FOR HYBRID INFERENCE (BASE CLASSES ONLY) ---
trainer.model.set_prototypes(prototype_matrix, alpha=0.5)
print("Prototype Matrix injected for hybrid inference.")

# --- FINAL EVALUATION ---
base_acc = trainer.eval(base_test_set, base_classes, batch_size=64, use_prototypes=True)
novel_acc = trainer.eval(novel_test_set, novel_classes, batch_size=64, use_prototypes=False)
hm = harmonic_mean(base_acc, novel_acc)

print("\n" + "="*70)
print("RESULTS")
print("="*70)
print(f"  Base Accuracy:  {base_acc*100:6.2f}%")
print(f"  Novel Accuracy: {novel_acc*100:6.2f}%")
print(f"  Harmonic Mean:  {hm*100:6.2f}%")
print("="*70)

## Results and Discussion

## Conclusions

## Refrences