# Disclaimer on Notebook Version

This notebook contains a simplified, self-contained version of the codebase developed for our project on zero-shot generalization in CLIP. The original project code is modular and includes features that are not compatible with notebook constraints.

For the full implementation and all functionality, refer to:
- GitHub repository: https://github.com/bacobax/DeepL-project  
- Project report (PDF): https://drive.google.com/file/d/1Mi2i5793zcRzsxmeUoWpJyaR2JdbhuhQ/view?usp=sharing

### Limitations of this Notebook Version

To comply with the self-contained nature of Jupyter notebooks, some features have been removed or modified:

- **YAML-based configuration**: Disabled to avoid external file dependencies. An example configuration is instead defined directly in the notebook.
- **Clustering caching**: Saving and loading of clustering results is disabled. Clustering is recomputed every time the program runs.
- **Log-based result plots**: Since logs are not easily handled in the notebook, the original repository is still used to generate result plots.

This notebook aims to demonstrate the core components and workflow but is not a full replacement for the original repository.


In [None]:

#install the requirements for the project

%pip install annotated-types==0.7.0 git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 contourpy==1.3.1 cycler==0.12.1 docker-pycreds==0.4.0 easydict==1.13 gitdb==4.0.12 GitPython==3.1.44 git+https://github.com/greydanus/mnist1d@7878d96082abd200c546a07a4101fa90b30fdf7e pooch==1.8.2 pydantic==2.10.6 pydantic_core==2.27.2 PyWavelets==1.8.0 regex==2024.11.6 sentry-sdk==2.23.1 setproctitle==1.3.5 smmap==5.0.2 tensorboard-data-server==0.7.2 wandb==0.19.8 zstandard==0.23.0


# Introduction

This project focuses on improving few-shot learning performance on a fixed dataset and experimental setup chosen by the instructor. The main constraint was to build upon the CoCoOp prompt learning method.

Experiments were conducted on the Oxford 102 Flowers dataset, divided into base and novel classes as specified by the course guidelines. The model is trained using few-shot samples from the base classes and evaluated on both base and novel classes to assess generalization.

Our approach explores modular training techniques such as KL divergence-based knowledge distillation and adversarial training to enhance CoCoOp’s ability to generalize to novel classes. The methods are designed to be flexible, allowing various combinations and settings to be tested.

This report details our implementation, experiments, and results within the context of the course project.


# Training Systems

### CoCoOp Infrastructure

**CoOp** [^4] and **CoCoOp** [^3] are prompt learning methods that adapt pretrained vision-language models by optimizing learnable context vectors (prompts) to improve performance on downstream tasks. While CoOp learns static prompts shared across all inputs, CoCoOp generates dynamic, image-conditioned prompts via a meta-network to better generalize to unseen classes and domains.

In our implementation, we first reproduced CoOp [^4] to obtain learned context vectors (`ctx`), which we then used to initialize CoCoOp [^3] rather than training it from scratch. This approach offers several benefits:

- **Improved Generalization:** CoOp’s context vectors, learned from limited few-shot data, provide a strong discriminative foundation that helps CoCoOp generalize better across domains and unseen classes.

- **Stabilized Meta-Network Training:** Initializing CoCoOp’s prompt learner close to a working solution improves training stability and convergence speed.

- **Better Semantic Initialization:** CoOp’s class-level embeddings encode meaningful semantic information, giving CoCoOp’s dynamic prompts a valuable head start.

This modular design leverages existing learned knowledge to improve efficiency and robustness in few-shot and domain generalization settings, aligning with CoCoOp’s goal of adaptive prompt generation based on visual context [^3].



## Model CoOp

### Custom CLIP CoOp

In [None]:
# From: ./model/coop/custom_clip.py
import clip
import torch
from torch import nn
from torch.nn import functional as F
from contextlib import contextmanager



class CustomCLIPCoOp(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.prompt_learner = PromptLearnerCoOp(cfg, classnames, clip_model)
        self.clip_model = clip_model
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.cfg = cfg
        self.dtype = clip_model.dtype

    def forward(self, image, label=None):
        logit_scale = self.logit_scale.exp()

        #print("Raw logit_scale:", self.logit_scale.item())
        #print("Exp logit_scale:", logit_scale)

        image_features = self.image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        prompts = self.prompt_learner()
        text_features = self.text_encoder(prompts.type(self.dtype), self.tokenized_prompts)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        if torch.isnan(image_features).any():
            print("⚠️ NaNs in image_features!")

        if torch.isnan(text_features).any():
            print("⚠️ NaNs in text_features!")

        if torch.isinf(image_features).any():
            print("⚠️ Infs in image_features!")

        if torch.isinf(text_features).any():
            print("⚠️ Infs in text_features!")

        #print("Image feature norm:", image_features.norm(dim=-1).mean().item())
        #print("Text feature norm:", text_features.norm(dim=-1).mean().item())

        logits = logit_scale * image_features @ text_features.t()

        #print(f"is training: {self.prompt_learner.training}, label: {label}")

        if label is not None:

            return F.cross_entropy(logits, label), logits

        return logits
    @contextmanager
    def temporary_classnames(self, new_classnames):
        # --- Save original state ---
        original_classnames = self.prompt_learner.n_cls
        original_tokenized_prompts = self.tokenized_prompts
        original_token_prefix = self.prompt_learner.token_prefix
        original_token_suffix = self.prompt_learner.token_suffix

        # --- Apply temporary state ---
        temp_prompt_learner = PromptLearnerCoOp(
            cfg=self.cfg,
            classnames=new_classnames,
            clip_model=self.clip_model
        )

        self.tokenized_prompts = temp_prompt_learner.tokenized_prompts
        self.prompt_learner.tokenized_prompts = temp_prompt_learner.tokenized_prompts
        self.prompt_learner.token_prefix = temp_prompt_learner.token_prefix
        self.prompt_learner.token_suffix = temp_prompt_learner.token_suffix
        self.prompt_learner.n_cls = len(new_classnames)

        try:
            yield
        finally:
            # --- Restore original state ---
            self.tokenized_prompts = original_tokenized_prompts
            self.prompt_learner.tokenized_prompts = original_tokenized_prompts
            self.prompt_learner.token_prefix = original_token_prefix
            self.prompt_learner.token_suffix = original_token_suffix
            self.prompt_learner.n_cls = original_classnames

### Text Encoder

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding
        x = x.permute(1, 0, 2)  # [batch_size, n_ctx, transformer.width] -> [n_ctx, batch_size, transformer.width]
        x = self.transformer(x.to(dtype=self.transformer.resblocks[0].mlp.c_fc.weight.dtype))
        x = x.permute(1, 0, 2)  # [n_ctx, batch_size, transformer.width] -> [batch_size, n_ctx, transformer.width]
        x = self.ln_final(x)

        # Take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection



        return x

### Prompt Learner CoOp

In [None]:
# From: ./model/coop/custom_clip.py

class PromptLearnerCoOp(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = cfg.TRAINER.COOP.N_CTX
        ctx_init = cfg.TRAINER.COOP.CTX_INIT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.INPUT.SIZE[0]
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal clip_imsize ({clip_imsize})"

        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        #print(f'Initial context: "{prompt_prefix}"')
        #print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)

        classnames = [name.replace("_", " ") for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts.to(clip_model.token_embedding.weight.device))
            embedding = embedding.type(dtype)

        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])

        self.tokenized_prompts = tokenized_prompts
        self.n_cls = n_cls
        self.n_ctx = n_ctx

    def construct_prompts(self, ctx, prefix, suffix):
        return torch.cat([prefix, ctx, suffix], dim=1)

    def forward(self):
        ctx = self.ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        return self.construct_prompts(ctx, self.token_prefix, self.token_suffix)


## Training System CoOp

In [None]:
# From: ./training_systems/coop.py
"""
This module defines the CoOpSystem class for training and evaluating a prompt-tuned CLIP model using the CoOp method.
It includes data loading, training with early stopping, evaluation, model saving/loading, and logging to TensorBoard.
"""

import os

import torch
from easydict import EasyDict
from tqdm import tqdm

# from model.coop.custom_clip import CustomCLIPCoOp
from utils.datasets import get_data, base_novel_categories, split_data, CLASS_NAMES
from utils.training_coop import test_step, training_step, eval_step
import clip
from torch.utils.tensorboard.writer import SummaryWriter
from torch.optim import SGD, Adam
from torch import nn


class CoOpSystem:
    """
    Implements the CoOp prompt tuning system for training and evaluating CLIP-based models.

    Attributes:
        batch_size (int): Number of samples per training batch.
        device (str): Device identifier (e.g., "cuda:0") to run training on.
        learning_rate (float): Learning rate for the optimizer.
        weight_decay (float): Weight decay used in the optimizer.
        momentum (float): Momentum term (if applicable).
        epochs (int): Number of training epochs.
        run_name (str): Identifier for the experiment run (used for logging and file naming).
        n_ctx (int): Number of context tokens for prompt tuning.
        ctx_init (str): Initialization string for context tokens.
        class_token_position (str): Position of the class token in the prompt.
        csc (bool): Whether to use class-specific context.
    """
    def __init__(self,
                 batch_size=16,
                 device="cuda:0",
                 learning_rate=0.002,
                 weight_decay=0.0005,
                 momentum=0.9,
                 epochs=2,
                 run_name="exp1",
                 n_ctx=4,
                 ctx_init="",
                 class_token_position="end",
                 csc=False,
                 ):
        self.batch_size = batch_size
        self.device = device
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.epochs = epochs
        self.run_name = run_name
        self.n_ctx = n_ctx
        self.ctx_init = ctx_init
        self.class_token_position = class_token_position
        self.csc = csc

        # Create a logger for the experiment
        self.writer = SummaryWriter(log_dir=f"runs/CoOp/{run_name}")
        self.writer.add_scalar(f"lr", self.learning_rate, 0)
        self.writer.add_scalar(f"momentum", self.momentum, 0)

        # Get dataloaders

        self.clip_model, _ = clip.load("ViT-B/16", device=self.device)
        self.clip_model = self.clip_model.to(self.device)
        self.clip_model = self.clip_model.float()
        resolution = self.clip_model.visual.input_resolution
        self.train_set, self.val_set, self.test_set = get_data(resolution=resolution)

        # split classes into base and novel
        self.base_classes, self.novel_classes = base_novel_categories(self.train_set)

        # split the three datasets
        self.train_base, _ = split_data(self.train_set, self.base_classes)
        self.val_base, self.val_novel = split_data(self.val_set, self.base_classes)
        self.test_base, self.test_novel = split_data(self.test_set, self.base_classes)

        #self.classnames, _ = embed_dataset_classnames(dataset_name, preprocess=preprocess, model=clip_model)

        resolution = self.clip_model.visual.input_resolution

        cfg = EasyDict()
        # Training configuration
        cfg.TRAINER = EasyDict()
        cfg.TRAINER.COOP = EasyDict()
        cfg.TRAINER.COOP.N_CTX = self.n_ctx  # Number of context tokens
        cfg.TRAINER.COOP.CTX_INIT = self.ctx_init  # Leave empty for random initialization
        cfg.INPUT = EasyDict()
        cfg.INPUT.SIZE = [resolution, resolution]  # Must match CLIP model's input resolution

        # Instantiate the network and move it to the chosen device (GPU)
        self.model = CustomCLIPCoOp(
            classnames=[CLASS_NAMES[idx] for idx in self.base_classes],
            cfg=cfg,
            clip_model=self.clip_model,
        ).to(device)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.model.named_parameters():
            if "prompt_learner" in name:
                param.requires_grad_(True)
            else:
                param.requires_grad_(False)

        print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        print(f"Total trainable parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}")

        self.optimizer = self.get_optimizer(self.learning_rate, self.weight_decay, self.momentum)
        self.cost_function = nn.CrossEntropyLoss()

    def train(self):
        """
        Trains the CoOp model on the base classes with early stopping and logs performance metrics.
        Saves the best model and computes evaluation at the end of training.
        """
        print("Before training:")
        print("Training the model...")
        print_epoch_interval = 1
        pbar = tqdm(total=self.epochs, desc="OVERALL TRAINING", position=0, leave=True)

        best_val_acc = 0.0
        patience = 3
        counter = 0
        best_model_state = None

        for e in range(self.epochs):
            base_train_loss, base_train_accuracy = training_step(
                model=self.model,
                dataset=self.train_base,
                optimizer=self.optimizer,
                batch_size=self.batch_size,
                classnames=self.base_classes,
                device=self.device,
            )

            if e % print_epoch_interval == 0:
                base_val_loss, base_val_accuracy = eval_step(
                    model=self.model,
                    dataset=self.val_base,
                    cost_function=self.cost_function,
                    new_classnames=self.base_classes,
                    device=self.device,
                    batch_size=self.batch_size,
                )

                self.log_values(e, base_train_loss, base_train_accuracy, "train_base")
                self.log_values(e, base_val_loss, base_val_accuracy, "validation_base")

                pbar.set_postfix(train_acc=base_train_accuracy, val_acc=base_val_accuracy)

                # Early stopping check
                if base_val_accuracy > best_val_acc:
                    best_val_acc = base_val_accuracy
                    counter = 0
                    best_model_state = self.model.state_dict()
                else:
                    counter += 1
                    if counter >= patience:
                        print(f"Early stopping at epoch {e}, best validation accuracy: {best_val_acc:.4f}")
                        break

            pbar.update(1)

        # Restore best model if early stopped
        if best_model_state is not None:
            self.model.load_state_dict(best_model_state)

        print("After training:")
        self.compute_evaluation(self.epochs)
        self.writer.close()

        self.save_model()
        self.save_prompt_learner()

    def save_model(self, path="./bin/coop"):
        """
        Saves the entire model's state dictionary to disk under the specified path.

        Args:
            path (str): Directory path where the model checkpoint will be saved.
        """
        #create folder if not exist
        if not os.path.exists(path):
            os.makedirs(path)
        # Save the model
        torch.save(self.model.state_dict(), os.path.join(path, f"{self.run_name}.pth"))

    def save_prompt_learner(self, path="./bin/coop"):
        """
        Saves only the prompt learner component of the model to disk.

        Args:
            path (str): Directory path where the prompt learner checkpoint will be saved.
        """
        # Create folder if not exist
        if not os.path.exists(path):
            os.makedirs(path)
        # Save only the self.ctx parameter of the prompt learner
        ctx_state = {"ctx": self.model.prompt_learner.ctx.detach().cpu()}
        torch.save(ctx_state, os.path.join(path, f"{self.run_name}_prompt_learner.pth"))

    def load_model(self, path="./bin"):
        """
        Loads a saved model checkpoint from disk and sets the model to evaluation mode.

        Args:
            path (str): Directory path from which the model checkpoint will be loaded.
        """
        # Load the model
        self.model.load_state_dict(torch.load(os.path.join(path, f"{self.run_name}.pth")))
        self.model.eval()
        print(f"Model loaded from {path}")

    def compute_evaluation(self, epoch_idx, base=False):
        """
        Evaluates the model (or zero-shot CLIP if base=True) on the base test set and logs accuracy.

        Args:
            epoch_idx (int): Index of the current epoch for logging.
            base (bool): If True, use zero-shot CLIP model for evaluation instead of the trained model.

        Returns:
            float: Accuracy on the base test set.
        """
        base_accuracy = test_step(
            self.model if not base else self.clip_model,
            self.test_base,
            self.batch_size,
            self.device,
            self.base_classes,
            label="test",
            base=base
        )
        # Log to TensorBoard
        self.log_value(epoch_idx,  base_accuracy, "base_classes")

        return base_accuracy

    def get_optimizer(self, lr, wd, momentum):
        """
        Instantiates and returns the optimizer for the model parameters.

        Args:
            lr (float): Learning rate.
            wd (float): Weight decay.
            momentum (float): Momentum term (unused for Adam optimizer).

        Returns:
            torch.optim.Optimizer: Configured Adam optimizer instance.
        """
        optimizer = Adam([
            {
                "params": self.model.parameters()
            }
        ], lr=lr, weight_decay=wd)

        return optimizer

    def log_value(self, step,  accuracy, prefix):
        """
        Logs a single scalar value (accuracy) to TensorBoard.

        Args:
            step (int): Training step or epoch index.
            accuracy (float): Accuracy value to log.
            prefix (str): Tag prefix to categorize the metric in TensorBoard.
        """
        self.writer.add_scalar(f"{prefix}/accuracy", accuracy, step)

    def log_values(self, step, loss, accuracy, prefix):
        """
        Logs both loss and accuracy values to TensorBoard.

        Args:
            step (int): Training step or epoch index.
            loss (float): Loss value to log.
            accuracy (float): Accuracy value to log.
            prefix (str): Tag prefix to categorize the metrics in TensorBoard.
        """
        self.writer.add_scalar(f"{prefix}/loss", loss, step)
        self.writer.add_scalar(f"{prefix}/accuracy", accuracy, step)

## Model CoCoOp

### Custom CLIP CoCoOp

In [None]:
# From: ./model/cocoop/custom_clip.py
import os.path as osp
from collections import OrderedDict
import math
from contextlib import contextmanager

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
# from model.cocoop.prompt_learner import PromptLearner
# from model.cocoop.mlp_adversary import GradientReversalLayer, AdversarialMLP
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from easydict import EasyDict

_tokenizer = _Tokenizer()


class CustomCLIP(nn.Module):
    """
    CustomCLIP is a modified CLIP wrapper supporting CoCoOp prompt learning and adversarial training extensions.
    It enables image-text matching via learned prompts and exposes additional methods for class manipulation.
    """
    """
    Initialize the CustomCLIP model with the given configuration, classnames, and base CLIP model.
    Sets up the image and text encoders and prepares the prompt learner.
    """
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.clip_model = clip_model
        self.cfg = cfg


    """
    Updates the prompt learner with a new list of classnames and resets its tokenized prompts.
    """
    def change_classnames(self, new_classnames):
        temp_prompt_learner = PromptLearner(
            cfg=self.cfg,
            classnames=new_classnames,
            clip_model=self.clip_model
        )

        self.tokenized_prompts = temp_prompt_learner.tokenized_prompts
        self.prompt_learner.tokenized_prompts = temp_prompt_learner.tokenized_prompts
        self.prompt_learner.token_prefix = temp_prompt_learner.token_prefix
        self.prompt_learner.token_suffix = temp_prompt_learner.token_suffix
        self.prompt_learner.n_cls = len(new_classnames)



    """
    A context manager that temporarily sets new classnames for inference, restoring the original ones afterward.
    """
    @contextmanager
    def temporary_classnames(self, new_classnames):
        # --- Save original state ---
        original_classnames = self.prompt_learner.n_cls
        original_tokenized_prompts = self.tokenized_prompts
        original_token_prefix = self.prompt_learner.token_prefix
        original_token_suffix = self.prompt_learner.token_suffix

        # --- Apply temporary state ---
        temp_prompt_learner = PromptLearner(
            cfg=self.cfg,
            classnames=new_classnames,
            clip_model=self.clip_model
        )

        self.tokenized_prompts = temp_prompt_learner.tokenized_prompts
        self.prompt_learner.tokenized_prompts = temp_prompt_learner.tokenized_prompts
        self.prompt_learner.token_prefix = temp_prompt_learner.token_prefix
        self.prompt_learner.token_suffix = temp_prompt_learner.token_suffix
        self.prompt_learner.n_cls = len(new_classnames)

        try:
            yield
        finally:
            # --- Restore original state ---
            self.tokenized_prompts = original_tokenized_prompts
            self.prompt_learner.tokenized_prompts = original_tokenized_prompts
            self.prompt_learner.token_prefix = original_token_prefix
            self.prompt_learner.token_suffix = original_token_suffix
            self.prompt_learner.n_cls = original_classnames


    """
    Performs a forward pass through the model, computing similarity logits between image features and prompt-conditioned text features.
    Optionally returns loss, image features, and intermediate representations during training.
    """
    def forward(self, image, label=None, get_image_features=False):
        # tokenized_prompts: [num_classes, context_length] (e.g., [10, 77])
        tokenized_prompts = self.tokenized_prompts

        # logit_scale: scalar (e.g., initialized as a learnable parameter like torch.tensor(1/0.07).log())
        logit_scale = self.logit_scale.exp()

        # image: [B, 3, H, W]
        # image_features: [B, D] where D = transformer width (e.g., 512 for ViT-B/32)
        #print(f"image device: {image.device} | image encoder device: {next(self.image_encoder.parameters()).device}")
        image_features = self.image_encoder(image.type(self.dtype))
        if image_features.isnan().any():
            raise ValueError("NaN detected in image_features.")
        # Normalize image features: each row to unit length
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # prompts: List of [num_classes, context_length, D] (one per image feature)
        # Each element is generated conditioned on an image feature
        prompts, ctx, bias = self.prompt_learner(image_features) # [B , n_cls, n_ctx, D]
        if prompts.isnan().any():
            raise ValueError("NaN detected in prompts.")
        # prompts: [B, n_cls, n_ctx, D] -> [B * n_cls, n_ctx, D]
        logits = []
        all_text_features = []
        selected_text_features = []
        # Iterate over batch
        for pts_i, imf_i in zip(prompts, image_features):
            # pts_i: [num_classes, context_length, D]
            # tokenized_prompts: [num_classes, context_length]
            # text_features: [num_classes, D]
            text_features = self.text_encoder(pts_i, tokenized_prompts)
            if text_features.isnan().any():
                raise ValueError("NaN detected in text ft.")
            all_text_features.append(text_features)
            # Normalize text features
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # imf_i: [D], text_features.T: [D, num_classes]
            # l_i: [num_classes], similarity scores between image and all class prompts
            l_i = logit_scale * imf_i @ text_features.t()
            # Append l_i (1D tensor) to logits list
            logits.append(l_i)
            best_idx = l_i.argmax()
            best_text_feat = text_features[best_idx]  # [D]
            selected_text_features.append(best_text_feat)

        all_text_features = torch.stack(all_text_features)  # [B, num_classes, D]
        #avarage over num_classes
        avg_text_features = all_text_features.mean(dim=1)

        # Shape: [B, D]
        selected_text_features = torch.stack(selected_text_features)
        # logits: list of B tensors each of shape [num_classes]
        # stacked into a tensor of shape [B, num_classes]
        logits = torch.stack(logits)
        if logits.isnan().any():
            raise ValueError("NaN detected in logits")

        # If in training mode, compute and return cross-entropy loss
        if self.prompt_learner.training:
            # logits: [B, num_classes], label: [B]
            if get_image_features:
                # If get_image_features is True, return logits and image features
                return logits, F.cross_entropy(logits, label), image_features, ctx, bias, avg_text_features, selected_text_features
            else:
                return logits, F.cross_entropy(logits, label)

        # Otherwise, return logits for evaluation: [B, num_classes]
        return logits

    """
    Utility function to print the data types of all parameters and buffers in each model component for debugging.
    """
    def print_all_dtypes(self):
        print(f"CustomCLIP dtype: {self.dtype}")
        print(f"  logit_scale dtype: {getattr(self.logit_scale, 'dtype', type(self.logit_scale))}")
        print(f"  tokenized_prompts dtype: {getattr(self.tokenized_prompts, 'dtype', type(self.tokenized_prompts))}")
        print(f"  image_encoder: {type(self.image_encoder)}")
        for name, param in self.image_encoder.named_parameters():
            print(f"    image_encoder param {name}: {param.dtype}")
        for name, buf in self.image_encoder.named_buffers():
            print(f"    image_encoder buffer {name}: {buf.dtype}")
        print(f"  text_encoder: {type(self.text_encoder)}")
        for name, param in self.text_encoder.named_parameters():
            print(f"    text_encoder param {name}: {param.dtype}")
        for name, buf in self.text_encoder.named_buffers():
            print(f"    text_encoder buffer {name}: {buf.dtype}")
        print(f"  prompt_learner: {type(self.prompt_learner)}")
        for name, param in self.prompt_learner.named_parameters():
            print(f"    prompt_learner param {name}: {param.dtype}")
        for name, buf in self.prompt_learner.named_buffers():
            print(f"    prompt_learner buffer {name}: {buf.dtype}")
        # Also print dtype for ctx, token_prefix, token_suffix if present
        if hasattr(self.prompt_learner, 'ctx'):
            print(f"    prompt_learner.ctx dtype: {self.prompt_learner.ctx.dtype}")
        if hasattr(self.prompt_learner, 'token_prefix'):
            print(f"    prompt_learner.token_prefix dtype: {self.prompt_learner.token_prefix.dtype}")
        if hasattr(self.prompt_learner, 'token_suffix'):
            print(f"    prompt_learner.token_suffix dtype: {self.prompt_learner.token_suffix.dtype}")
        print(f"  clip_model: {type(self.clip_model)}")
        for name, param in self.clip_model.named_parameters():
            print(f"    clip_model param {name}: {param.dtype}")
        for name, buf in self.clip_model.named_buffers():
            print(f"    clip_model buffer {name}: {buf.dtype}")
    """
    Loads a CustomCLIP instance from a saved checkpoint.
    Sets up model configuration, restores state dict, and reassigns classnames.
    """
    @staticmethod
    def load_from_checkpoint(classnames, checkpoint_path, device, n_ctx, clip_model, ctx_8_coop, ctx_4_coop, ctx_init=""):
        ctx_load = (
            ctx_4_coop
            if n_ctx == 4
            else ctx_8_coop
        )
        resolution = clip_model.visual.input_resolution
        cfg = EasyDict(
            {
                "TRAINER": {
                    "COCOOP": {
                        "CTX_LOAD": ctx_load,
                        "N_CTX": n_ctx,
                        "CTX_INIT": ctx_init,
                        "PREC": "fp16",
                    }
                },
                "INPUT": {"SIZE": [resolution, resolution]},
            }
        )
        state_dict = torch.load(checkpoint_path, map_location=device)
        n_cls = state_dict["prompt_learner.token_prefix"].shape[0]

        clip_model, _ = clip.load("ViT-B/16", device=device)

        model = CustomCLIP(cfg, ["X"] * n_cls, clip_model)
        model.load_state_dict(state_dict)
        model.change_classnames(classnames)
        return model

### Prompt Learner CoCoOp

In [None]:
# From: ./model/cocoop/prompt_learner.py
from collections import OrderedDict
import os
import torch
import torch.nn as nn
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()


class PromptLearner(nn.Module):
    """
    The PromptLearner dynamically generates context-conditioned prompts for each class using a meta-network and learned context vectors.
    This module supports both random and pre-initialized context tokens, and computes prompt embeddings compatible with CLIP.
    """
    def __init__(self, cfg, classnames, clip_model):
        """
        Initialize the PromptLearner with configuration parameters, class names, and a CLIP model.
        Loads or initializes context vectors and builds the meta-network for instance-conditioned prompt adaptation.
        """
        super().__init__()
        n_cls = len(classnames)
        n_ctx = cfg.TRAINER.COCOOP.N_CTX
        ctx_init = cfg.TRAINER.COCOOP.CTX_INIT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        vis_dim = clip_model.visual.output_dim
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.INPUT.SIZE[0]
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)


        self.ctx = nn.Parameter(ctx_vectors)

        # Optional: Load pre-trained ctx from a file
        if hasattr(cfg.TRAINER.COCOOP, "CTX_LOAD") and cfg.TRAINER.COCOOP.CTX_LOAD:
            ctx_path = cfg.TRAINER.COCOOP.CTX_LOAD
            if os.path.isfile(ctx_path):
                #print(f"🔁 Loading ctx from: {ctx_path}")
                state_dict = torch.load(ctx_path, map_location="cpu")
                if "ctx" in state_dict:
                    with torch.no_grad():
                        self.ctx.copy_(state_dict["ctx"])
                else:
                    raise KeyError(f"'ctx' not found in {ctx_path}")
            else:
                raise FileNotFoundError(f"CTX_LOAD path not found: {ctx_path}")
            #print("PROMPT LEARNER LOADED FROM A COOP PRETRAINED ONE")

        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
        ]))

        if cfg.TRAINER.COCOOP.PREC == "fp16" and not torch.backends.mps.is_available():
            self.meta_net.half()
        else:
            print("⚠️ Using float32 for meta_net due to MPS")

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])  # (n_cls, n_tkn)
        with torch.no_grad():
            device = clip_model.token_embedding.weight.device

            # Ensure tokenized_prompts is on the right device BEFORE embedding
            tokenized_prompts = tokenized_prompts.to(device)

            embedding = clip_model.token_embedding(tokenized_prompts)

            # Do not convert to fp16 on MPS (Apple doesn't support it fully)
            if device.type == "mps" and dtype == torch.float16:
                print("⚠️ fp16 not fully supported on MPS; using float32 instead")
                dtype = torch.float32

            embedding = embedding.to(dtype)
        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens

    def construct_prompts(self, ctx, prefix, suffix, label=None):
        """
        Construct the full tokenized prompt from context tokens, prefix (SOS), and suffix (CLS, EOS).
        Optionally uses label indexing for training-time class selection.
        """
        # dim0 is either batch_size (during training) or n_cls (during testing)
        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)

        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]

        prompts = torch.cat(
            [
                prefix,  # (dim0, 1, dim)
                ctx,     # (dim0, n_ctx, dim)
                suffix,  # (dim0, *, dim)
            ],
            dim=1,
        )

        return prompts

    def forward(self, im_features):
        """
        Generate the context-conditioned prompts for each class based on input image features.
        Applies the meta-network to compute per-instance bias vectors, which shift the context tokens.
        Returns the generated prompts along with the original context and computed bias.
        """
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx                     # (n_ctx, ctx_dim)
        if im_features.isnan().any():
            raise ValueError("NaN in im_features before meta_net")

        #print("im_features stats", im_features.min().item(), im_features.max().item(), im_features.norm(dim=1).mean().item())

        meta_net_dtype = next(self.meta_net.parameters()).dtype
        # print(f"meta_net_dtype: {meta_net_dtype}, im_features dtype: {im_features.dtype}")
        bias = self.meta_net(im_features.to(meta_net_dtype))  # (batch, ctx_dim)
        if bias.isnan().any():
            raise ValueError("NaN detected in bias")
        bias = bias.unsqueeze(1)           # (batch, 1, ctx_dim)
        ctx = ctx.unsqueeze(0)             # (1, n_ctx, ctx_dim)
        ctx_shifted = ctx + bias           # (batch, n_ctx, ctx_dim)
        if ctx_shifted.isnan().any():
            raise ValueError("NaN detected in ctx_shifted")
        # Use instance-conditioned context tokens for all classes
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1) # (n_cls, n_ctx, ctx_dim)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)  # (n_cls, n_tkn, ctx_dim)
            if pts_i.isnan().any():
                raise ValueError("NaN detected in pts_i")
            prompts.append(pts_i)
        prompts = torch.stack(prompts)

        return prompts, ctx, bias

## Training System CoOpOp

In [None]:
# From: ./training_systems/cocoop.py
"""
Main module for training the CoCoOp system, supporting both base and adversarial training phases.
Includes configuration, data loading, model preparation, and training logic for zero-shot learning with CLIP.
"""

import os
import math
from copy import deepcopy
from statistics import harmonic_mean
from sympy.simplify.cse_main import preprocess_for_cse
import torch
from easydict import EasyDict
from tqdm import tqdm
from torch.utils.tensorboard.writer import SummaryWriter
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
import clip
import numpy as np
import random
import hashlib

# from model.cocoop.custom_clip import CustomCLIP
# from model.cocoop.mlp_adversary import GradientReversalLayer, AdversarialMLP
# from utils import (
#     conditional_clustering,
#     random_clustering,
#     rotating_cluster_generator_shift,
#     get_data,
#     base_novel_categories,
#     split_data,
#     TensorboardLogger,
#     CLASS_NAMES
# )

# from training_systems.training_methods import (
#     Adversarial,
#     KLCoCoOp,
#     BaseCoCoOp,
#     KLCoCoOpV2,
# )
# from training_systems.evaluation_methods import (
#     ZeroShotTestStep,
#     FineTunedTestStep,
#     EvalStep,
# )
# from training_systems.core import DoubleDatasetTrainingMethod

# --- Add this block for reproducibility ---
def set_global_seed(seed):

    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # For CUDA >= 10.2, for full determinism
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    # For PyTorch >= 1.8
    if hasattr(torch, 'use_deterministic_algorithms'):
        torch.use_deterministic_algorithms(True)
# --- End reproducibility block ---


def checksum(model):
    """
    Generate an MD5 checksum of the model's parameters to track changes across training.

    Args:
        model (torch.nn.Module): The model to hash.

    Returns:
        str: MD5 hash string.
    """
    with torch.no_grad():
        all_params = torch.cat(
            [p.view(-1).cpu() for p in model.parameters() if p.requires_grad]
        )
        return hashlib.md5(all_params.numpy().tobytes()).hexdigest()


class CoCoOpSystem:
    """
    Manages the full training process of the CoCoOp model, including configuration, training, evaluation,
    checkpointing, and logging. Supports both base and adversarial training.
    """

    def __init__(
        self,
        *,
        test_batch_size=16,
        pseudo_base_ratio=0.7,
        seed=42,
        device="cuda",
        run_name="exp1",
        cnn_model="ViT-B/16",
        hparams_file,
        optimizer_configs=None,
        skip_tests=None,
        train_base_checkpoint_path=None,
        debug=False,
        prompt_learner_opt=None,
        kl_loss_opt=None,
        adv_training_opt=None,
        base_training_opt=None,
        clustering_opt=None,
        report=False,
        pat=True,
    ):
        """
        Initialize the CoCoOp system, load data, setup the model, loss functions, optimizers, and logger.

        Args:
            batch_size (int): Batch size for training.
            device (str): Device identifier (e.g., 'cuda' or 'cpu').
            run_name (str): Unique name for the training run.
            cnn_model (str): Backbone CLIP model name.
            optimizer_configs (list): Optimizer settings for base and adversarial training.
            skip_tests (list): Booleans to skip testing after each training stage.
            train_base_checkpoint_path (str): Optional path to a pre-trained base model.
            debug (bool): Enables logging of additional debug information.
            prompt_learner_opt, kl_loss_opt, adv_training_opt, base_training_opt: Configuration dictionaries.
        """

        print(f"run_name: {run_name}, using {cnn_model}, pat: {pat}")
        # --- Set global seed for reproducibility ---
        self.seed = seed if seed is not None else 42
        # set_global_seed(self.seed)
        # --- End reproducibility ---
        assert prompt_learner_opt is not None, "prompt_learner_opt must be provided"
        assert kl_loss_opt is not None, "kl_loss_opt must be provided"
        assert adv_training_opt is not None, "adv_training_opt must be provided"
        assert base_training_opt is not None, "base_training_opt must be provided"
        assert clustering_opt is not None, "clustering_opt must be provided"
        assert (
            optimizer_configs is not None and len(optimizer_configs) == 2
        ), "Two optimizer configs must be provided"

        # --- NEW: Pseudo-base/novel split param ---
        self.pseudo_base_ratio = pseudo_base_ratio
        self.pseudo_split_seed = seed

        self.test_batch_size = test_batch_size
        self.device = device
        self.epochs = base_training_opt["epochs"]
        self.run_name = run_name
        self.n_ctx = prompt_learner_opt["n_ctx"]
        self.ctx_init = prompt_learner_opt["ctx_init"]
        self.class_token_position = prompt_learner_opt["class_token_position"]
        self.csc = prompt_learner_opt["csc"]
        self.lambda_kl = kl_loss_opt["lambda_kl"]
        self.double_datasets_kl = kl_loss_opt.get("double_datasets_kl", False)
        self.rotation_period = kl_loss_opt.get("rotation_period", "relative")
        self.warmup_lambda_kl = kl_loss_opt.get("warmup_lambda_kl", 0)  # Add warmup parameter
        self.lambda_adv = adv_training_opt["lambda_adv"]
        self.gaussian_noise = adv_training_opt.get("gaussian_noise", 0.0)
        self.use_bias_ctx = adv_training_opt.get("use_bias_ctx", False)
        self.adv_training_epochs = adv_training_opt["adv_training_epochs"]
        self.cnn_model = cnn_model
        self.warmup_epoch = base_training_opt["warmup_epoch"]
        self.warmup_cons_lr = base_training_opt["warmup_cons_lr"]
        self.using_kl = kl_loss_opt["using_kl"]
        self.grl_lambda = adv_training_opt["grl_lambda"]
        self.mlp_opt = EasyDict(adv_training_opt["mlp_opt"])
        self.skip_tests = (
            skip_tests if skip_tests is not None else [False, False, False]
        )
        self.train_base_checkpoint_path = train_base_checkpoint_path
        self.debug = debug
        self.max_epoch = self.epochs
        self.optimizer_configs = [EasyDict(conf) for conf in optimizer_configs]
        self.warmup_lambda_adv = adv_training_opt["warmup_lambda_adv"]
        self.base_batch_size = base_training_opt["batch_size"]
        self.adv_batch_size = adv_training_opt["batch_size"]
        self.adv_accumulation_steps = adv_training_opt.get("accumulation_steps", 1)
        self.base_accumulation_steps = base_training_opt.get("accumulation_steps", 1)
        self.prompt_learner_warmup_epochs = adv_training_opt["prompt_learner_warmup_epochs"] if "prompt_learner_warmup_epochs" in adv_training_opt else 0
        self.pat = pat
        print(
            "BATCH SIZES: ",
            self.test_batch_size,
            self.base_batch_size,
            self.adv_batch_size,
        )

        self.ignore_no_improvement = adv_training_opt.get("ignore_no_improvement", False)
        if not report:
            self.log_dir = f"runs/CoCoOp/{self.run_name}"
        elif self.pat:
            self.log_dir = f"runs/report/{self.run_name}"
        else:
            self.log_dir = f"runs/report_no_pat/{self.run_name}"

        self.writer = SummaryWriter(log_dir=self.log_dir)
        self.writer.add_text("Hparams yaml file", hparams_file)
        self.logger = TensorboardLogger(self.writer)

        self.logger.log_hparams(
            {
                "batch_size_test": self.test_batch_size,
                "base_batch_size": self.base_batch_size,
                "adv_batch_size": self.adv_batch_size,
                "epochs": self.epochs,
                "n_ctx": self.n_ctx,
                "ctx_init": self.ctx_init,
                "class_token_position": self.class_token_position,
                "csc": self.csc,
                "lambda_kl_first": self.lambda_kl[0],
                "lambda_kl_second": self.lambda_kl[1],
                "warmup_epoch": self.warmup_epoch,
                "warmup_cons_lr": self.warmup_cons_lr,
                "lambda_adv": self.lambda_adv,
                "cnn_model": self.cnn_model,
                "grl_lambda": self.grl_lambda,
                "prompt_learner_warmup_epochs" : self.prompt_learner_warmup_epochs,
                "double_datasets_kl": self.double_datasets_kl,
                "pseudo_base_ratio": self.pseudo_base_ratio,
                "pseudo_split_seed": self.seed,
                "rotation_period": self.rotation_period,
                "warmup_lambda_kl": self.warmup_lambda_kl,
            }
        )


        print(f"patience {'disabled' if not self.pat else 'enabled'}")
        # Load model
        self.clip_model, preprocess= clip.load(self.cnn_model)
        self.clip_model = self.clip_model.to(self.device)
        resolution = self.clip_model.visual.input_resolution
        self.train_set, self.val_set, self.test_set = get_data(resolution=resolution, eval_transform=preprocess)
        self.base_classes, self.novel_classes = base_novel_categories(self.train_set)
        # --- NEW: Pseudo-base/novel split ---

        # Helper to split a dataset by class list
        # (moved to method below)

        # Split train_base/val_base into pseudo_base/pseudo_novel
        self.train_base, _ = split_data(self.train_set, self.base_classes)
        self.val_base, self.val_novel = split_data(self.val_set, self.base_classes)
        self.test_base, self.test_novel = split_data(self.test_set, self.base_classes)
        print(f"Base classes length: {len(self.base_classes)}, Novel classes length: {len(self.novel_classes)}")
        print(f"Train base length: {len(self.train_base)}, Val base length: {len(self.val_base)}, Test base length: {len(self.test_base)}")
        print(f"Val novel length: {len(self.val_novel)}, Test novel length: {len(self.test_novel)}")


        self.rotation_steps = int(len(self.base_classes)*(1-self.pseudo_base_ratio))

        self.cluster_generator = rotating_cluster_generator_shift(
            self.base_classes,
            self.pseudo_base_ratio,
            steps=self.rotation_steps,
            seed=self.seed
        )

        _, self.pseudo_base_classes, self.pseudo_novel_classes = next(self.cluster_generator)

        self.train_pseudo_base = self.split_by_classes(self.train_base, self.pseudo_base_classes)
        self.train_pseudo_novel = self.split_by_classes(self.train_base, self.pseudo_novel_classes)
        self.val_pseudo_base = self.split_by_classes(self.val_base, self.pseudo_base_classes)
        self.val_pseudo_novel = self.split_by_classes(self.val_base, self.pseudo_novel_classes)

        # --- Model/classnames: only pseudo_base for first phase ---
        ctx_load = (
            "./bin/coop/coop_ctx_4_VIT16.pth"
            if self.n_ctx == 4
            else "./bin/coop/coop_ctx_8_VIT16.pth"
        )
        cfg = EasyDict(
            {
                "TRAINER": {
                    "COCOOP": {
                        "CTX_LOAD": ctx_load,
                        "N_CTX": self.n_ctx,
                        "CTX_INIT": self.ctx_init,
                        "PREC": "fp16",
                    }
                },
                "INPUT": {"SIZE": [resolution, resolution]},
            }
        )
        self.model = CustomCLIP(

            classnames=[CLASS_NAMES[idx] for idx in self.pseudo_base_classes],
            cfg=cfg,
            clip_model=self.clip_model,
        ).to(self.device)

        # Print all dtypes of every component inside CustomCLIP
        # self.model.print_all_dtypes()

        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
            else:
                param.requires_grad_(True)

        self.cost_function = nn.CrossEntropyLoss()
        self.grl = GradientReversalLayer(lambda_=self.grl_lambda)

        clip_dim = self.clip_model.visual.output_dim

        self.mlp_adversary = AdversarialMLP(
            input_dim=clip_dim+len(self.base_classes), opt=self.mlp_opt, output_dim=clustering_opt["n_clusters"], use_bias_ctx=self.use_bias_ctx, n_ctx=self.n_ctx
        ).to(self.device)

        print("mlp adversary struct: ", self.mlp_adversary)
        self.optimizer = self.get_optimizer(self.model, None, self.optimizer_configs[0])
        self.lr_scheduler = LambdaLR(self.optimizer, self._lr_lambda)

        # --- NEW: Cluster dict for adversarial phase ---

        clustering_type = clustering_opt["clustering_type"]

        if clustering_type == "random":
            # Use random clustering
            self.cls_cluster_dict, _ = random_clustering(
                n_cluster=clustering_opt["n_clusters"],
                seed=self.seed,
                distribution="uniform",
            )
        elif clustering_type == "semantic":
            # Load clustering information
            self.cls_cluster_dict, _ = conditional_clustering(
                n_cluster=clustering_opt["n_clusters"],
                variance=clustering_opt["variance"],
                cnn=clustering_opt["vision_encoder"],
                device=self.device,
            )
        elif clustering_type == "default":
            # Pseudo_base: cluster 0, pseudo_novel: cluster 1
            self.pseudo_cls_cluster_dict = {c: 0 for c in self.pseudo_base_classes}
            self.pseudo_cls_cluster_dict.update({c: 1 for c in self.pseudo_novel_classes})
            # For adversarial phase, use this dict
            self.cls_cluster_dict = self.pseudo_cls_cluster_dict
        else:
            raise ValueError(f"Unknown clustering type: {clustering_type}")

    def _set_test_methods(self):
        """
        Initializes the zero-shot and fine-tuned test step evaluators for both base and novel class splits.
        """
        self.zero_shot_base_classes_test_method = ZeroShotTestStep(
            model=self.clip_model,
            batch_size=self.test_batch_size,
            categories=self.base_classes,
        )
        self.zero_shot_novel_classes_test_method = ZeroShotTestStep(
            model=self.clip_model,
            batch_size=self.test_batch_size,
            categories=self.novel_classes,
        )
        self.zero_shot_pseudo_base_test_method = ZeroShotTestStep(
            model=self.clip_model,
            batch_size=self.test_batch_size,
            categories=self.pseudo_base_classes,
        )
        self.zero_shot_pseudo_novel_test_method = ZeroShotTestStep(
            model=self.clip_model,
            batch_size=self.test_batch_size,
            categories=self.pseudo_novel_classes,
        )
        self.finetuned_test_method = FineTunedTestStep(
            model=self.model,
            batch_size=self.test_batch_size,
        )

    def _set_eval_method(self):
        """
        Initializes the evaluation method used during validation.
        """
        self.eval_method = EvalStep(
            model=self.model,
            batch_size=self.test_batch_size,
        )

    def _set_train_methods(self):
        """
        Initializes the training method used for both base and adversarial phases, depending on whether KL loss is enabled.
        Chooses between standard and KL-regularized training methods.
        """
        self.adversarial_method = Adversarial(
                lambda_adv=0.05,
                model=self.model,
                optimizer=self.optimizer,
                cls_cluster_dict=self.cls_cluster_dict,
                grl=self.grl,
                mlp_adversary=self.mlp_adversary,
                debug=self.debug,
                tmp_classes=self.base_classes,
                gaussian_noise=self.gaussian_noise,
                use_bias_ctx=self.use_bias_ctx
            )

        if self.using_kl[0]:
            if self.double_datasets_kl:
                self.basic_train_method = KLCoCoOpV2(
                    model=self.model,
                    optimizer=self.optimizer,
                    debug=self.debug,
                    lambda_kl=self.lambda_kl[0],
                )
            else:
                self.basic_train_method = KLCoCoOp(
                    model=self.model,
                    optimizer=self.optimizer,
                    debug=self.debug,
                    lambda_kl=self.lambda_kl[0],
                )
        else:
            self.basic_train_method = BaseCoCoOp(
                model=self.model,
                optimizer=self.optimizer,
                debug=self.debug,
            )



    def train(self):
        """
        Execute the full training pipeline: base phase, optionally followed by adversarial training and evaluation.
        """
        self._set_test_methods()
        self._set_eval_method()
        self._set_train_methods()
        if not self.skip_tests[0]:
            print("Doing base accuracy test")
            base_acc, novel_acc = self.compute_evaluation(-1, base=True)
            self._log_final_metrics(
                "Final metrics - CLIP ZERO SHOT",
                base_acc,
                novel_acc,
                -1,
            )

        best_model_path = os.path.join(self.log_dir, "best_model.pth")
        # Ensure all methods are properly initialized prior to the base training phase

        if self.train_base_checkpoint_path is None:
            # Base training phase
            base_end_epoch, _, best_base_epoch = self._train_base_phase(best_model_path)
            self.best_base_epoch = best_base_epoch  # Store for later use if needed
            # Log the best epoch to TensorBoard and logger
            self.writer.add_scalar("Best Base Epoch", best_base_epoch, base_end_epoch)

            # Also print for visibility
            print(f"Best base model found at epoch: {best_base_epoch}")
            if self.epochs != 0:
                if self.pat:
                    print(f"[DEBUG] Loading best model state dict after base phase from: {best_model_path}")
                    self.model.load_state_dict(torch.load(best_model_path))
                    print(f"[DEBUG] Loaded best model with classnames: {self.model.prompt_learner.n_cls} classes")
                else:
                    print(f"[DEBUG] Using last model state from base phase (patience disabled)")
                    print(f"[DEBUG] Model has classnames: {self.model.prompt_learner.n_cls} classes")
                self.save_model(path="./bin/cocoop", prefix="after_first_train_")

            if not self.skip_tests[1]:
                print("Doing base accuracy test")
                base_acc, novel_acc = self.compute_evaluation(base_end_epoch)
                self._log_final_metrics(
                    "Final metrics - After Base Training",
                    base_acc,
                    novel_acc,
                    base_end_epoch,
                )
        else:
            base_end_epoch = 0
            print("Skipping base training")
            print(f"[DEBUG] Loading model state dict from: {self.train_base_checkpoint_path}")
            with self.model.temporary_classnames([CLASS_NAMES[c] for c in self.base_classes]):
                self.model.load_state_dict(torch.load(self.train_base_checkpoint_path))
            print(f"[DEBUG] Loaded model with classnames: {self.model.prompt_learner.n_cls} classes")

        # Re-initialize test/eval/train methods after loading/training model and before adversarial phase
        self._set_eval_method()
        self._set_test_methods()

        self.optimizer = self.get_optimizer(
            self.model, self.mlp_adversary, self.optimizer_configs[1]
        )
        # After changing optimizer, ensure train methods use the new optimizer
        self._set_train_methods()

        checksum1 = None
        if self.debug:
            checksum1 = checksum(self.model)
            # Adversarial phase
            print("Before adv training:", checksum1)

        adv_end_epoch = self._train_adversarial_phase(base_end_epoch, best_model_path)

        if self.debug and checksum1:
            checksum2 = checksum(self.model)
            print("After adv training:", checksum2)
            print(f"checksum1: {checksum1}, checksum2: {checksum2}")
            if checksum1 != checksum2:
                print("Model parameters have changed after adversarial training.")

        if not self.skip_tests[2]:
            print("Doing post-adv. accuracy test")
            base_acc, novel_acc = self.compute_evaluation(adv_end_epoch)
            self._log_final_metrics(
                "Final metrics - After Adversarial Training",
                base_acc,
                novel_acc,
                adv_end_epoch,
            )

        self.logger.close()
        if self.adv_training_epochs != 0:
            self.save_model(path="./bin/cocoop", prefix="after_adv_train_")
            self.save_mlp_adversary()

    def get_next_rotation(self):
        """
        Get the next rotation of the pseudo base and pseudo novel classes.
        """
        _, self.pseudo_base_classes, self.pseudo_novel_classes = next(self.cluster_generator)
        self.train_pseudo_base = self.split_by_classes(self.train_base, self.pseudo_base_classes)
        self.train_pseudo_novel = self.split_by_classes(self.train_base, self.pseudo_novel_classes)
        self.val_pseudo_base = self.split_by_classes(self.val_base, self.pseudo_base_classes)
        self.val_pseudo_novel = self.split_by_classes(self.val_base, self.pseudo_novel_classes)
        return self.pseudo_base_classes, self.pseudo_novel_classes, self.train_pseudo_base, self.train_pseudo_novel, self.val_pseudo_base, self.val_pseudo_novel

    def _train_base_phase(self, best_model_path):
        """
        Train the model using KL regularization only (no adversarial objective).

        Args:
            best_model_path (str): Path to store the best base model.

        Returns:
            Tuple[int, float, int]: Final epoch index, best validation score, and best epoch index.
        """
        best_score = 0.0
        patience = 8
        patience_counter = 0
        c = 0
        method = self.basic_train_method
        evaluation_period = 1 if self.pat else 2



        # Initialize rotation tracking
        if self.rotation_period == "random":
            # For random rotation, we'll track when the next rotation should happen
            next_rotation_epoch = random.randint(1, 4)  # Sample from 1 to 4
            rotation_epochs = None  # Not used for random
        else:
            rotation_epochs = int(patience * (3/4)) if self.rotation_period == "relative" else self.rotation_period
            next_rotation_epoch = None  # Not used for fixed/relative

        pbar = tqdm(total=self.max_epoch, desc="Base Training")
        best_epoch = -1  # Track the epoch of the best model
        best_epoch_path = best_model_path + ".best_epoch.txt"  # Path to store best epoch

        for e in range(self.max_epoch):

            # Apply lambda_kl warmup if enabled
            if self.warmup_lambda_kl > 0 and (isinstance(method, KLCoCoOp) or isinstance(method, KLCoCoOpV2)):
                progress = (e + 1) / self.warmup_lambda_kl
                current_lambda_kl = 0.1 + (self.lambda_kl[0] - 0.1) * min(progress, 1)
                # Only update lambda_kl for methods that support it (KLCoCoOp and KLCoCoOpV2)
                method.update_lambda_kl(current_lambda_kl)

            if self.using_kl[0]:
                if isinstance(method, DoubleDatasetTrainingMethod):
                    # Check if rotation should happen
                    should_rotate = False
                    if self.rotation_period == "random":
                        if e == next_rotation_epoch:
                            should_rotate = True
                            # Sample next rotation epoch (1 to 4 epochs from now)
                            next_rotation_epoch = e + random.randint(1, 4)
                    else:
                        # Fixed or relative rotation
                        if rotation_epochs is not None and e % rotation_epochs == 0:
                            should_rotate = True

                    if should_rotate:
                        if self.rotation_period == "random":
                            print(f"[DEBUG] Random rotation at epoch {e}, next rotation at epoch {next_rotation_epoch}")
                        (
                            self.pseudo_base_classes,
                            self.pseudo_novel_classes,
                            self.train_pseudo_base,
                            self.train_pseudo_novel,
                            self.val_pseudo_base,
                            self.val_pseudo_novel
                        ) = self.get_next_rotation()
                    kl_loss, ce_loss, acc = method.double_datasets_train_step(
                        self.train_pseudo_base,
                        self.train_pseudo_novel,
                        self.base_batch_size,
                        ["pseudo_base", "pseudo_novel KL"],
                        self.pseudo_base_classes,
                        self.pseudo_novel_classes,
                    )
                    total_loss = ce_loss + kl_loss
                else:
                    total_loss, acc, ce_loss, kl_loss = method.train_step(
                        self.train_pseudo_base,
                        self.base_batch_size,
                        classnames=self.pseudo_base_classes,
                        accumulation_steps=self.base_accumulation_steps
                    )
            elif isinstance(method, BaseCoCoOp):
                total_loss, acc = method.train_step(
                    self.train_pseudo_base,
                    self.base_batch_size,
                    classnames=self.pseudo_base_classes,
                    accumulation_steps=self.base_accumulation_steps
                )
                kl_loss = None
                ce_loss = total_loss

            self.logger.log_training_base(
                e,
                self.optimizer.param_groups[0]["lr"],
                ce_loss,
                acc,
                kl_loss,
                total_loss,
            )
            postfix_dict = {
                'lr': self.optimizer.param_groups[0]["lr"],
                'ce_L': ce_loss,
                'kl_L': kl_loss,
                'pat_c': patience_counter,
            }
            if e % evaluation_period == 0:
                base_val_acc, novel_val_acc = self._evaluate_and_log(e)

                score = novel_val_acc

                if (score > best_score) or (not self.pat):
                    best_score = score
                    patience_counter = 0
                    torch.save(self.model.state_dict(), best_model_path)
                    best_epoch = e  # Save the epoch of the best model
                    # Store the best epoch to disk
                    with open(best_epoch_path, "w") as f:
                        f.write(str(best_epoch))
                    # When pat=False, we always save the current model state (last epoch)
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Early stopping at epoch {e}")
                        break

                postfix_dict["B_val_acc"] = base_val_acc
                postfix_dict["N_val_acc"] = novel_val_acc
                postfix_dict["score"] = score

            self.lr_scheduler.step()
            pbar.set_postfix(**postfix_dict)
            pbar.update(1)
            c += 1

        return c, best_score, best_epoch

    def _train_adversarial_phase(self, start_epoch, best_model_path):
        """
        Train the model adversarially with dynamic lambda scheduling and early stopping.

        Args:
            start_epoch (int): Starting epoch index.
            best_model_path (str): Path to save the best adversarial model.

        Returns:
            int: Final epoch index after training.
        """
        best_novel_accuracy = 0.0
        patience = 5
        patience_counter = 0
        at_least_one_improving = False
        warmup_epochs = self.warmup_lambda_adv
        lambda_adv_max = self.lambda_adv
        initial_lambda_adv = 0.05
        pbar = tqdm(total=self.adv_training_epochs, desc="Adversarial Training")

        last_model_state = None  # store last model state

        method = self.adversarial_method

        for e in range(start_epoch, start_epoch + self.adv_training_epochs):
            progress = (e - start_epoch + 1) / warmup_epochs
            new_lambda_adv = initial_lambda_adv + (
                lambda_adv_max - initial_lambda_adv
            ) * min(progress, 1)

            method.update_lambda_adv(new_lambda_adv)

            if (e-start_epoch) < self.prompt_learner_warmup_epochs:
                print(f"[DEBUG] Prompt learner FROZEN at adv epoch {e-start_epoch}")
                for name, param in self.model.named_parameters():
                    if "prompt_learner" in name:
                        param.requires_grad_(False)
            elif (e-start_epoch) == self.prompt_learner_warmup_epochs:
                print(f"prompt learner UNFROZEN at adv epoch {e-start_epoch}")
                for name, param in self.model.named_parameters():
                    if "prompt_learner" in name:
                        param.requires_grad_(True)
                # print the frozen parameters name


            if self.using_kl[1]:
                total_loss, acc, ce_loss, kl_loss, adv_loss = method.train_step(
                    self.train_base,
                    self.base_batch_size,
                    classnames=self.base_classes,
                    accumulation_steps=self.base_accumulation_steps
                )
            else:
                total_loss, acc, ce_loss, adv_loss = method.train_step(
                    self.train_base,
                    self.adv_batch_size,
                    classnames=self.base_classes,
                    accumulation_steps=self.adv_accumulation_steps
                )
                kl_loss = None

            self.logger.log_training_adv(
                e,
                method.lambda_adv,
                ce_loss,
                acc,
                adv_loss,
                ce_loss + adv_loss + (kl_loss if kl_loss else 0.0),
                kl_loss=kl_loss,
            )
            if (e-start_epoch) >= self.prompt_learner_warmup_epochs:

                # Always update last_model_state to track the current state
                last_model_state = deepcopy(self.model.state_dict())

                base_val_acc, novel_val_acc = self._evaluate_and_log(
                    e,
                    is_adv=True,
                )
                if novel_val_acc > best_novel_accuracy or self.ignore_no_improvement or (not self.pat):
                    best_novel_accuracy = novel_val_acc
                    print(f"[DEBUG] Saving model with classnames: {self.model.prompt_learner.n_cls} classes")
                    torch.save(self.model.state_dict(), best_model_path)
                    at_least_one_improving = True
                    patience_counter = 0
                    # When pat=False, we always save the current model state (last epoch)
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Early stopping adversarial at epoch {e}")
                        break
                pbar.set_postfix(
                    PB_val_acc=base_val_acc,
                    PN_val_acc=novel_val_acc,
                    ce_L=ce_loss,
                    kl_L=kl_loss,
                    adv_L=adv_loss,
                    lr=self.optimizer.param_groups[0]["lr"],
                    pat_c=patience_counter,
                )
            else:

                pbar.set_postfix(
                    adv_loss=adv_loss,
                )
            pbar.update(1)

        # When pat=False, we want to keep the last model state (no loading from checkpoint)
        # When pat=True, we load the best model if there was improvement
        if self.pat and ((at_least_one_improving and self.epochs != 0) or self.ignore_no_improvement):
            print(f"[DEBUG] Loading best model state dict after adversarial phase from: {best_model_path}")
            self.model.load_state_dict(torch.load(best_model_path))
            print(f"[DEBUG] Loaded model with classnames: {self.model.prompt_learner.n_cls} classes")
            print("Loaded best model from adversarial checkpoint (robust, filtered mismatched keys).")
        else:
            print(
                "Using model from last adversarial epoch (patience disabled or no improvement)."
            )
            if last_model_state is not None:
                self.model.load_state_dict(last_model_state)
                print("Loaded last adversarial model state.")
            else:
                # If last_model_state is None (e.g., still in warmup), the model already has the last state
                # But to be safe, we can load from the saved checkpoint which should be the last state when pat=False
                if not self.pat:
                    print(f"[DEBUG] Loading last model state from checkpoint (patience disabled)")
                    self.model.load_state_dict(torch.load(best_model_path))
                    print("Loaded last adversarial model state from checkpoint.")
                else:
                    print("Model already has the last adversarial state (no loading needed).")

        return start_epoch + self.adv_training_epochs

    def _evaluate_and_log(self, epoch, is_adv=False):
        """
        Run validation and log results for both base and novel splits.

        Args:
            epoch (int): Current training epoch.
            is_adv (bool): Whether evaluation is during adversarial training.

        Returns:
            Tuple[float, float]: Accuracy for base and novel classes.
        """

        if is_adv:
            metrics_base = self.eval_method.evaluate(
                dataset=self.val_base,
                classnames=self.base_classes,
                desc_add=" - Base",
            )
            base_val_loss = metrics_base["loss"]
            base_val_acc = metrics_base["accuracy"]

            metrics_novel = self.eval_method.evaluate(
                dataset=self.val_novel,
                classnames=self.novel_classes,
                desc_add=" - Novel",
            )
            novel_val_loss = metrics_novel["loss"]
            novel_val_acc = metrics_novel["accuracy"]

        else:

            metrics_base = self.eval_method.evaluate(
                dataset=self.val_base,
                classnames=self.base_classes,
                desc_add=" - Base",
            )
            base_val_loss = metrics_base["loss"]
            base_val_acc = metrics_base["accuracy"]

            metrics_novel = self.eval_method.evaluate(
                dataset=self.val_novel,
                classnames=self.novel_classes,
                desc_add=" - Novel",
            )
            novel_val_loss = metrics_novel["loss"]
            novel_val_acc = metrics_novel["accuracy"]

        self.logger.log_validation(
            epoch,
            base_val_loss,
            base_val_acc,
            novel_val_loss,
            novel_val_acc,
            is_adv=is_adv,
        )

        return base_val_acc, novel_val_acc

    def _log_final_metrics(self, tag, base_acc, novel_acc, step):
        """
        Log final test results to TensorBoard.

        Args:
            tag (str): Descriptive tag for the log.
            base_acc (float): Accuracy on base classes.
            novel_acc (float): Accuracy on novel classes.
            step (int): Epoch or step index for this log.
        """
        self.logger.log_final_metrics(tag, base_acc, novel_acc, step)

    def _lr_lambda(self, current_epoch):
        """
        Learning rate scheduler with cosine annealing and warm-up.

        Args:
            current_epoch (int): Epoch index.

        Returns:
            float: Learning rate multiplier.
        """
        if current_epoch < self.warmup_epoch:
            return self.warmup_cons_lr / self.optimizer_configs[0].prompt_lr # type: ignore
        return 0.5 * (
            1
            + math.cos(
                math.pi
                * (current_epoch - self.warmup_epoch)
                / (self.max_epoch - self.warmup_epoch + 1e-7)
            )
        )

    def compute_evaluation(self, epoch_idx, base=False):
        """
        Run evaluation on the test split for both base and novel classes.

        Args:
            epoch_idx (int): Epoch index for logging.
            base (bool): Whether to evaluate the frozen base CLIP model.

        Returns:
            Tuple[float, float]: Base and novel class test accuracy.

        model = self.model if not base else self.clip_model
        base_accuracy = test_step(model, self.test_base, self.base_classes, self.batch_size, self.device, label="test", base=base)
        novel_accuracy = test_step(model, self.test_novel, self.novel_classes, self.batch_size, self.device, label="test", base=base)
        """
        if base:
            base_metrics = self.zero_shot_base_classes_test_method.evaluate(
                dataset=self.test_base,
                desc_add=" - Base Zero Shot",

            )
            novel_metrics = self.zero_shot_novel_classes_test_method.evaluate(
                dataset=self.test_novel,
                desc_add=" - Novel Zero Shot",
            )
        else:
            base_metrics = self.finetuned_test_method.evaluate(
                dataset=self.test_base,
                classnames=self.base_classes,
                desc_add=" - Base Fine Tuned",
            )
            novel_metrics = self.finetuned_test_method.evaluate(
                dataset=self.test_novel,
                classnames=self.novel_classes,
                desc_add=" - Novel Fine Tuned",
            )

        base_accuracy = base_metrics["accuracy"]
        novel_accuracy = novel_metrics["accuracy"]
        self.logger.log_test_accuracy(epoch_idx, base_accuracy, "base_classes")
        self.logger.log_test_accuracy(epoch_idx, novel_accuracy, "novel_classes")
        return base_accuracy, novel_accuracy

    def save_model(self, path="./bin/cocoop", prefix=""):
        """
        Save model weights to disk.

        Args:
            path (str): Directory to save the model to.
            prefix (str): Filename prefix to distinguish models.
        """
        os.makedirs(path, exist_ok=True)
        print(f"[DEBUG] Saving model with classnames: {self.model.prompt_learner.n_cls} classes")
        with self.model.temporary_classnames([CLASS_NAMES[idx] for idx in self.base_classes]):
            torch.save(
                self.model.state_dict(), os.path.join(path, f"{prefix}{self.run_name}.pth")
            )

    def save_mlp_adversary(self, path="./bin/cocoop", prefix=""):
        """
        Save the MLP adversary weights to disk.

        Args:
            path (str): Directory to save the MLP adversary model to.
            prefix (str): Filename prefix to distinguish models.
        """
        os.makedirs(path, exist_ok=True)
        torch.save(
            self.mlp_adversary.state_dict(),
            os.path.join(path, f"{prefix}{self.run_name}_mlp_adversary.pth"),
        )

    def get_optimizer(self, model, mlp_adversary, config):
        """
        Build an SGD optimizer with separate learning rates for different parameter groups.

        Args:
            model (torch.nn.Module): Main model.
            mlp_adversary (torch.nn.Module): Optional adversarial MLP.
            config: Optimizer configuration namespace.

        Returns:
            torch.optim.Optimizer: Configured optimizer.
        """
        params = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if "prompt_learner" in n and p.requires_grad
                ],
                "lr": config.prompt_lr,
            }
        ]
        if mlp_adversary is not None:
            params.append(
                {
                    "params": mlp_adversary.parameters(),
                    "lr": config.mlp_lr,
                }
            )
        return torch.optim.SGD(
            params, weight_decay=config.weight_decay, momentum=config.momentum
        )

    def split_by_classes(self, dataset, class_list):
        idxs = [i for i, (_, label) in enumerate(dataset) if label in class_list]
        return torch.utils.data.Subset(dataset, idxs)

## Training System Core

### Default training method

In [None]:
# From: ./training_systems/core/TrainingMethod.py
"""
 Abstract base class for training methods. it should have forward_backward method to be fulfilled by its children
"""

from abc import ABC, abstractmethod
from typing import Dict, Any, Callable

from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

from utils import AverageMeter, ContiguousLabelDataset


class TrainingMethod(ABC):
    """
    Abstract base class for training methods. it should have forward_backward method to be fulfilled by its children

    The TrainingMethod class serves as an abstract interface that standardizes the structure of different training strategies.
    It defines key methods that all training strategies must implement, such as:
     - `get_metrics`: to initialize and return the performance metrics.
     - `get_data_loader`: to prepare the DataLoader tailored for the specific training strategy.
     - `forward_backward`: to implement the forward and backward passes during training.
     - `debug_metrics_to_pbar_args`: to convert debug information for progress bar visualization.
     - `training_step_return`: to return summary metrics after a training epoch.
     It also provides common functionality like `start_training`, `optimizer_step`, and `train_step` to be reused across concrete training methods.
    """

    def __init__(self, model: Any, optimizer: Any, title: str, debug) -> None:
        """
        Initialize the training method.

        Args:
            model (Any): The model to train.
            optimizer (Any): The optimizer to use for training.
            title (str): Title for identification (e.g., for progress display).
            debug (bool): Flag to enable debug mode for extra logging.
        """
        self.model = model
        self.optimizer = optimizer
        self.title = title
        self.device = next(self.model.parameters()).device
        self.debug = debug
        self.mlp_adversary = None

    @abstractmethod
    def get_metrics(self) -> Dict[str, AverageMeter]:
        """
        Initialize and return a dictionary of performance metrics.

        Returns:
            Dict[str, AverageMeter]: Dictionary mapping metric names to metric trackers.
        """
        pass

    @abstractmethod
    def get_data_loader(self, dataset, batch_size) -> DataLoader:
        """
        Create a data loader for the training dataset.

        Args:
            dataset: The training dataset.
            batch_size (int): Number of samples per batch.

        Returns:
            DataLoader: PyTorch data loader instance.
        """
        pass

    @abstractmethod
    def forward_backward(
            self, sample, batch_idx, metrics: Dict[str, AverageMeter], dataset: ContiguousLabelDataset, accumulation_steps: int = 1, step: int = 0
    ) -> Dict[str, float]:
        """
        Execute forward and backward pass, compute loss, and update metrics.

        Args:
            sample: Current batch sample from data loader.
            batch_idx (int): Index of the current batch.
            metrics (Dict[str, AverageMeter]): Metric trackers.
            dataset (ContiguousLabelDataset): Dataset wrapper for label mapping.

        Returns:
            Dict[str, float]: Dictionary of current debug metric values.
        """
        pass

    @abstractmethod
    def debug_metrics_to_pbar_args(self, debug_metrics: Dict[str, float]) -> Dict[str, float]:
        """
        Format debug metrics for display in a progress bar.

        Args:
            debug_metrics (Dict[str, float]): Metrics from the current step.

        Returns:
            Dict[str, float]: Formatted metrics for tqdm display.
        """
        pass

    @abstractmethod
    def training_step_return(self, metrics: Dict[str, AverageMeter]) -> list[float]:
        """
        Generate summary metrics after completing a training epoch.

        Args:
            metrics (Dict[str, AverageMeter]): Collected training metrics.

        Returns:
            List[float]: List of averaged metric values for reporting.
        """

    def optimizer_step(self) -> None:
        """
        Apply the optimizer's step and zero the gradients.
        """
        self.optimizer.step()
        self.optimizer.zero_grad()

    def start_training(self) -> None:
        """
        Set the model in training mode.
        """
        self.model.train()

    def train_step(self, dataset, batch_size, classnames, accumulation_steps: int = 1):
        """
        Perform a complete training epoch, including data loading, training, and metric collection.

        Args:
            dataset: Dataset used for training.
            batch_size (int): Number of samples per training batch.

        Returns:
            List[float]: Averaged values for each tracked metric after the epoch.
        """
        metrics = self.get_metrics()
        self.start_training()
        tmp_dataset = ContiguousLabelDataset(dataset, classnames)
        dataloader = self.get_data_loader(tmp_dataset, batch_size)
        pbar = tqdm(dataloader, desc=f"Training-{self.title}", position=1, leave=False)
        for batch_idx, sample in enumerate(dataloader):
            debug_metrics = self.forward_backward(sample, batch_idx, metrics, tmp_dataset, accumulation_steps=accumulation_steps, step=batch_idx)
            pbar.set_postfix(
                self.debug_metrics_to_pbar_args(debug_metrics)
            )
            pbar.update(1)
        if accumulation_steps > 1:
            if (batch_idx + 1) % accumulation_steps != 0:
                if self.mlp_adversary is not None:
                    torch.nn.utils.clip_grad_norm_(
                        list(self.model.parameters()) + list(self.mlp_adversary.parameters()),
                        max_norm=1.0,
                        norm_type=2.0,
                        error_if_nonfinite=True
                    )

                self.optimizer.step()
                self.optimizer.zero_grad()

        return self.training_step_return(metrics)








### DoubleDataset Training Method

In [None]:
# From: ./training_systems/core/DoubleDatasetTrainingMethod.py
"""
 Abstract base class for training methods. it should have forward_backward method to be fulfilled by its children
"""

from abc import ABC, abstractmethod
from typing import Dict, Any, Callable

from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import AverageMeter, ContiguousLabelDataset

class DoubleDatasetTrainingMethod:
    """
    Abstract base class for training methods. it should have forward_backward method to be fulfilled by its children

    The TrainingMethod class serves as an abstract interface that standardizes the structure of different training strategies.
    It defines key methods that all training strategies must implement, such as:
     - `get_metrics`: to initialize and return the performance metrics.
     - `get_data_loader`: to prepare the DataLoader tailored for the specific training strategy.
     - `forward_backward`: to implement the forward and backward passes during training.
     - `debug_metrics_to_pbar_args`: to convert debug information for progress bar visualization.
     - `training_step_return`: to return summary metrics after a training epoch.
     It also provides common functionality like `start_training`, `optimizer_step`, and `train_step` to be reused across concrete training methods.
    """

    def __init__(self, model: Any, optimizer: Any, title: str, debug) -> None:
        """
        Initialize the training method.

        Args:
            model (Any): The model to train.
            optimizer (Any): The optimizer to use for training.
            title (str): Title for identification (e.g., for progress display).
            debug (bool): Flag to enable debug mode for extra logging.
        """
        self.model = model
        self.optimizer = optimizer
        self.title = title
        self.device = next(self.model.parameters()).device
        self.debug = debug

    @abstractmethod
    def get_metrics(self) -> Dict[str, AverageMeter]:
        """
        Initialize and return a dictionary of performance metrics.

        Returns:
            Dict[str, AverageMeter]: Dictionary mapping metric names to metric trackers.
        """
        pass

    @abstractmethod
    def get_data_loader1(self, dataset, batch_size) -> DataLoader:
        """
        Create a data loader for the training dataset.

        Args:
            dataset: The training dataset.
            batch_size (int): Number of samples per batch.

        Returns:
            DataLoader: PyTorch data loader instance.
        """
        pass
    @abstractmethod
    def get_data_loader2(self, dataset, batch_size) -> DataLoader:
        """
        Create a data loader for the training dataset.

        Args:
            dataset: The training dataset.
            batch_size (int): Number of samples per batch.

        Returns:
            DataLoader: PyTorch data loader instance.
        """
        pass

    @abstractmethod
    def debug_metrics_to_pbar_args1(self, debug_metrics: Dict[str, float]) -> Dict[str, float]:
        """
        Format debug metrics for display in a progress bar.

        Args:
            debug_metrics (Dict[str, float]): Metrics from the current step.

        Returns:
            Dict[str, float]: Formatted metrics for tqdm display.
        """
        pass

    @abstractmethod
    def debug_metrics_to_pbar_args2(self, debug_metrics: Dict[str, float]) -> Dict[str, float]:
        """
        Format debug metrics for display in a progress bar.
        """
        pass

    @abstractmethod
    def training_step_return(self, metrics: Dict[str, AverageMeter]) -> list[float]:
        """
        Generate summary metrics after completing a training epoch.

        Args:
            metrics (Dict[str, AverageMeter]): Collected training metrics.

        Returns:
            List[float]: List of averaged metric values for reporting.
        """
        pass

    def optimizer_step(self) -> None:
        """
        Apply the optimizer's step and zero the gradients.
        """
        self.optimizer.step()
        self.optimizer.zero_grad()

    def start_training(self) -> None:
        """
        Set the model in training mode.
        """
        self.model.train()

    @abstractmethod
    def forward_backward1(
            self, sample, batch_idx, metrics: Dict[str, AverageMeter], dataset: ContiguousLabelDataset, classes: list[int]
    ) -> Dict[str, float]:
        """
        Execute forward and backward pass, compute loss, and update metrics.

        Args:
            sample: Current batch sample from data loader.
            batch_idx (int): Index of the current batch.
            metrics (Dict[str, AverageMeter]): Metric trackers.
            dataset (ContiguousLabelDataset): Dataset wrapper for label mapping.

        Returns:
            Dict[str, float]: Dictionary of current debug metric values.
        """
        pass

    @abstractmethod
    def forward_backward2(
            self, sample, batch_idx, metrics: Dict[str, AverageMeter], dataset: ContiguousLabelDataset, classes: list[int]
    ) -> Dict[str, float]:
        """
        Execute forward and backward pass, compute loss, and update metrics.
        """

    def double_datasets_train_step(self, dataset1, dataset2, batch_size, names: list[str], classes1, classes2):
        assert len(names) == 2, "Number of names must be 2"
        metrics = self.get_metrics()
        self.start_training()
        tmp_dataset1 = ContiguousLabelDataset(dataset1, classes1)
        tmp_dataset2 = ContiguousLabelDataset(dataset2, classes2)

        dataloader1 = self.get_data_loader1(tmp_dataset1, batch_size)
        dataloader2 = self.get_data_loader2(tmp_dataset2, batch_size)

        pbar = tqdm(dataloader1, desc=f"Training-{self.title}/{names[0]}", position=1, leave=False)
        for batch_idx, sample in enumerate(dataloader1):
            debug_metrics = self.forward_backward1(sample, batch_idx, metrics, tmp_dataset1, classes1)
            pbar.set_postfix(
                self.debug_metrics_to_pbar_args1(debug_metrics)
            )
            pbar.update(1)


        pbar = tqdm(dataloader2, desc=f"Training-{self.title}/{names[1]}", position=1, leave=False)
        for batch_idx, sample in enumerate(dataloader2):
            debug_metrics = self.forward_backward2(sample, batch_idx, metrics, tmp_dataset2, classes2)
            pbar.set_postfix(
                self.debug_metrics_to_pbar_args2(debug_metrics)
            )
            pbar.update(1)

        return self.training_step_return(metrics)






### Evaluation Method

In [None]:
# From: ./training_systems/core/EvaluationMethod.py
from abc import ABC, abstractmethod
from typing import Dict, Any
from torch.utils.data import DataLoader
from utils.metrics import AverageMeter


class EvaluationMethod(ABC):
    """
    Abstract base class for evaluation methods. Standardizes evaluation interface for different strategies.
    """

    def __init__(self, model, batch_size: int = 32, device=None):
        """
        Initialize evaluation method.

        Args:
            model: The model to evaluate.
            batch_size (int): Evaluation batch size.
        """
        self.model = model
        self.device = next(self.model.parameters()).device if device is None else device
        self.batch_size = batch_size

    @abstractmethod
    def evaluate(self, dataset, classnames, desc_add="") -> Dict[str, float]:
        """
        Perform the evaluation.

        Args:
            dataset: The dataset to evaluate on.

        Returns:
            Dict[str, float]: Dictionary with evaluation results (e.g., accuracy, loss).
        """
        pass

## Training Methods

## Base CoCoOp Training

**Key Aspects:** Dynamic prompt generation conditioned on image features; cross-entropy loss on base classes.  
**Constraints:** Trains only on seen/base classes; no explicit mechanisms for handling novel categories.

Base CoCoOp learns context vectors (prompts) dynamically generated by an MLP (called MetaNet), conditioned on visual features extracted from input images. For each training sample, the model constructs prompts for all base classes by combining the learned context vectors with class names, then encodes these prompts using a text encoder.

Similarity scores between image and text embeddings are computed and optimized using a cross-entropy loss over the base classes. The training updates both the parameters of MetaNet and the context vectors through backpropagation until convergence.

This approach allows the prompts to adapt based on image content, improving recognition on base categories and setting the stage for further improvements like knowledge distillation or adversarial regularization.

---

### Training Procedure (repeat until convergence):

1. Sample an image and its label:  
   $$
   (\mathbf{x}, \hat{y}) \sim \mathcal{D}_{\text{seen}}^{\text{train}}
   $$

2. Extract visual features from the image:  
   $$
   \mathbf{v} = f(\mathbf{x})
   $$

3. Generate context vectors conditioned on visual features:  
   $$
   [v_1, \ldots, v_M] = \text{MetaNet}_\theta(\mathbf{v})
   $$

4. For each base class $c \in \mathcal{C}_{\text{base}}$:  
   - Construct the prompt by combining learned context vectors and the class name:  
     $$
     \mathbf{p}_c = [p_1 + v_1, \ldots, p_M + v_M, \texttt{"class } c \texttt{"}]
     $$  
   - Encode the prompt into a text embedding:  
     $$
     \mathbf{t}_c = g(\mathbf{p}_c)
     $$

5. Compute similarity scores between the image feature and each class text embedding:  
   $$
   \ell_c = \text{sim}(\mathbf{v}, \mathbf{t}_c)
   $$

6. Compute cross-entropy loss over the similarity scores with respect to the ground truth label:  
   $$
   \mathcal{L}_{\text{CE}} = \text{CE}(\{\ell_c\}, \hat{y})
   $$

7. Update MetaNet parameters $\theta$ and context vectors $p_1, \ldots, p_M$ via backpropagation on $\mathcal{L}_{\text{CE}}$.


In [None]:
# From: ./training_systems/training_methods/BaseCoCoOp.py
"""
This module defines the BaseCoCoOp training method, a baseline implementation for CoCoOp-based optimization.
It includes metric tracking, data loading, and training step execution using standard cross-entropy loss.
"""
import random
from typing import Dict, Any

import torch
from torch.utils.data import DataLoader

# from training_systems.core import TrainingMethod
# from utils import AverageMeter, ContiguousLabelDataset


class BaseCoCoOp(TrainingMethod):
    """
    BaseCoCoOp implements a simple training routine based on cross-entropy loss without adversarial components.
    This class inherits from TrainingMethod and provides basic training loop functionalities.
    """

    def __init__(
            self,
            model: Any,
            optimizer: Any,
            debug: bool = False
    ) -> None:
        super().__init__(model, optimizer, "Base CoCoOp", debug)

    def get_metrics(self) -> Dict[str, AverageMeter]:
        """
        Initializes and returns the performance metrics used during training.

        Returns:
            Dict[str, AverageMeter]: A dictionary with average meters for loss and accuracy.
        """


        return {
            "loss_metric": AverageMeter(),
            "accuracy_metric": AverageMeter(),
        }

    def get_data_loader(self, dataset: ContiguousLabelDataset, batch_size: int) -> DataLoader:
        """
        Creates and returns a DataLoader for the given dataset.

        Args:
            dataset (ContiguousLabelDataset): Dataset to load samples from.
            batch_size (int): Number of samples per batch.

        Returns:
            DataLoader: PyTorch DataLoader configured with shuffle and multiple workers.
        """
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
        )

    def forward_backward(
            self,
            sample,
            batch_idx,
            metrics: Dict[str, AverageMeter],
            dataset: ContiguousLabelDataset,
            accumulation_steps: int = 1,
            step: int = 0
    ) -> Dict[str, float]:
        """
        Executes the forward and backward pass for the BaseCoCoOp training method.

        Args:
            sample (Tuple[Tensor, Tensor]): Batch of input data and targets.
            batch_idx (int): Index of the current batch.
            metrics (Dict[str, AverageMeter]): Dictionary of metrics to be updated.
            dataset (ContiguousLabelDataset): Dataset object (unused in this implementation).

        Returns:
            Dict[str, float]: Dictionary with current loss and accuracy values.
        """
        # Load data into GPU
        inputs, targets = sample
        # === Pseudo-base: cross-entropy ===
        inputs_base = inputs.to(self.device)
        targets_base = targets.to(self.device)

        logits_base, loss_ce = self.model(inputs_base, targets_base)
        # === Combine losses ===
        print("SHAPES LOGITS: ",logits_base.shape, targets_base.shape)

        loss_ce = loss_ce / accumulation_steps
        loss_ce.backward()

        # optimizer step every `accumulation_steps` steps
        if (step + 1) % accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()

        self.optimizer_step()
        batch_size_total = inputs_base.size(0)

        metrics["loss_metric"].update(loss_ce.item(), n=batch_size_total)

        _, predicted = logits_base.max(dim=1)
        correct = (predicted == targets_base).sum().item()
        total = targets_base.size(0)
        metrics["accuracy_metric"].update(correct, n=total, raw=True)

        return {
            "loss": loss_ce.item(),
            "accuracy": correct / targets_base.size(0),
        }

    def debug_metrics_to_pbar_args(self, debug_metrics: Dict[str, float]) -> Dict[str, float]:
        """
        Passes debug metrics directly to be displayed in the progress bar.

        Args:
            debug_metrics (Dict[str, float]): Metrics from the current step.

        Returns:
            Dict[str, float]: Unmodified metrics suitable for display.
        """
        return debug_metrics

    def training_step_return(self, metrics: Dict[str, AverageMeter]) -> list[float]:
        """
        Extracts and returns the average loss and accuracy metrics after a training step.

        Args:
            metrics (Dict[str, AverageMeter]): Dictionary containing tracked metrics.

        Returns:
            List[float]: A list containing the average loss and accuracy values.
        """
        return [
            metrics["loss_metric"].avg,
            metrics["accuracy_metric"].avg,
        ]






### MLP Adversarial Training

One of our proposed modifications incorporates an adversarial training scheme applied to the MLP prompt learner within CoCoOp. The goal is to encourage the MLP to produce features invariant to predefined image clusters, thereby improving robustness.

#### Loss Formulation  
The total CoCoOp training loss combines the original cross-entropy classification loss and an adversarial loss weighted by a hyperparameter $\gamma$:

$$
\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{CE}} + \gamma \cdot (-\mathcal{L}_{\text{MLP-adv}})
$$

where $\mathcal{L}_{\text{CE}}$ is the base CoCoOp cross-entropy loss, and $\mathcal{L}_{\text{MLP-adv}}$ is the adversarial loss from the MLP. The adversarial component is implemented using a Gradient Reversal Layer (GRL) [4], effectively maximizing $\mathcal{L}_{\text{MLP-adv}}$ during CoCoOp training.

The MLP itself is trained to minimize $\mathcal{L}_{\text{MLP-adv}}$, which corresponds to a classification loss that predicts the cluster membership of each input image.

#### Cluster Definition

Clusters are precomputed before training and represent groupings of classes used to guide adversarial training in the MLP prompt learner (see [Clustering](#clustering)). Different clustering strategies and parameters can be employed depending on the experiment.

During training, the MLP aims to classify the cluster membership of each input image, with the number of clusters and loss type (binary or categorical cross-entropy) adjusted accordingly. These clusters help encourage the prompt learner to produce representations invariant to certain groupings, potentially improving robustness and generalization.

#### Intuition  
The most useful information we tried to exploit is: all the latent-space representations of the images in a category are different from those for images in a different category.

We adversarially train the CoCoOp prompt learner against a cluster-based MLP classifier to prevent it from encoding cluster-specific features. This forces the model to focus on broader, more transferable representations.

Our belief is that if the prompt learner degrades the discriminator's accuracy, it also reduces the influence of shared semantic traits across clusters' categories, leading to logits that are less biased by intra-domain similarities. This is key in our setting, where classes are visually similar and overfitting to fine-grained base class cues harms generalization to novel ones.

[4]:


#### Adversarial Method

In [None]:
# From: ./training_systems/training_methods/Adversarial.py
"""
This module implements the Adversarial training method, which incorporates a gradient reversal layer
and an adversarial MLP to encourage domain-invariant feature learning.
"""
from typing import Dict, Any, List

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

# from training_systems.core import TrainingMethod
# from utils import AverageMeter, ContiguousLabelDataset, CLASS_NAMES
# from model.cocoop.mlp_adversary import AdversarialMLP, GradientReversalLayer


class Adversarial(TrainingMethod):
    """
    Adversarial training method using a Gradient Reversal Layer and an MLP adversary.

    Attributes:
        cls_cluster_dict (Dict[int, Any]): Maps class indices to cluster labels.
        grl (GradientReversalLayer): The gradient reversal layer instance.
        mlp_adversary (AdversarialMLP): The adversarial MLP used to confuse cluster prediction.
        lambda_adv (float): Weight of the adversarial loss term.
        debug (bool): If True, print debug information.
    """

    def __init__(
            self,
            model: Any,
            optimizer: Any,
            cls_cluster_dict: Dict[int, Any],
            grl: GradientReversalLayer,
            mlp_adversary: AdversarialMLP,
            lambda_adv,
            tmp_classes: list,
            debug: bool = False,
            gaussian_noise: float = 0.0,
            use_bias_ctx: bool = False
    ) -> None:
        """
        Args:
            model (Any): The main model being trained.
            optimizer (Any): Optimizer for updating model parameters.
            cls_cluster_dict (Dict[int, Any]): Mapping from class labels to clusters.
            grl (GradientReversalLayer): The gradient reversal layer.
            mlp_adversary (AdversarialMLP): The adversarial network module.
            lambda_adv (float): Weight for the adversarial loss.
            debug (bool, optional): Enables debug mode. Defaults to False.
        """
        super().__init__(model, optimizer, "Adv.", debug)
        self.cls_cluster_dict = cls_cluster_dict
        self.grl = grl
        self.mlp_adversary = mlp_adversary
        self.lambda_adv = lambda_adv
        self.tmp_classes = tmp_classes
        self.gaussian_noise = gaussian_noise
        self.use_bias_ctx = use_bias_ctx

    def get_metrics(self) -> Dict[str, AverageMeter]:
        """
        Returns:
            Dict[str, AverageMeter]: Dictionary containing initialized metrics for training.
        """
        return {
            "total_loss_metric": AverageMeter(),
            "ce_loss_metric": AverageMeter(),
            "adv_loss_metric": AverageMeter(),
            "accuracy_metric": AverageMeter(),
        }

    def get_data_loader(self, dataset: ContiguousLabelDataset, batch_size: int) -> DataLoader:
        """
        Args:
            dataset (ContiguousLabelDataset): Dataset to be used.
            batch_size (int): Size of each batch.

        Returns:
            DataLoader: Configured PyTorch DataLoader instance.
        """
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
        )

    def forward_backward(
            self,
            sample,
            batch_idx,
            metrics: Dict[str, AverageMeter],
            dataset: ContiguousLabelDataset,
            accumulation_steps: int = 1,
            step: int = 0
    ) -> Dict[str, float]:
        """
        Executes the forward and backward pass.

        Args:
            sample (Tuple[Tensor, Tensor]): Batch of input data and targets.
            batch_idx (int): Index of the current batch.
            metrics (Dict[str, AverageMeter]): Dictionary of metrics to be updated.
            dataset (ContiguousLabelDataset): Dataset for cluster label lookup.

        Returns:
            Dict[str, float]: Dictionary with loss and accuracy metrics.
        """
        # Load data into GPU
        inputs, targets = sample
        inputs, targets = inputs.to(self.device), targets.to(self.device)

        targets_real_category = [dataset.idx2cat[c.item()] for c in targets]
        cluster_target = [int(self.cls_cluster_dict[int(tl)]) for tl in targets_real_category]
        cluster_target = torch.tensor(
            cluster_target,
            device=targets.device,
            dtype=torch.float32
        )

        # Use all tmp_classes for adversarial phase
        with self.model.temporary_classnames([CLASS_NAMES[idx] for idx in self.tmp_classes]):
            logits, ce_loss, img_features, ctx, bias, avg_txt_features, selected_text_features = self.model(inputs, targets, get_image_features=True)

            if self.gaussian_noise > 0:
                noise = torch.randn_like(logits) * self.gaussian_noise
                noisy_logits = logits + noise
            else:
                noisy_logits = logits

            if self.use_bias_ctx:
                if ctx.shape[0] == 1:
                    ctx = ctx.expand(bias.shape[0], -1, -1)

                ctx_shifted = ctx + bias  # shape: [B, L, D]
                # concat = ctx_shifted.view(ctx_shifted.size(0), -1).to(dtype=torch.float32)
                concat = torch.cat([selected_text_features, ctx_shifted.mean(dim=1)], dim=1).to(dtype=torch.float32)

            else:
                concat = torch.cat([avg_txt_features, noisy_logits], dim=1).to(dtype=torch.float32)
            # print(f"concat shape: {concat.shape}, bias shape: {bias.shape}, ctx shape: {ctx.shape}, avg_txt_features shape: {avg_txt_features.shape}, logits shape: {logits.shape}")
            reversed_concat = self.grl(concat)
            # print(f"reversed_concat shape: {reversed_concat.shape}")
            cluster_logits = self.mlp_adversary(reversed_concat)

            if cluster_logits.shape[1] == 1:
                # Binary classification
                cluster_logits = cluster_logits.squeeze(-1)
                loss_adv = F.binary_cross_entropy_with_logits(cluster_logits, cluster_target.float())
            else:
                # Multi-class classification
                loss_adv = F.cross_entropy(cluster_logits, cluster_target.long())

            # Skip adversarial update if prompt learner is frozen
            if any(p.requires_grad for p in self.model.prompt_learner.parameters()):
                ce_grads = self.get_grads(ce_loss)
            else:
                ce_grads = None  # Or torch.zeros_like(...), depending on downstream use
            total_loss = ce_loss + self.lambda_adv * loss_adv

            total_loss = total_loss / accumulation_steps
            total_loss.backward()
            # print(f"step: {step}, total_loss: {total_loss.item():.4f}, accumulation_steps: {accumulation_steps}, ")
            # --- accumulate grads and update ---
            if (step + 1) % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    list(self.model.parameters()) + list(self.mlp_adversary.parameters()),
                    max_norm=1.0,
                    norm_type=2.0,
                    error_if_nonfinite=True
                )
                self.optimizer_step()
                self.optimizer.zero_grad()

            if batch_idx < 3 and self.debug:  # Log only a few batches for performance
                print(f"[Batch {batch_idx}] CE Loss: {ce_loss.item():.4f} | "
                      f"Adv Loss: {loss_adv.item():.4f} | "
                      f"Total Loss: {total_loss.item():.4f} | "
                      f"lambda_adv: {self.lambda_adv:.4f}")
            batch_size = inputs.shape[0]
            metrics["total_loss_metric"].update(total_loss.item(), n=batch_size)
            metrics["ce_loss_metric"].update(ce_loss.item(), n=batch_size)
            metrics["adv_loss_metric"].update(loss_adv.item(), n=batch_size)
            _, predicted = logits.max(dim=1)
            correct = predicted.eq(targets).sum().item()
            metrics["accuracy_metric"].update(correct, n=batch_size, raw=True)
            return {
                "total_loss": total_loss.item(),
                "ce_loss": ce_loss.item(),
                "adv_loss": loss_adv.item(),
                "accuracy": correct / batch_size,
            }

    def print_grads_norms(self, bce_grads, ce_grads):
        """
        Args:
            bce_grads (Dict[str, Tensor]): Gradients from BCE loss.
            ce_grads (Dict[str, Tensor]): Gradients from CE loss.
        """
        for name in ce_grads:
            ce_norm = ce_grads[name].norm().item()
            bce_norm = bce_grads[name].norm().item()
            print(f"{name}: CE grad norm = {ce_norm:.4e}, BCE grad norm = {bce_norm:.4e}")

    def get_grads(self, loss):
        """
        Args:
            loss (Tensor): Loss tensor to backpropagate.

        Returns:
            Dict[str, Tensor]: Dictionary of gradients.
        """
        loss.backward(retain_graph=True)
        ce_grads = {}
        for name, p in self.model.named_parameters():
            if p.grad is not None and "prompt_learner" in name:
                ce_grads[name] = p.grad.detach().clone()
        # --- Zero gradients ---
        self.optimizer.zero_grad()
        return ce_grads

    def debug_metrics_to_pbar_args(self, debug_metrics: Dict[str, float]) -> Dict[str, float]:
        """
        Args:
            debug_metrics (Dict[str, float]): Metrics from current training step.

        Returns:
            Dict[str, float]: Same metrics, passed to progress bar.
        """
        return debug_metrics

    def training_step_return(self, metrics: Dict[str, AverageMeter]) -> list[float]:
        """
        Args:
            metrics (Dict[str, AverageMeter]): Collected training metrics.

        Returns:
            List[float]: Average values of total, accuracy, CE, and adversarial losses.
        """
        return [
            metrics["total_loss_metric"].avg,
            metrics["accuracy_metric"].avg,
            metrics["ce_loss_metric"].avg,
            metrics["adv_loss_metric"].avg,
        ]

    def update_lambda_adv(self, lambda_adv) -> None:
        """
        Args:
            lambda_adv (float): New value to set for lambda_adv.
        """
        self.lambda_adv = lambda_adv

#### MLP Adversarial + GRL

In [None]:
# From: ./model/cocoop/mlp_adversary.py
import torch
import torch.nn as nn
from torch.autograd import Function


class ResidualBlock(nn.Module):
    """
    A residual block with a linear layer, layer normalization, ReLU activation, dropout, and optional identity or projection shortcut.
    """
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.linear = nn.Linear(dim_in, dim_out)
        self.norm = nn.LayerNorm(dim_out)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(0.2)
        self.residual = (
            nn.Identity() if dim_in == dim_out else nn.Linear(dim_in, dim_out)
        )

    def forward(self, x):
        """
        Forward pass through the residual block with normalization, activation, dropout, and residual connection.
        """
        out = self.linear(x)
        out = self.norm(out)
        out = self.relu(out)
        out = self.drop(out)
        return out + self.residual(x)


class GradientReversalFunction(Function):
    """
    Implements a gradient reversal layer as a custom autograd function, useful in adversarial training.
    """
    @staticmethod
    def forward(ctx, x, lambda_):
        """
        Forward pass that returns the input as-is and stores the lambda factor for use in the backward pass.
        """
        ctx.lambda_ = lambda_
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass that reverses the gradient by multiplying it by -lambda.
        """
        return -ctx.lambda_ * grad_output, None


class GradientReversalLayer(nn.Module):
    """
    A wrapper for the GradientReversalFunction to integrate it into a standard nn.Module.
    """
    def __init__(self, lambda_=1.0):
        super().__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        """
        Applies the gradient reversal function with the stored lambda parameter.
        """
        return GradientReversalFunction.apply(x, self.lambda_)


class AdversarialMLP(nn.Module):
    """
    A multi-layer perceptron with optional bias context support, designed for adversarial learning.
    Uses residual blocks for intermediate layers and configurable final output dimension.
    """
    def __init__(self, input_dim, opt, output_dim=1, use_bias_ctx=False, n_ctx=4):
        """
        Initializes the adversarial MLP with given structure and optional bias context input.
        Builds the network using residual blocks and applies Xavier initialization to weights.
        """
        super().__init__()
        hidden_dims = opt.hidden_structure

        layers = []

        if use_bias_ctx:
            # Add a bias context layer if specified
            layers.append(nn.Linear(512*2, hidden_dims[0]))
            layers.append(nn.ReLU())
            for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]):
                layers.append(ResidualBlock(in_dim, out_dim))
        else:
            for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]):
                layers.append(ResidualBlock(in_dim, out_dim))

        # Final output layer with configurable output_dim
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        self.model = nn.Sequential(*layers)

        self.model.apply(self.init_weights.__get__(self, AdversarialMLP))

    def forward(self, x):
        """
        Forward pass through the adversarial MLP model.
        """
        return self.model(x)

    def predict(self, x):
        """
        Performs a forward pass and applies a sigmoid activation to the output.
        Squeezes output if it's a single dimension.
        """
        out = self.forward(x)
        return torch.sigmoid(out).squeeze(-1) if out.shape[-1] == 1 else torch.sigmoid(out)

    def init_weights(self, m):
        """
        Initializes weights of linear layers using Xavier uniform distribution and biases to zero.
        """
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

#### Clustering
To support adversarial training and related experiments, we utilize multiple methods to assign categories into clusters. These cluster assignments are designed to partition classes based on semantic, visual, or arbitrary criteria, allowing flexible experimentation.

- **Default Clustering:** The simplest form, grouping classes into base (seen) and novel (unseen) categories.

- **Visual Feature Clustering:** Extracting CLIP visual embeddings for each base class, aggregating features across training samples, then projecting into a lower-dimensional space using PCA (preserving 95% variance). We compute cosine distances between class features and apply Agglomerative Clustering to group visually similar categories. This can be extended to produce more than two clusters, enabling the MLP discriminator to differentiate among multiple groups.

- **Random Clustering:** Synthetic cluster assignments are generated by randomly partitioning classes using various distributions (uniform, fully random, or sequential). This tests the adversarial module’s robustness under arbitrary partitions and noise in the clustering criterion.

##### Consistency Across Datasets
Cluster assignments are consistently applied to both training and validation datasets.

##### Caching Mechanism
To improve efficiency, cluster assignments (both as indices and category names) are cached. If cached clusters exist for a given configuration, they are loaded automatically; otherwise, they are computed and saved for reuse.

##### Configurability via YAML Files
All clustering parameters, including the number of clusters, assignment criteria, and caching options, are configurable via the experiment `.yaml` files. This modular setup facilitates systematic evaluation of how different clusterings affect adversarial training and overall model generalization.


In [None]:
# From: ./utils/clustering.py
import numpy as np
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from tqdm import tqdm
import torch
# from utils.datasets import get_data, base_novel_categories, split_data, CLASS_NAMES
import clip
import os
import pickle
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_distances
from collections import Counter, deque


import random
from collections import deque

def create_cluster_from_ordered_list(ordered_categories, split_ratio):
    """
    Converts a list of categories into a cluster dict using split_ratio.
    """
    n = len(ordered_categories)
    n_zeros = int(n * split_ratio)
    return {
        cat: 0 if i < n_zeros else 1
        for i, cat in enumerate(ordered_categories)
    }

def rotating_cluster_generator_shift(categories, split_ratio, steps=1, seed=None):
    """
    Yields clusters by cyclically rotating the category list.

    Args:
        categories (list): List of category identifiers.
        split_ratio (float): Ratio of cluster 0 elements.
        seed (int, optional): Random seed for reproducibility.

    Yields:
        dict: Cluster mapping.
    """
    cat_list = list(categories)
    if seed is not None:
        random.seed(seed)
    random.shuffle(cat_list)
    cat_deque = deque(cat_list)

    while True:
        cluster = create_cluster_from_ordered_list(cat_deque, split_ratio)
        cat0 = [cat for cat, c in cluster.items() if c == 0]
        cat1 = [cat for cat, c in cluster.items() if c == 1]
        yield cluster, cat0, cat1
        cat_deque.rotate(-steps)  # rotate left

def cluster_categories(device, cnn, n_clusters=2, variance=0.95, data_dir="../data"):
    """
    Clusters base classes using visual features extracted from a CLIP model. Applies PCA to reduce dimensionality
    and Agglomerative Clustering on cosine distances to group the categories.

    Args:
        device (torch.device): The device to run computations on (CPU/GPU).
        cnn (str): CLIP model architecture name (e.g., "ViT-B/32").
        n_clusters (int): Number of clusters to generate.
        variance (float): Variance ratio to preserve during PCA.

    Returns:
        Tuple[Dict[int, int], Dict[str, int]]: Two dictionaries mapping class indices and class names to cluster IDs.
    """

    # initialize clip model with ViT
    clip_model, _ = clip.load(cnn)
    clip_model = clip_model.to(device)
    resolution = clip_model.visual.input_resolution
    train_set, _, _ = get_data(data_dir=data_dir, resolution=resolution)

    # split classes into base and novel
    base_classes, _ = base_novel_categories(train_set)

    # split the three datasets
    train_base, _ = split_data(train_set, base_classes)

    class_feature = {}
    with torch.no_grad():
        for c in tqdm(base_classes, desc="Processing classes"):
            imgs_c = []
            # Create a DataLoader to iterate over the dataset properly
            dataloader = torch.utils.data.DataLoader(train_base, batch_size=1, shuffle=False)
            for img, label in dataloader:
                if label.item() == c:
                    imgs_c.append(img.squeeze(0))
            features = [
                clip_model.encode_image(img.unsqueeze(0).to(device)).cpu().numpy()
                for img in imgs_c
            ]
            class_feature[c] = np.mean(features, axis=0)

    # class_ft_array = np.array([class_feature[c][0] for c in base_classes])

    cat2idx = {}
    idx2cat = {}
    class_ft_array = []
    for i, c in enumerate(base_classes):
        cat2idx[c] = i
        idx2cat[i] = c
        class_ft_array.append(class_feature[c][0])

    pca = PCA(n_components=variance)
    X_reduced = pca.fit_transform(class_ft_array)

    print(
        f"Reduced feature shape: {X_reduced.shape}, Variance explained: {pca.explained_variance_ratio_.sum()}"
    )

    cosine_dist = cosine_distances(X_reduced)
    # Step 5: Agglomerative clustering
    agglo = AgglomerativeClustering(
        n_clusters=n_clusters, metric="precomputed", linkage="average"
    )
    cluster_labels = agglo.fit_predict(cosine_dist)

    cluster_labels = {idx2cat[i]: cluster for i, cluster in enumerate(cluster_labels)}

    cluster_labels_text = {
        CLASS_NAMES[base_class]: int(cluster)
        for base_class, cluster in enumerate(cluster_labels)
    }

    torch.cuda.empty_cache()

    return cluster_labels, cluster_labels_text


def random_clustering(
    n_cluster,
    seed=42,
    data_dir="../data",
    distribution="uniform",
    split_ratio=0.7,  # Only for bipartite
):
    """
    Generates random cluster assignments for a given number of clusters.

    Args:
        n_cluster (int): Number of clusters to generate.
        seed (int): Random seed for reproducibility.
        data_dir (str): Directory where the dataset is stored.
        distribution (str): Distribution type for cluster assignment. Options are "uniform", "random", "sequential", or "bipartite".
        split_ratio (float): Percentage of classes in the larger cluster (only used for "bipartite").

    Returns:
        Tuple[
            Dict[int, int],        # class_id -> cluster_id
            Dict[str, int],        # class_name -> cluster_id
            List[int],             # class_ids in cluster 0
            List[int],             # class_ids in cluster 1
        ]
    """

    np.random.seed(seed)
    train_set, _, _ = get_data(data_dir=data_dir)
    base_classes, _ = base_novel_categories(train_set)

    cluster_labels = {}

    if distribution == "uniform":
        shuffled = np.random.permutation(base_classes)
        for i, cls in enumerate(shuffled):
            cluster_id = i % n_cluster
            cluster_labels[cls] = cluster_id

    elif distribution == "random":
        for cls in base_classes:
            cluster_id = np.random.choice(range(n_cluster))
            cluster_labels[cls] = cluster_id

    elif distribution == "sequential":
        for i, cls in enumerate(base_classes):
            cluster_id = i % n_cluster
            cluster_labels[cls] = cluster_id

    cluster_labels_text = {
        CLASS_NAMES[cls]: int(cluster_labels[cls])
        for cls in cluster_labels
    }

    cluster_dict_int = {int(k): v for k, v in cluster_labels.items()}

    return cluster_dict_int, cluster_labels_text


def conditional_clustering(n_cluster, variance, cnn, device, data_dir="../data"):
    """
    Loads existing cluster labels from disk if available, otherwise computes and saves new cluster assignments.

    Args:
        n_cluster (int): Number of clusters to generate.
        variance (float): Variance ratio to preserve during PCA.
        cnn (str): CLIP model architecture name (used for naming output files).
        device (torch.device): The device to run computations on.

    Returns:
        Tuple[Dict[int, int], Dict[str, int]]: Dictionaries for integer-labeled and text-labeled cluster assignments.
    """
    cnn_sanitized = cnn.replace("/", "_")
    save_dir = f"clustering_split/cluster_labels_{n_cluster}_{variance}_{cnn_sanitized}"
    os.makedirs(save_dir, exist_ok=True)

    int_categories_path = os.path.join(save_dir, "int_categories.pkl")
    text_categories_path = os.path.join(save_dir, "text_categories.pkl")


    # Commented this lines because they are not used in the Jupyter version, they are meant to save and load cluster labels

    # if os.path.exists(int_categories_path) and os.path.exists(text_categories_path):
    #     print("🟩 CLUSTERS FILES FOUND. Loading existing cluster labels...")
    #     with open(int_categories_path, "rb") as f:
    #         cluster_labels = pickle.load(f)
    #         cluster_dict_int = {int(k): v for k, v in cluster_labels.items()}
    #     with open(text_categories_path, "rb") as f:
    #         cluster_labels_text = pickle.load(f)

    # else:
    print("🟧 NO CLUSTERS FILES FOUND (jupyter version without saves). Creating cluster labels...")
    # cluster the base classes
    cluster_labels, cluster_labels_text = cluster_categories(
        device, n_clusters=n_cluster, variance=variance, cnn=cnn, data_dir=data_dir
    )
    cluster_dict_int = {int(k): v for k, v in cluster_labels.items()}

    ## Commented this lines because they are not used in the Jupyter version, they are meant to save and load cluster labels

    # with open(int_categories_path, "wb") as f:
    #     pickle.dump(cluster_labels, f)
    # with open(text_categories_path, "wb") as f:
    #     pickle.dump(cluster_labels_text, f)
    # Count samples in each cluster
    cluster_counts = Counter(cluster_dict_int.values())
    for cluster_id in range(n_cluster):
        print(f"Cluster {cluster_id} count: {cluster_counts.get(cluster_id, 0)}")

    return cluster_dict_int, cluster_labels_text



### KL Divergence Teacher Guidance (v1)

In this modification, we introduce a form of teacher-student training where CLIP's zero-shot predictions [^3] are used as soft targets to guide CoCoOp. This is motivated by the observation that CLIP-zero-shot can outperform CoCoOp on certain datasets like Oxford Flowers for novel classes, which is also our study case, and the fact that both models use the same CLIP backbone ensures architectural consistency.

#### Loss Formulation  
The training objective augments the standard CoCoOp cross-entropy loss with a Kullback–Leibler (KL) divergence term between CoCoOp’s output and CLIP-zero-shot predictions:

$$
\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{CE}} + \delta \cdot \mathcal{L}_{\text{KL}}
$$

where $\mathcal{L}_{\text{KL}}$ is the KL divergence between the softmax outputs of CoCoOp and the CLIP zero-shot classifier, and $\delta$ is a scaling hyperparameter.

#### Batch Splitting Strategy  
To apply the KL loss meaningfully, we split each training batch into two subsets:

- **Pseudo-base:** Samples belonging to 70% of the categories in the batch are used to compute the standard cross-entropy loss.
- **Pseudo-novel:** Samples belonging to the remaining 30% of the categories in the batch are used to compute the KL divergence loss.

This split is randomized at each batch, allowing CoCoOp to be trained with a blend of direct supervision and knowledge distillation from CLIP zero-shot predictions [^3].

#### Intuition  
The KL divergence term encourages CoCoOp to preserve the generalization strengths of CLIP’s zero-shot model, particularly for novel or underrepresented classes. By mimicking CLIP's soft predictions on part of the data, the prompt learner is gently nudged toward representations that maintain strong performance outside the training distribution.



## KL Regularization (v1)

**Key Aspects:** Pseudo-novel splitting within batch; loss fusion in a single backward pass.  
**Constraints:** No memory separation; subject to class imbalance due to random sampling.

This variant augments CoCoOp training with a KL divergence term between the model’s predictions and frozen CLIP zero-shot outputs. It uses a single dataloader that returns mini-batches sampled from  
$$
\mathcal{D}_{\text{seen}}^{\text{train}}
$$
with each batch split into pseudo-base and pseudo-novel subsets inside the collate function (typically at a fixed 70/30 ratio).

Both losses — cross-entropy on the pseudo-base and KL divergence on the pseudo-novel — are computed jointly and combined before the backward pass. This enables efficient optimization but introduces stochasticity in class composition and potential imbalance in the pseudo-novel subset. These effects can reduce training stability and reproducibility compared to more structured variants.

---

### Training Procedure (repeat until convergence):

1. Sample an image batch:  
   $$
   \mathbf{X}_{\text{b}} \sim \mathcal{D}_{\text{seen}}^{\text{train}}
   $$

2. Divide the batch into pseudo-novel and pseudo-base subsets:  
   $$
   (\mathbf{x}_{\text{bn}}, y_{\text{bn}}),\; (\mathbf{x}_{\text{bb}}, y_{\text{bb}}) \sim \mathbf{X}_{\text{b}}
   $$

3. Compute CoCoOp cross-entropy loss on the pseudo-base subset:  
   $$
   \mathcal{L}_{\text{CE}} = \text{CE} \left( \text{CoCoOp}_{\theta, p_1, \ldots, p_M}(\mathbf{x}_{\text{bb}}), y_{\text{bb}} \right)
   $$

4. Compute CoCoOp softmax predictions on the pseudo-novel subset:  
   $$
   y_{\text{student}} = \text{CoCoOp}_{\theta, p_1, \ldots, p_M}(\mathbf{x}_{\text{bn}})
   $$

5. Compute frozen CLIP softmax predictions on the pseudo-novel subset:  
   $$
   y_{\text{teacher}} = \text{CLIP}(\mathbf{x}_{\text{bn}})
   $$

6. Compute student-teacher KL divergence:  
   $$
   \mathcal{L}_{\text{KL}} = D_{\text{KL}} \left[ y_{\text{teacher}} \| y_{\text{student}} \right]
   $$

7. Compute total loss combining cross-entropy and weighted KL divergence:  
   $$
   \mathcal{L}_{\text{total}} = \mathcal{L}_{\text{CE}} + \lambda_{\text{KL}} \cdot \mathcal{L}_{\text{KL}}
   $$

8. Update parameters $\theta$ and context vectors $p_1, \ldots, p_M$ via backpropagation on $\mathcal{L}_{\text{total}}$.



In [None]:
# From: ./training_systems/training_methods/KLCoCoOp.py
"""
This module defines the KLCoCoOp training method, which combines standard cross-entropy loss with KL divergence
between a model's predictions and a frozen teacher (e.g., CLIP) to enhance generalization to novel classes.
"""
import random
from typing import Dict, Any

import torch
from torch.utils.data import DataLoader

from training_systems.core.TrainingMethod import TrainingMethod

# from utils import AverageMeter,ContiguousLabelDataset, get_kl_loss
# from utils.datasets import CLASS_NAMES


class KLCoCoOp(TrainingMethod):
    """
    KLCoCoOp training method combining cross-entropy classification on pseudo-base samples and KL divergence
    loss on pseudo-novel samples. It encourages transferability and generalization by mixing base and novel categories.

    Attributes:
        lambda_kl (float): Weight for the KL divergence loss component.
    """

    def __init__(
            self,
            model: Any,
            optimizer: Any,
            lambda_kl,
            debug: bool = False
    ) -> None:
        super().__init__(model, optimizer, "Base CoCoOp + KL", debug)
        self.lambda_kl = lambda_kl

    def get_metrics(self) -> Dict[str, AverageMeter]:
        """
        Initializes training metrics including total, cross-entropy, KL divergence losses, and classification accuracy.

        Returns:
            Dict[str, AverageMeter]: Dictionary of metric names mapped to their respective AverageMeter instances.
        """
        return {
            "total_loss_metric": AverageMeter(),
            "ce_loss_metric": AverageMeter(),
            "kl_loss_metric": AverageMeter(),
            "accuracy_metric": AverageMeter(),
        }

    def get_data_loader(self, dataset: ContiguousLabelDataset, batch_size: int) -> DataLoader:
        """
        Returns a DataLoader that splits each batch into pseudo-base and pseudo-novel subsets.

        Args:
            dataset (ContiguousLabelDataset): The dataset used for training.
            batch_size (int): Number of samples per batch.

        Returns:
            DataLoader: PyTorch DataLoader with a custom collate function to separate base and novel samples.
        """

        def custom_collate(batch):
            base_samples = []
            novel_samples = []
            targets_in_batch = list(set([target for _, target in batch]))
            random.shuffle(targets_in_batch)
            split_idx = int(0.7 * len(targets_in_batch))
            pseudo_base_ids = targets_in_batch[:split_idx]
            pseudo_novel_ids = targets_in_batch[split_idx:]
            for img, label in batch:
                if label in pseudo_base_ids:
                    base_samples.append((img, label))
                elif label in pseudo_novel_ids:
                    novel_samples.append((img, label))
            return base_samples, novel_samples

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            collate_fn=custom_collate,
        )

    def forward_backward(
            self,
            sample,
            batch_idx,
            metrics: Dict[str, AverageMeter],
            dataset: ContiguousLabelDataset,
            accumulation_steps: int = 1,
            step: int = 0
    ) -> Dict[str, float]:
        """
        Performs forward and backward passes, computing CE loss on pseudo-base and KL loss on pseudo-novel samples.

        Args:
            sample (Tuple[List[Tuple[Tensor, int]], List[Tuple[Tensor, int]]]): Tuple of pseudo-base and pseudo-novel batches.
            batch_idx (int): Current batch index during training.
            metrics (Dict[str, AverageMeter]): Dictionary to update with the training metrics.
            dataset (ContiguousLabelDataset): Dataset object used for KL divergence lookup.

        Returns:
            Dict[str, float]: Scalar values of total loss, CE loss, KL loss, and CE accuracy.
        """
        # Load data into GPU
        base_batch, novel_batch = sample

        if not base_batch or not novel_batch:
            return {
                metric: val.avg
                for metric, val in metrics.items()
            }

        # === Pseudo-base: cross-entropy ===
        inputs_base = torch.stack([img for img, _ in base_batch]).to(self.device)
        targets_base = torch.tensor([lbl for _, lbl in base_batch]).to(self.device)


        categories_base_tensor = [dataset.idx2cat[int(c.item())] for c in list(set(targets_base))]
        remapped_class_names = [ CLASS_NAMES[ c ] for c in categories_base_tensor ]

        target_remapping = {cat:idx for idx, cat in enumerate(categories_base_tensor)}
        target_original = [dataset.idx2cat[int(c.item())] for c in targets_base]
        target_remapped = torch.tensor([target_remapping[c] for c in target_original]).to(self.device)

        with self.model.temporary_classnames(remapped_class_names):
            self.model.train()
            logits_base, loss_ce = self.model(inputs_base, target_remapped)

        # === Pseudo-novel: KL divergence with frozen CLIP ===
        self.model.eval()  # needed to disable dropout etc.
        inputs_novel = torch.stack([img for img, _ in novel_batch]).to(self.device)
        targets_novel = [lbl for _, lbl in novel_batch]

        kl_loss = get_kl_loss(self.device, inputs_novel, self.model, targets_novel, dataset)

        # === Combine losses ===
        total_loss = loss_ce + self.lambda_kl * kl_loss

        total_loss = total_loss / accumulation_steps
        total_loss.backward()

        # optimizer step every `accumulation_steps` steps
        if (step + 1) % accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()

        batch_size_total = inputs_base.size(0) + inputs_novel.size(0)

        metrics["total_loss_metric"].update(total_loss.item(), n=batch_size_total)
        metrics["ce_loss_metric"].update(loss_ce.item(), n=inputs_base.size(0))
        metrics["kl_loss_metric"].update(kl_loss.item(), n=inputs_novel.size(0))

        _, predicted = logits_base.max(dim=1)
        correct = (predicted == target_remapped).sum().item()
        total = target_remapped.size(0)
        metrics["accuracy_metric"].update(correct, n=total, raw=True)

        return {
            "total_loss": total_loss.item(),
            "ce_loss": loss_ce.item(),
            "ce_accuracy": correct / target_remapped.size(0),
            "kl_loss": kl_loss.item(),
        }

    def debug_metrics_to_pbar_args(self, debug_metrics: Dict[str, float]) -> Dict[str, float]:
        """
        Prepares debug metrics for visualization in progress bars or logs.

        Args:
            debug_metrics (Dict[str, float]): Dictionary of debug metrics from the current step.

        Returns:
            Dict[str, float]: Same metrics for direct display.
        """
        return debug_metrics

    def training_step_return(self, metrics: Dict[str, AverageMeter]) -> list[float]:
        """
        Returns the average values of tracked metrics after a training step.

        Args:
            metrics (Dict[str, AverageMeter]): Metric dictionary to extract averages from.

        Returns:
            List[float]: Averages of total loss, accuracy, CE loss, and KL loss.
        """
        return [
            metrics["total_loss_metric"].avg,
            metrics["accuracy_metric"].avg,
            metrics["ce_loss_metric"].avg,
            metrics["kl_loss_metric"].avg,
        ]

    def update_lambda_kl(self, lambda_kl):
        self.lambda_kl = lambda_kl





#### KL(v1) Loss calculation

In [None]:
# From: ./utils/kl.py
import clip
import torch

# from utils.datasets import CLASS_NAMES


def get_kl_loss(device, inputs_novel, model, targets_novel, tmp_dataset):
    """
    Computes the KL divergence between the student model's predictions and the CLIP model's predictions
    for a batch of novel class images.

    Args:
        device (torch.device): The device (CPU or CUDA) to perform computation on.
        inputs_novel (Tensor): A batch of input images from novel classes.
        model (nn.Module): The student model that includes a CLIP backbone and prompt learner.
        targets_novel (List[int]): Target labels corresponding to the novel class inputs.
        tmp_dataset (ContiguousLabelDataset): Dataset wrapper with label-to-category mappings.

    Returns:
        Tensor: A scalar tensor representing the KL divergence loss.
    """
    targets_novel_tensor = torch.tensor(targets_novel).to(device) if isinstance(targets_novel, list) else targets_novel
    categories_novel_tensor = [tmp_dataset.idx2cat[c.item()] for c in list(set(targets_novel_tensor))]
    # print(f"input novel shape: {inputs_novel.shape} novel base: {targets_novel_tensor.shape}")
    with torch.no_grad():
        image_features_clip = model.clip_model.encode_image(inputs_novel)
        image_features_clip = image_features_clip / image_features_clip.norm(dim=-1, keepdim=True)


        text_inputs = clip.tokenize(
            [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories_novel_tensor]
        ).to(device)

        text_features_clip = model.clip_model.encode_text(text_inputs)
        text_features_clip = text_features_clip / text_features_clip.norm(dim=-1, keepdim=True)

        clip_logits = image_features_clip @ text_features_clip.T

    remapped_class_names = [ CLASS_NAMES[ c ] for c in categories_novel_tensor ]

    # targets_novel_tensor contains contiguous indices (0, 1, 2, ...)
    # categories_novel_tensor should be the original class labels for the current batch

    # No need to remap targets; they are already correct
    target_remapping = {cat:idx for idx, cat in enumerate(categories_novel_tensor)}
    target_original = [tmp_dataset.idx2cat[c.item()] for c in targets_novel_tensor]
    target_remapped = torch.tensor([target_remapping[c] for c in target_original]).to(device)

    with model.temporary_classnames(remapped_class_names):
        model.train()
        student_logits, student_loss = model(inputs_novel, target_remapped)  # [B, num_classes]
        kl_loss = torch.nn.functional.kl_div(
            torch.nn.functional.log_softmax(student_logits, dim=-1),
            torch.nn.functional.softmax(clip_logits, dim=-1),
            reduction="batchmean"
        )
    return kl_loss

### KL Divergence with Rotating Class Splits (v2)

This variant extends the KLv1 method by introducing a two-stage training strategy that separates the dataset into two disjoint subsets and optimizes them sequentially. The goal is to better mimic a base/novel split and train CoCoOp with a curriculum-like structure using CLIP zero-shot outputs as guidance [^3].

#### Split-Based Training  
We partition the base training dataset into two subsets:

- **pseudo-base-klv2 ($\mathcal{D}^\text{training}_\text{seen}$):** Used for supervised learning with standard cross-entropy loss.
- **pseudo-novel-klv2 ($\mathcal{D}^\text{training}_\text{unseen}$):** Used for knowledge distillation via KL divergence from CLIP zero-shot predictions.

During each epoch, the optimizer first performs updates on pseudo-base-klv2 using only the cross-entropy loss. After completing this pass, the model is trained on pseudo-novel-klv2 using only the KL divergence loss:

$$
\mathcal{L}_{\text{CE-step}} = \mathcal{L}_{\text{CE}}, \quad \mathcal{L}_{\text{KL-step}} = \delta \cdot \mathcal{L}_{\text{KL}}
$$

We deliberately avoid combining CE and KL in a single backward pass, as it leads to higher memory consumption. This separation simplifies optimization but may introduce gradient interference issues since both losses are not co-optimized in a shared context. Nevertheless, in practice, this sequential training proved more stable under memory constraints.

#### Rotating Class Assignments  
To prevent the model from overfitting to a fixed base/novel class partition, we periodically reshuffle the class assignments across epochs using a cyclic rotation strategy. The base class list is shuffled and split using a predefined ratio (e.g., 70% pseudo-base, 30% pseudo-novel), and then rotated after each remixing step.

The number of rotation steps is calculated as:

$$
\texttt{rotation\_steps} = |\mathcal{C}_\text{unseen}| = \left\lfloor |\mathcal{C}_\text{base}| \cdot (1 - \texttt{pseudo\_base\_ratio}) \right\rfloor
$$

This ensures that after each rotation:

$$
\mathcal{C}'_\text{unseen} \cap \mathcal{C}_\text{unseen} = \emptyset
$$

where $\mathcal{C}'_\text{unseen}$ are the new `pseudo-novel-klv2` classes after rotation.

#### Intuition  
This strategy gradually exposes the model to different subsets of classes as pseudo-novel categories, encouraging it to generalize beyond any fixed training split. By rotating which classes are treated as "novel" across epochs, the model is repeatedly challenged to adapt to new class combinations. This setup mimics the conditions of generalization to unseen classes, which is the core goal of CoCoOp.

The KL divergence loss acts as a soft guidance signal from CLIP’s zero-shot predictions [^3], helping the model stay aligned with more generalizable representations. Meanwhile, the rotation strategy prevents overfitting to a specific partition, ensuring balanced coverage and making training more robust over time.

[^1]: Radford, A., et al. (2021). *Learning Transferable Visual Models From Natural Language Supervision*. [arXiv:2103.00020](https://arxiv.org/abs/2103.00020)


### KL (v2) - Implementation

**Key Aspects:** Dual dataloaders; controlled class remixing with rotation; label remapping for KL.  
**Constraints:** Increased implementation complexity; KL and CE losses optimized separately.

This version improves control over the pseudo-base and pseudo-novel composition by using two distinct dataloaders. One loader provides samples for cross-entropy loss, while the other supplies examples for KL divergence. A rotation-based mechanism (implemented via a `deque`) periodically shifts the set of classes assigned to the pseudo-novel split, ensuring coverage over time.

To ensure consistent KL computation, pseudo-novel labels are remapped to contiguous indices, and the corresponding CLIP logits are computed on-the-fly. Since CE and KL losses operate on independent batches, gradients are accumulated in two separate backward passes per step. This setup allows better balancing, interpretability, and reproducibility at the cost of increased computational and memory overhead.

---

### Training Procedure (repeat until convergence):

1. **Phase 1: Standard CoCoOp training on seen batch**

   Sample batch:  
   $$
   (\mathbf{x}_{\text{bb}}, \mathbf{y}_{\text{bb}}) \sim \mathcal{D}_{\text{seen}}^{\text{train}}
   $$

   Apply base CoCoOp training procedure (see previous section) on $(\mathbf{x}_{\text{bb}}, \mathbf{y}_{\text{bb}})$.

2. **Phase 2: KL divergence on unseen batch**

   Sample batch:  
   $$
   (\mathbf{x}_{\text{bn}}, \mathbf{y}_{\text{bn}}) \sim \mathcal{D}_{\text{unseen}}^{\text{train}}
   $$

   Compute frozen CLIP teacher logits:  
   $$
   y_{\text{teacher}} = \text{CLIP}(\mathbf{x}_{\text{bn}})
   $$

   Compute CoCoOp student logits:  
   $$
   y_{\text{student}} = \text{CoCoOp}_{\theta, p_1, \ldots, p_M}(\mathbf{x}_{\text{bn}})
   $$

   Compute KL divergence loss:  
   $$
   \mathcal{L}_{\text{KL}} = D_{\text{KL}} \left( y_{\text{teacher}} \| y_{\text{student}} \right)
   $$

   Apply weighting factor:  
   $$
   \mathcal{L}_{\text{KL}} \leftarrow \lambda_{\text{KL}} \cdot \mathcal{L}_{\text{KL}}
   $$

   Update parameters $\theta$ and context vectors $p_1, \ldots, p_M$ via backpropagation on $\mathcal{L}_{\text{KL}}$.


In [None]:
# From: ./training_systems/training_methods/KLCoCoOpV2.py
"""
This module defines the KLCoCoOp training method, which combines standard cross-entropy loss with KL divergence
between a model's predictions and a frozen teacher (e.g., CLIP) to enhance generalization to novel classes.
"""
import random
from typing import Dict, Any, Optional

import clip
import torch
from torch.utils.data import DataLoader, Dataset

# from training_systems.core import DoubleDatasetTrainingMethod

# from utils import AverageMeter,CLASS_NAMES, ContiguousLabelDataset, get_kl_loss


class KLCoCoOpV2(DoubleDatasetTrainingMethod):
    """
    KLCoCoOp training method combining cross-entropy classification on pseudo-base samples and KL divergence
    loss on pseudo-novel samples. It encourages transferability and generalization by mixing base and novel categories.

    Attributes:
        lambda_kl (float): Weight for the KL divergence loss component.
    """

    def __init__(
            self,
            model: Any,
            optimizer: Any,
            lambda_kl,
            debug: bool = False,
    ) -> None:
        super().__init__(model, optimizer, "Base CoCoOp + KL", debug)
        self.lambda_kl = lambda_kl
        if self.debug:
            print(f"[KLCoCoOpV2] Initialized with lambda_kl={self.lambda_kl}, device={self.device}")


    def get_metrics(self) -> Dict[str, AverageMeter]:
        """
        Initializes training metrics including total, cross-entropy, KL divergence losses, and classification accuracy.

        Returns:
            Dict[str, AverageMeter]: Dictionary of metric names mapped to their respective AverageMeter instances.
        """
        if self.debug:
            print("[KLCoCoOpV2] Initializing metrics.")
        return {
            "ce_loss_metric": AverageMeter(),
            "kl_loss_metric": AverageMeter(),
            "ce_accuracy_metric": AverageMeter(),
        }

    def get_data_loader1(self, pseudo_base: ContiguousLabelDataset, batch_size: int) -> DataLoader:
        """
        Returns a DataLoader that splits each batch into pseudo-base and pseudo-novel subsets.

        Args:
            dataset (ContiguousLabelDataset): The dataset used for training.
            batch_size (int): Number of samples per batch.

        Returns:
            DataLoader: PyTorch DataLoader with a custom collate function to separate base and novel samples.
        """
        if self.debug:
            print(f"[KLCoCoOpV2] Creating DataLoader1 for pseudo_base with batch_size={batch_size}.")
        return DataLoader(
            pseudo_base,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
        )

    def get_data_loader2(self, pseudo_novel_dataset: ContiguousLabelDataset, batch_size: int) -> DataLoader:
        """
        Returns a DataLoader that splits each batch into pseudo-base and pseudo-novel subsets.
        """
        if self.debug:
            print(f"[KLCoCoOpV2] Creating DataLoader2 for pseudo_novel with batch_size={batch_size}.")
        return DataLoader(pseudo_novel_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


    def forward_backward1(
            self,
            sample,
            batch_idx,
            metrics: Dict[str, AverageMeter],
            dataset: ContiguousLabelDataset,
            classes: list[int]
    ) -> Dict[str, float]:
        if self.debug and batch_idx < 3:
            print(f"[KLCoCoOpV2] forward_backward1: batch_idx={batch_idx}")
            print(f"[KLCoCoOpV2] Sample type: {type(sample)}")
        # Load data into GPU
        inputs, targets = sample
        # === Pseudo-base: cross-entropy ===
        inputs_base = inputs.to(self.device)
        targets_base = targets.to(self.device)
        # Use remapped class names to match the remapped labels
        # classes contains original category indices, but we need to map them to 0-based indices
        # The ContiguousLabelDataset remaps labels to 0-based indices, so we need class names in the same order
        # Use the remapped indices to get class names in the correct order
        remapped_class_names = [ CLASS_NAMES[ dataset.idx2cat[i] ] for i in range(len(dataset.idx2cat)) ]
        assert set(classes) == set(dataset.idx2cat.values()), \
            f"Split classes ({classes}) != dataset labels ({list(dataset.idx2cat.values())})"
        with self.model.temporary_classnames(remapped_class_names):
            logits_base, loss_ce = self.model(inputs_base, targets_base)
            # === Combine losses ===
            if self.debug and batch_idx < 3:
                print(f"[KLCoCoOpV2] LOGITS shape: {logits_base.shape}, TARGETS shape: {targets_base.shape}")
                print(f"[KLCoCoOpV2] CE loss: {loss_ce.item()}")
                print(f"[KLCoCoOpV2] Remapped class names: {remapped_class_names}")
                print(f"[KLCoCoOpV2] Original classes: {classes}")
                print(f"[KLCoCoOpV2] Targets: {targets_base}")
                print(f"[KLCoCoOpV2] Logits max: {logits_base.max(dim=1)[1]}")
                print(f"[KLCoCoOpV2] Dataset cat2idx: {dataset.cat2idx}")
                print(f"[KLCoCoOpV2] Dataset idx2cat: {dataset.idx2cat}")
                print(f"[KLCoCoOpV2] Expected class names order: {[CLASS_NAMES[dataset.idx2cat[i]] for i in range(len(classes))]}")
                print(f"[KLCoCoOpV2] Model class names: {[CLASS_NAMES[c] for c in classes]}")
            loss_ce.backward()

            self.optimizer_step()
            batch_size_total = inputs_base.size(0)

            metrics["ce_loss_metric"].update(loss_ce.item(), n=batch_size_total)

            _, predicted = logits_base.max(dim=1)
            correct = (predicted == targets_base).sum().item()
            total = targets_base.size(0)
            metrics["ce_accuracy_metric"].update(correct, n=total, raw=True)

            if self.debug and batch_idx < 3:
                print(f"[KLCoCoOpV2] Batch accuracy: {correct}/{total} = {correct/total if total > 0 else 0}")
        return {
            "ce_loss": loss_ce.item(),
            "accuracy": correct / targets_base.size(0),
        }

    def forward_backward2(
            self,
            sample,
            batch_idx,
            metrics: Dict[str, AverageMeter],
            dataset: ContiguousLabelDataset,
            classes: list[int]
    ) -> Dict[str, float]:
        # Use remapped class names to match the remapped labels
        # classes contains original category indices, but we need to map them to 0-based indices
        # The ContiguousLabelDataset remaps labels to 0-based indices, so we need class names in the same order
        # Use the remapped indices to get class names in the correct order
        remapped_class_names = [ CLASS_NAMES[ dataset.idx2cat[i] ] for i in range(len(dataset.idx2cat)) ]
        pseudo_novel_class_names = remapped_class_names
        if self.debug and batch_idx < 3:
            print(f"[KLCoCoOpV2] forward_backward2: batch_idx={batch_idx}")
            print(f"[KLCoCoOpV2] Sample type: {type(sample)}; Sample len: {len(sample) if hasattr(sample, '__len__') else 'N/A'}")
        # Load data into GPU
        inputs_novel, targets_novel = sample
        # === Pseudo-novel: KL divergence with frozen CLIP ===
        inputs_novel = inputs_novel.to(self.device)
        targets_novel = targets_novel.to(self.device)

        with torch.no_grad():
            image_features_clip = self.model.clip_model.encode_image(inputs_novel)
            image_features_clip = image_features_clip / image_features_clip.norm(dim=-1, keepdim=True)

            #category_idxs = [dataset.idx2cat[c.item()] for c in list(set(targets_novel))] # type: ignore

            text_inputs = clip.tokenize(
                [f"a photo of a {cn}, a type of flower." for cn in pseudo_novel_class_names]
            ).to(self.device)

            text_features_clip = self.model.clip_model.encode_text(text_inputs)
            text_features_clip = text_features_clip / text_features_clip.norm(dim=-1, keepdim=True)

            clip_logits = image_features_clip @ text_features_clip.T


        self.model.train()
        with self.model.temporary_classnames(remapped_class_names):
            student_logits, student_loss = self.model(inputs_novel, targets_novel)
            kl_loss = torch.nn.functional.kl_div(
                torch.nn.functional.log_softmax(student_logits, dim=-1),
                torch.nn.functional.softmax(clip_logits, dim=-1),
                reduction="batchmean"
            )

        # === Combine losses ===
            kl_loss = self.lambda_kl * kl_loss

            if self.debug and batch_idx < 3:
                print(f"[KLCoCoOpV2] KL loss (weighted): {kl_loss.item()}")
            kl_loss.backward()

            self.optimizer_step()

            metrics["kl_loss_metric"].update(kl_loss.item(), n=inputs_novel.size(0))

        return {
            "kl_loss": kl_loss.item(),
        }

    def debug_metrics_to_pbar_args1(self, debug_metrics: Dict[str, float]) -> Dict[str, float]:
        """
        Prepares debug metrics for visualization in progress bars or logs.

        Args:
            debug_metrics (Dict[str, float]): Dictionary of debug metrics from the current step.

        Returns:
            Dict[str, float]: Same metrics for direct display.
        """
        if self.debug:
            print(f"[KLCoCoOpV2] debug_metrics_to_pbar_args1: {debug_metrics}")
        return debug_metrics

    def debug_metrics_to_pbar_args2(self, debug_metrics: Dict[str, float]) -> Dict[str, float]:
        """
        Prepares debug metrics for visualization in progress bars or logs.
        """
        if self.debug:
            print(f"[KLCoCoOpV2] debug_metrics_to_pbar_args2: {debug_metrics}")
        return debug_metrics

    def training_step_return(self, metrics: Dict[str, AverageMeter]) -> list[float]:
        """
        Returns the average values of tracked metrics after a training step.

        Args:
            metrics (Dict[str, AverageMeter]): Metric dictionary to extract averages from.

        Returns:
            List[float]: Averages of total loss, accuracy, CE loss, and KL loss.
        """
        if self.debug:
            print(f"[KLCoCoOpV2] training_step_return: KL={metrics['kl_loss_metric'].avg}, CE={metrics['ce_loss_metric'].avg}, Acc={metrics['ce_accuracy_metric'].avg}")
        return [
            metrics["kl_loss_metric"].avg,
            metrics["ce_loss_metric"].avg,
            metrics["ce_accuracy_metric"].avg,
        ]

    def update_lambda_kl(self, lambda_kl):
        """
        Update the lambda_kl value used for KL loss weighting.

        Args:
            lambda_kl (float): New lambda_kl value.
        """
        self.lambda_kl = lambda_kl






## Evaluation Methods

### Eval Step

In [None]:
# From: ./training_systems/evaluation_methods/EvalStep.py

from typing import Dict
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
import torch

# from training_systems.core import EvaluationMethod
# from utils import ContiguousLabelDataset, CLASS_NAMES, AverageMeter


class EvalStep(EvaluationMethod):
    """
    Generic evaluation step for models that support temporary class name modification.
    """
    @torch.no_grad()
    def evaluate(self, dataset, classnames, desc_add="") -> Dict[str, float]:
        """
        Evaluate model performance on the provided dataset.

        Args:
            dataset: Dataset for evaluation.
            new_classnames (list, optional): New class names to apply temporarily.
            desc_add (str): Suffix to append to tqdm description.

        Returns:
            Dict[str, float]: Dictionary containing average loss and accuracy.
        """
        self.model.eval()
        loss_meter = AverageMeter()
        accuracy_meter = AverageMeter()
        tmp_dataset = ContiguousLabelDataset(dataset, classnames)
        dataloader = DataLoader(tmp_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1)

        remapped_classnames = [ CLASS_NAMES[ tmp_dataset.idx2cat[i] ] for i in range(len(tmp_dataset.idx2cat)) ]
        with self.model.temporary_classnames(remapped_classnames):
            self.walk(loss_meter, accuracy_meter, dataloader, desc_add)

        return {"loss": loss_meter.avg, "accuracy": accuracy_meter.avg}

    @torch.no_grad()
    def walk(self, loss_meter, accuracy_meter, dataloader, desc_add=""):
        """
        Perform the evaluation loop over the dataset.

        Args:
            loss_meter: Tracks average loss.
            accuracy_meter: Tracks average accuracy.
            dataloader: DataLoader to iterate over.
            desc_add (str): Additional string to append to tqdm description.
        """
        for images, targets in tqdm(dataloader, desc="Validation" + desc_add, position=1, leave=False):
            images = images.to(self.device)
            targets = targets.to(self.device)
            logits = self.model(images)
            loss = F.cross_entropy(logits, targets)
            predictions = logits.argmax(dim=-1)
            correct = (predictions == targets).sum().item()
            loss_meter.update(loss.item(), n=targets.size(0))
            accuracy_meter.update(correct, n=targets.size(0), raw=True)

### Zero-Shot Test Step + Finetuned Test Step

In [None]:
# From: ./training_systems/evaluation_methods/TestSteps.py
from typing import Dict

import clip
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
import torch

# from training_systems.core import EvaluationMethod
# from utils import ContiguousLabelDataset, CLASS_NAMES, AverageMeter


class ZeroShotTestStep(EvaluationMethod):
    """
    Evaluation method for models that have been fine-tuned (e.g., CoCoOp or adversarial models).
    """

    def __init__(self, model, batch_size, categories):
        super().__init__(model, batch_size)
        self.categories = categories
        # here we apply the standard CLIP template used for oxford flowers to all categories
        text_inputs = clip.tokenize(
            [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in self.categories]
        ).to(self.device)
        with torch.no_grad():
            # we can encode the text features once as they are shared for all images
            # therefore we do it outside the evaluation loop
            self.text_features = self.model.encode_text(text_inputs)
            # and here we normalize them (standard pratice with CLIP)
            self.text_features /= self.text_features.norm(dim=-1, keepdim=True)

    @torch.no_grad()
    def evaluate(self, dataset, desc_add="") -> Dict[str, float]:
        self.model.eval()
        accuracy_meter = AverageMeter()
        tmp_dataset = ContiguousLabelDataset(dataset, self.categories)
        dataloader = DataLoader(tmp_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
        # Remap labels into a contiguous set starting from zero
        self.walk(dataloader, accuracy_meter, desc_add)

        return {"accuracy": accuracy_meter.avg}

    @torch.no_grad()
    def walk(self, dataloader, accuracy_meter, desc_add):
        pbar = tqdm(total=len(dataloader), desc="Test (Zero Shots) " + desc_add, position=1, leave=False)
        for image, target in dataloader:
            # base categories range from 0 to 50, while novel ones from 51 to 101
            # therefore we must map categories to the [0, 50], otherwise we will have wrong predictions
            # Map targets in contiguous set starting from zero
            # Labels needs to be .long() in pytorch

            image = image.to(self.device)
            target = target.to(self.device)

            # forward image through CLIP image encoder
            image_features = self.model.encode_image(image)
            # and normalize
            image_features /= image_features.norm(dim=-1, keepdim=True)

            # here cosine similarity between image and text features and keep the argmax for every row (every image)
            predicted_class = (image_features @ self.text_features.T).argmax(dim=-1)
            # now we check which are correct, and sum them (False == 0, True == 1)

            correct = (predicted_class == target).sum().item()
            accuracy_meter.update(correct, n=target.size(0), raw=True)
            pbar.set_postfix({
                "accuracy" : accuracy_meter.avg
            })
            pbar.update(1)


class FineTunedTestStep(EvaluationMethod):
    """
    Evaluation method for the f rozen base CLIP model.
    """

    @torch.no_grad()
    def evaluate(self, dataset, classnames: list[int], desc_add="") -> Dict[str, float]:
        self.model.eval()
        accuracy_meter = AverageMeter()
        tmp_dataset = ContiguousLabelDataset(dataset, classnames)
        dataloader = DataLoader(tmp_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
        remapped_class_names = [CLASS_NAMES[tmp_dataset.idx2cat[i]] for i in range(len(tmp_dataset.idx2cat))]
        with self.model.temporary_classnames(remapped_class_names):
            pbar = tqdm(total=len(dataloader), desc="Test (Finetuned) " + desc_add, position=1, leave=False)
            for images, targets in dataloader:
                images = images.to(self.device)
                targets = targets.to(self.device)
                logits = self.model(images)
                predictions = logits.argmax(dim=-1)
                correct = (predictions == targets).sum().item()
                accuracy_meter.update(correct, n=targets.size(0), raw=True)
                pbar.set_postfix({
                    "accuracy": accuracy_meter.avg
                })
                pbar.update(1)

        return {"accuracy": accuracy_meter.avg}

# Utils & Generics

## Dataset

The dataset used for this project is **Oxford 102 Flowers**, a fine-grained classification dataset consisting of 102 flower categories. The dataset was predefined by the course instructors as a common benchmark for all groups participating in the project.

To simulate a zero-shot generalization setting, the dataset is split into two disjoint subsets:

- **Base classes** ($\mathcal{C}_\text{base}$): 51 classes used during training.  
- **Novel classes** ($\mathcal{C}_\text{novel}$): 51 classes held out for evaluation only.

Following the standard few-shot protocol, we use only **10 shots per class** for training and validation. No additional data augmentation or unlabeled data is used beyond this few-shot setup.

In certain methods, we simulate exposure to novel-like conditions by creating **pseudo-novel** ($\mathcal{D}_\text{unseen}$) subsets within the base classes, split dynamically or statically depending on the method (e.g., KLv2). The pseudo-novel ratio determines how many base classes are treated as novel at a time.

### Data Augmentation

To promote generalization to novel classes and reduce overfitting to spurious image components such as background or object shape, we apply data augmentation during training. Specifically, we replicate the augmentation pipeline used in the original CoCoOp implementation [^3], as defined in its configuration files.

The augmentation includes a random resized crop with bicubic interpolation, random horizontal flip, and normalization using CLIP-specific mean and standard deviation values:

```yaml
INPUT:
  SIZE: (224, 224)
  INTERPOLATION: "bicubic"
  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
  TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
```

This version uses proper triple backticks for the YAML block and no escaping, so it will render cleanly in any Markdown environment. Let me know if you're targeting a specific renderer (like Jupyter or Hugo).

### *Seen* and *Unseen* Split

#### Motivation

As part of the zero-shot learning framework, it is critical to respect the constraint that **no data from the novel classes** ($\mathcal{D}_\text{novel}$) is used in any form of training or model selection. This includes:

- Updating model parameters using novel data.
- Using novel data for validation, early stopping, or loss evaluation during training.

Allowing access to $\mathcal{D}_\text{novel}$ during training would compromise the core assumption of zero-shot learning: generalization to unseen categories without direct supervision.

**However**, in this work, we occasionally report performance on *val_novel* for the purpose of monitoring how different techniques affect generalization to novel classes. This evaluation is performed strictly *post hoc*, with no model updates or checkpoint selection based on novel data. This use is consistent with a case-study setup, where understanding the behavior of novel-class performance under various constraints is central to the investigation.

#### Split

Whatever method-pipeline is chosen, at the beginning of training we split base categories into $b$ seen and $\texttt{total categories} - b$ unseen categories, where:

$$
b = \texttt{total categories} \cdot \texttt{seen ratio}
$$

with $0.7 < \texttt{seen ratio} < 0.9$ to preserve training quality.

The formal structure is defined as:

$$
\langle \mathcal{C}_\text{seen}, \mathcal{C}_\text{unseen} \rangle = \mathcal{C}_\text{base}
$$

$$
\mathcal{D}_\text{base} = \{ \mathbf{x}_i \mid y_i \in \mathcal{C}_\text{base} \} = \mathcal{D}_\text{seen} \cup \mathcal{D}_\text{unseen}
$$

$$
\mathcal{D}_\text{seen} = \{ \mathbf{x}_i \mid y_i \in \mathcal{C}_\text{seen} \}
$$

$$
\mathcal{D}_\text{unseen} = \{ \mathbf{x}_i \mid y_i \in \mathcal{C}_\text{unseen} \}
$$

$$
\mathcal{D}_\text{seen} \cap \mathcal{D}_\text{unseen} = \emptyset
$$

The roles of $\mathcal{D}_\text{seen}$ and $\mathcal{D}_\text{unseen}$ vary depending on the training method employed. Below, we summarize how each method leverages the data split:

- **Base CoCoOp**:  
  - $\mathcal{D}_\text{seen}$ is used to update model parameters via cross-entropy loss.  
  - $\mathcal{D}_\text{unseen}$ is used only for validation ("novel" accuracy tracking).  
  - No knowledge distillation or gradients from $\mathcal{D}_\text{unseen}$ flow into the model.

- **CoCoOp + KL (v1)**:  
  - A batch from $\mathcal{D}_\text{seen}$ is split into pseudo-base/pseudo-novel samples.  
  - All samples contribute to both CE and KL loss within the batch.  
  - $\mathcal{D}_\text{unseen}$ is only used for validation.

- **CoCoOp + KL (v2)**:  
  - $\mathcal{D}_\text{seen}$ is used for standard CoCoOp training with cross-entropy.  
  - $\mathcal{D}_\text{unseen}$ is used to compute KL divergence from frozen CLIP logits.  
  - $\mathcal{D}_\text{seen}$ and $\mathcal{D}_\text{unseen}$ are rotated epoch-by-epoch (see [Section](#subsec:KLV2design)).  
  - Both sets contribute to training updates.

- **Adversarial Training**:  
  - The full dataset $\mathcal{D}_\text{base} = \mathcal{D}_\text{seen} \cup \mathcal{D}_\text{unseen}$ is used.  
  - All samples contribute via cross-entropy and adversarial losses.

---

**Summary:**  
$\mathcal{D}_\text{unseen}$ is used *only for validation* in **Base CoCoOp** and **KL(v1)**. It becomes active in **KL(v2)** through KL regularization and is fully used in **Adversarial Training**.  
In code, the terms `pseudo_base` and `pseudo_novel` correspond to $\mathcal{C}_\text{seen}$ and $\mathcal{C}_\text{unseen}$ respectively.



### ContiguousLabelDataset + Dataset's utils

In [None]:
# From: ./utils/datasets.py
import clip
import torch
import torchvision
from torch.utils.data import Subset
import torchvision.transforms as T

CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

class ContiguousLabelDataset(torch.utils.data.Dataset):
    """
    A dataset wrapper that remaps arbitrary class labels to contiguous integers starting from 0.

    This is useful for classification tasks where models expect class indices to be in a 0-based contiguous range.

    Attributes:
        dataset (Dataset): The original dataset to wrap.
        cat2idx (Dict[Any, int]): Mapping from original class labels to contiguous integer indices.
        idx2cat (Dict[int, Any]): Reverse mapping from indices back to original class labels.
    """
    def __init__(self, dataset, class_order: list[int]):
        self.dataset = dataset
        # force the mapping to use your custom order
        self.cat2idx = { cat: idx for idx, cat in enumerate(class_order) }
        self.idx2cat = { idx: cat for cat, idx in self.cat2idx.items() }

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image, label = self.dataset[index]
        mapped_label = self.cat2idx[label]
        return image, mapped_label


# --- Data Augmentation Pipeline ---
def build_default_transform(resolution=224):
    """
    Builds the default data augmentation pipeline as specified:
    - RandomResizedCrop with bicubic interpolation
    - RandomHorizontalFlip
    - Normalize with given mean and std
    """
    return T.Compose([
        T.RandomResizedCrop(resolution, interpolation=T.InterpolationMode.BICUBIC),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
    ])


def build_eval_transform(resolution=224):
    """
    Builds the evaluation transform pipeline (no random augmentations):
    - Resize to resolution
    - CenterCrop to resolution
    - Normalize with given mean and std
    """
    return T.Compose([
        T.Resize(resolution, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(resolution),
        T.ToTensor(),
        T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
    ])


def get_data(data_dir="./data", train_transform=None, eval_transform=None, resolution=224):
    """
    Loads the Flowers102 dataset from torchvision, returning separate splits for training, validation, and testing.

    Args:
        data_dir (str): Directory where the dataset will be downloaded/stored. Defaults to "./data".
        train_transform (torchvision.transforms.Compose or None): Transformations to apply to training data.
        eval_transform (torchvision.transforms.Compose or None): Transformations to apply to validation/test data.
        resolution (int): Image resolution for transforms (default 224).

    Returns:
        tuple: A tuple (train, val, test) of Flowers102 dataset splits.
    """
    if train_transform is None:
        train_transform = build_default_transform(resolution)
    if eval_transform is None:
        eval_transform = build_eval_transform(resolution)

    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=train_transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=eval_transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=eval_transform)
    return train, val, test


def base_novel_categories(dataset):
    # set returns the unique set of all dataset classes
    all_classes = set(dataset._labels)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes


def get_labels(dataset):
    """
    Recursively retrieve labels from dataset or nested Subset.
    Assumes the base dataset has a `_labels` attribute.
    """
    if hasattr(dataset, '_labels'):
        return dataset._labels
    elif isinstance(dataset, Subset):
        parent_labels = get_labels(dataset.dataset)
        return [parent_labels[i] for i in dataset.indices]
    else:
        raise AttributeError("Dataset does not have _labels or is not a Subset of a dataset with _labels.")

def split_data(dataset, base_classes):
    """
    Splits the dataset into base and novel subsets based on base_classes.
    Works even if the input dataset is already a Subset.

    Args:
        dataset: PyTorch Dataset or Subset
        base_classes (List[int]): List of class indices considered as base.

    Returns:
        base_dataset (Subset): Subset containing samples from base classes.
        novel_dataset (Subset): Subset containing samples from novel classes.
    """
    base_categories_samples = []
    novel_categories_samples = []

    labels = get_labels(dataset)
    base_set = set(base_classes)

    for sample_id, label in enumerate(labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = Subset(dataset, base_categories_samples)
    novel_dataset = Subset(dataset, novel_categories_samples)

    return base_dataset, novel_dataset


## Metrics

To evaluate training dynamics and final performance, we track a comprehensive set of metrics using TensorBoard. These metrics are recorded for both training and validation phases and are logged for every configuration run.

#### Loss Curves  
For each method, we log the evolution of the main losses on both training and validation sets:

- **Cross-Entropy Loss**  
- **KL Divergence Loss** (when applicable, e.g., KLv1 or KLv2)  
- **Adversarial Loss** (when adversarial training with the MLP discriminator is active)

#### Accuracy Tracking  
We monitor accuracy on both training and validation datasets, according to the structure of the current method:

- **Validation Accuracy:** Includes base, novel, pseudo-base, and pseudo-novel subsets depending on the method (e.g., pseudo-novel for KLv2, base-only for CE-only baselines).  
- **Training Accuracy:** Captured for the subset used in each training phase (e.g., pseudo-base for CE step).

#### Final Evaluation  
At the end of training, we report test accuracy on:

- **Base Classes:** Accuracy on the 51 seen classes used during training.  
- **Novel Classes:** Accuracy on the 51 unseen classes to assess generalization.  
- **Harmonic Mean:** The harmonic mean of base and novel accuracy is computed to summarize the trade-off between performance on seen and unseen classes:

$$
H = \frac{2 \cdot \text{Acc}_\text{base} \cdot \text{Acc}_\text{novel}}{\text{Acc}_\text{base} + \text{Acc}_\text{novel}}
$$

#### Experiment Configuration  
All experiment settings, including hyperparameters such as learning rate, loss weights, number of epochs, architectural dimensions, and optimizer choices, are stored and logged from the YAML configuration file associated with each run. This allows for consistent tracking and comparison across experiments.

These metrics provide both detailed insights into learning progress and a clear summary of final generalization performance across different configurations.


In [None]:
# From: ./utils/metrics.py
from collections import defaultdict

import torch


class AverageMeter:
    """Compute and store the average and current value.

    Examples::
        >>> # 1. Initialize a meter to record loss
        >>> losses = AverageMeter()
        >>> # 2. Update meter after every mini-batch update
        >>> losses.update(loss_value, batch_size)
    """

    def __init__(self, ema=False):
        """
        Args:
            ema (bool, optional): apply exponential moving average.
        """
        self.ema = ema
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1, raw=False):
        if isinstance(val, torch.Tensor):
            val = val.item()

        self.val = val
        if raw:
            self.sum += val
        else:
            self.sum += val * n

        self.count += n

        if self.ema:
            self.avg = self.avg * 0.9 + self.val * 0.1
        else:
            self.avg = self.sum / self.count


class MetricMeter:
    """Store the average and current value for a set of metrics.

    Examples::
        >>> # 1. Create an instance of MetricMeter
        >>> metric = MetricMeter()
        >>> # 2. Update using a dictionary as input
        >>> input_dict = {'loss_1': value_1, 'loss_2': value_2}
        >>> metric.update(input_dict)
        >>> # 3. Convert to string and print
        >>> print(str(metric))
    """

    def __init__(self, delimiter=" "):
        self.meters = defaultdict(AverageMeter)
        self.delimiter = delimiter

    def update(self, input_dict):
        if input_dict is None:
            return

        if not isinstance(input_dict, dict):
            raise TypeError(
                "Input to MetricMeter.update() must be a dictionary"
            )

        for k, v in input_dict.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            self.meters[k].update(v)

    def __str__(self):
        output_str = []
        for name, meter in self.meters.items():
            output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})")
        return self.delimiter.join(output_str)


#### Tensorboard Logger + CSV Logger

In [None]:
# From: ./utils/tensor_board_logger.py
import os

from torch.utils.tensorboard.writer import SummaryWriter


def harmonic_mean(a, b):
    return 2 * (a * b) / (a + b)


class CSVLogger:
    """
    A simple logger that writes training metrics to a CSV file.

    Attributes:
        filename (str): The path to the output CSV file.
        file (file object): Open file handle for writing.
    """

    def __init__(self, filename):
        """
        Initializes the CSV logger and creates the output directory and file.

        Args:
            filename (str): Path to the CSV file for logging.
        """
        self.filename = filename
        os.makedirs(
            os.path.dirname(filename), exist_ok=True
        )  # Create folder if it doesn't exist
        self.file = open(filename, "w")
        self.file.write("epoch,base_acc,novel_acc,harmonic_mean\n")

    def log(self, epoch, base_acc, novel_acc):
        """
        Logs a training epoch's results including harmonic mean to the CSV file.

        Args:
            epoch (int): The epoch number.
            base_acc (float): Accuracy on base classes.
            novel_acc (float): Accuracy on novel classes.
        """
        hm = harmonic_mean(base_acc, novel_acc)
        self.file.write(f"{epoch},{base_acc},{novel_acc},{hm}\n")
        self.file.flush()
        print(
            f"Logged: epoch {epoch}, base_acc {base_acc}, novel_acc {novel_acc}, harmonic_mean {hm}"
        )

    def close(self):
        """
        Closes the CSV file.
        """
        self.file.close()


class BaseAndNovelMetrics:
    """
    Tracks base and novel accuracy values and their harmonic mean across epochs.

    Attributes:
        tmp (List[Tuple[int, float, float, float]]): Logged metrics per epoch.
    """

    def __init__(self):
        self.tmp = []

    def update(self, epoch, base_acc, novel_acc):
        """
        Records a new set of metrics for a specific epoch.

        Args:
            epoch (int): Epoch number.
            base_acc (float): Accuracy on base classes.
            novel_acc (float): Accuracy on novel classes.
        """
        self.tmp.append(
            (epoch, base_acc, novel_acc, harmonic_mean(base_acc, novel_acc))
        )

    def get_metrics(self):
        """
        Retrieves the collected metrics.

        Returns:
            List[Tuple[int, float, float, float]] or None: Logged metrics or None if empty.
        """
        if len(self.tmp) == 0:
            return None
        return self.tmp


class TensorboardLogger:
    """
    Handles logging of training and evaluation metrics to TensorBoard and CSV.

    Attributes:
        writer (SummaryWriter): TensorBoard writer instance.
        csv_logger (CSVLogger): CSV logger instance for persistent metric storage.
        hparams (dict): Hyperparameters dictionary.
        base_and_novel_metrics (BaseAndNovelMetrics): Metric tracker.
    """

    def __init__(self, writer: SummaryWriter):
        """
        Initializes the TensorboardLogger.

        Args:
            writer (SummaryWriter): TensorBoard writer instance.
        """
        self.writer = writer
        self.csv_logger = CSVLogger(f"{writer.log_dir}/metrics.csv")
        self.hparams = None
        self.base_and_novel_metrics = BaseAndNovelMetrics()

    def log_hparams(self, hparams: dict):
        """
        Stores hyperparameters for later logging.

        Args:
            hparams (dict): Hyperparameters dictionary.
        """
        self.hparams = hparams

    def log_training_base(self, epoch, lr, ce_loss, acc, kl_loss, total_loss):
        """
        Logs training metrics for the base training phase.

        Args:
            epoch (int): Current epoch.
            lr (float): Learning rate.
            ce_loss (float): Cross entropy loss.
            acc (float): Accuracy.
            kl_loss (float or None): KL divergence loss, optional.
            total_loss (float): Total loss.
        """
        self.writer.add_scalar("learning_rate", lr, epoch)
        self.writer.add_scalar("train_base/ce_loss", ce_loss, epoch)
        self.writer.add_scalar("train_base/ce_accuracy", acc, epoch)
        if kl_loss is not None:
            self.writer.add_scalar("train_base/kl_loss", kl_loss, epoch)
        self.writer.add_scalar("train_base/total_loss", total_loss, epoch)

    def log_training_adv(
        self, epoch, lambda_adv, ce_loss, acc, adv_loss, total_loss, kl_loss=None
    ):
        """
        Logs training metrics for the adversarial training phase.

        Args:
            epoch (int): Current epoch.
            lambda_adv (float): Adversarial loss weight.
            ce_loss (float): Cross entropy loss.
            acc (float): Accuracy.
            adv_loss (float): Adversarial loss.
            total_loss (float): Total loss.
            kl_loss (float or None): KL divergence loss, optional.
        """
        self.writer.add_scalar("lambda_adv", lambda_adv, epoch)
        self.writer.add_scalar("train_adv/ce_loss", ce_loss, epoch)
        self.writer.add_scalar("train_adv/ce_accuracy", acc, epoch)
        self.writer.add_scalar("train_adv/mlp_loss", adv_loss, epoch)
        if kl_loss is not None:
            self.writer.add_scalar("train_adv/kl_loss", kl_loss, epoch)
        self.writer.add_scalar("train_adv/total_loss", total_loss, epoch)

    def log_validation(
        self, epoch, base_loss, base_acc, novel_loss, novel_acc, is_adv=False
    ):
        """
        Logs validation metrics.

        Args:
            epoch (int): Current epoch.
            base_loss (float): Loss on base classes.
            base_acc (float): Accuracy on base classes.
            novel_loss (float): Loss on novel classes.
            novel_acc (float): Accuracy on novel classes.
            is_adv (bool): Whether validation is adversarial.
        """
        self.writer.add_scalar(f"validation_base/loss", base_loss, epoch)
        self.writer.add_scalar(f"validation_base/accuracy", base_acc, epoch)
        self.writer.add_scalar(f"validation_novel/loss", novel_loss, epoch)
        self.writer.add_scalar(f"validation_novel/accuracy", novel_acc, epoch)

        prefix = "validation_adv" if is_adv else "validation_ce"
        self.writer.add_scalar(f"{prefix}_base/loss", base_loss, epoch)
        self.writer.add_scalar(f"{prefix}_base/accuracy", base_acc, epoch)
        self.writer.add_scalar(f"{prefix}_novel/loss", novel_loss, epoch)
        self.writer.add_scalar(f"{prefix}_novel/accuracy", novel_acc, epoch)

    def log_final_metrics(self, tag, base_acc, novel_acc, step):
        """
        Logs final accuracy metrics and updates CSV and metric tracker.

        Args:
            tag (str): Tag name for TensorBoard.
            base_acc (float): Accuracy on base classes.
            novel_acc (float): Accuracy on novel classes.
            step (int): Training step or epoch index.
        """
        harmonic = harmonic_mean(base_acc, novel_acc)
        self.writer.add_scalars(
            tag,
            {
                "Harmonic Mean": harmonic,
                "Base Accuracy": base_acc,
                "Novel Accuracy": novel_acc,
            },
            global_step=step + 1,
        )

        self.csv_logger.log(step + 1, base_acc, novel_acc)

        if self.hparams is not None:
            self.base_and_novel_metrics.update(step + 1, base_acc, novel_acc)

        self.writer.flush()

    def log_test_accuracy(self, step, acc, label):
        """
        Logs test accuracy for a given label.

        Args:
            step (int): Step or epoch index.
            acc (float): Accuracy value.
            label (str): Label name for the accuracy metric.
        """
        self.writer.add_scalar(f"{label}/accuracy", acc, step)

    def close(self):
        """
        Closes the logger, writes hyperparameters and final metrics if available.
        """
        metrics = self.base_and_novel_metrics.get_metrics() or []

        """
        metric_dict = {
            "base_acc_after_base": metrics[0][1] if metrics else 0,
            "novel_acc_after_base": metrics[0][2] if metrics else 0,
            "harmonic_mean_after_base": metrics[0][3] if metrics else 0,
            "base_acc_after_adv": metrics[1][1] if metrics else 0,
            "novel_acc_after_adv": metrics[1][2] if metrics else 0,
            "harmonic_mean_after_adv": metrics[1][3] if metrics else 0,
        }"""

        if self.hparams is not None and metrics:

            tmp = {}
            # Set the prefix based on whether metrics are from base phase (index 0) or adversarial phase (index 1)
            for idx, m in enumerate(metrics):
                prefix = "after_base" if idx == 0 else "after_adv"
                tmp[f"epoch_{prefix}"] = m[0]
                tmp[f"base_acc_{prefix}"] = m[1]
                tmp[f"novel_acc_{prefix}"] = m[2]
                tmp[f"harmonic_mean_{prefix}"] = m[3]

            self.writer.add_hparams(
                hparam_dict=self.hparams,
                metric_dict=tmp,
            )
        self.writer.close()
        self.csv_logger.close()

## Training Coop

In [None]:
# From: ./utils/training_coop.py
"""
This module provides training, evaluation, and testing functions for the CoOp model and standard CLIP evaluation,
including fine-tuning and zero-shot classification on image datasets.
"""
from clip.model import CLIP
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import clip

# from model.coop.custom_clip import CustomCLIPCoOp
# from utils.datasets import ContiguousLabelDataset, CLASS_NAMES


@torch.no_grad()
def eval_step(model, dataset, cost_function, new_classnames, batch_size=32, device="cuda"):
    """
    Evaluates the model on a given dataset using cross-entropy loss.

    Args:
        model (nn.Module): The model to evaluate.
        dataset (Dataset): Dataset to evaluate on.
        cost_function (Callable): Loss function to use.
        batch_size (int): Batch size for evaluation.
        device (str): Computation device ("cuda" or "cpu").
        new_classnames (List[int] or None): Optional list of class indices to temporarily substitute for evaluation.

    Returns:
        Tuple[float, float]: Average loss and accuracy.
    """
    model.eval()

    tmp_dataset = ContiguousLabelDataset(dataset, new_classnames)
    dataloader = DataLoader(tmp_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    total_loss = 0.0
    correct = 0
    total = 0
    if new_classnames is not None:
        new_classnames = [CLASS_NAMES[c] for c in new_classnames]
        with model.temporary_classnames(new_classnames):
            correct, total, total_loss = walk_the_dataset(correct, cost_function, dataloader, device, model, total,
                                                          total_loss)

    else:
        correct, total, total_loss = walk_the_dataset(correct, cost_function, dataloader, device, model, total,
                                                      total_loss)
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


def walk_the_dataset(correct, cost_function, dataloader, device, model, total, total_loss):
    """
    Iterates over the dataset and computes cumulative loss and accuracy.

    Args:
        correct (int): Running count of correct predictions.
        cost_function (Callable): Loss function used.
        dataloader (DataLoader): DataLoader for the dataset.
        device (str): Computation device.
        model (nn.Module): Model being evaluated.
        total (int): Total number of samples evaluated so far.
        total_loss (float): Accumulated loss value.

    Returns:
        Tuple[int, int, float]: Updated correct, total, and total_loss.
    """
    for images, targets in tqdm(dataloader, desc="Validation", position=1, leave=False):
        images = images.to(device)
        targets = targets.to(device)

        loss, logits = model(images, targets)

        total_loss += loss.item() * targets.size(0)
        predictions = logits.argmax(dim=-1)
        correct += (predictions == targets).sum().item()
        total += targets.size(0)
    return correct, total, total_loss


def training_step(model: CustomCLIPCoOp, dataset, optimizer, batch_size, classnames, device="cuda"):
    """
    Performs one full training epoch for the CoOp model.

    Args:
        model (CustomCLIPCoOp): The model to train.
        dataset (Dataset): Dataset to train on.
        optimizer (Optimizer): Optimizer used for updating model parameters.
        batch_size (int): Batch size for training.
        device (str): Computation device.

    Returns:
        Tuple[float, float]: Average training loss and accuracy.
    """
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    # Set the network to training mode

    model.train()

    tmp_dataset = ContiguousLabelDataset(dataset, classnames)
    dataloader = DataLoader(tmp_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    pbar = tqdm(dataloader, desc="Training", position=1, leave=False)
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        # Load data into GPU
        inputs = inputs.to(device)
        targets = targets.to(device)
        # debug if inputs and targets are taken correctly by the dataloader
        #print(inputs.shape)
        #print(targets.shape)
        # Forward pass + loss computation
        loss, logits = model(inputs, targets)

        if torch.isnan(loss):
            print("⚠️ NaN loss encountered!")
            #print("Logits:", logits)
            print("Targets:", targets)

        # Backward pass
        loss.backward()

        # Parameters update
        optimizer.step()

        # Gradients reset
        optimizer.zero_grad()

        # Fetch prediction and loss value
        samples += inputs.shape[0]
        cumulative_loss += loss.item() * inputs.shape[0]
        _, predicted = logits.max(dim=1)  # max() returns (maximum_value, index_of_maximum_value)

        # Compute training accuracy
        cumulative_accuracy += predicted.eq(targets).sum().item()

        pbar.set_postfix(train_loss=loss.item(), train_acc=cumulative_accuracy / samples )
        pbar.update(1)

    return cumulative_loss / samples, cumulative_accuracy / samples


@torch.no_grad()
def test_step(model, dataset, batch_size, device, categories, label="test", base=False):
    """
    Evaluates the model using either fine-tuned or base (zero-shot) strategy.

    Args:
        model (nn.Module): The model to test.
        dataset (Dataset): Dataset for testing.
        batch_size (int): Batch size for testing.
        device (str): Device used for computation.
        label (str): Label for the progress bar.
        base (bool): Whether to use zero-shot CLIP instead of the fine-tuned model.

    Returns:
        float: Accuracy score.
    """
    if not base:
        return finetuned_test_step(model, dataset, batch_size, device, categories, label)
    else:
        return base_test_step(model, dataset, categories, batch_size, device, label)


@torch.no_grad()
def finetuned_test_step(model: CustomCLIPCoOp, dataset, batch_size, device, categories, label="test"):
    """
    Evaluates a fine-tuned CustomCLIPCoOp model on the given dataset.

    Args:
        model (CustomCLIPCoOp): Fine-tuned CoOp model.
        dataset (Dataset): Dataset for evaluation.
        batch_size (int): Batch size.
        device (str): Computation device.
        label (str): Label for progress display.

    Returns:
        float: Accuracy of the model on the dataset.
    """
    model.eval()

    tmp_dataset = ContiguousLabelDataset(dataset, categories)
    dataloader = DataLoader(tmp_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    correct = 0
    total = 0

    for images, targets in tqdm(dataloader, desc=label):
        images = images.to(device)
        targets = targets.to(device)

        logits = model(images)
        predictions = logits.argmax(dim=-1)

        correct += (predictions == targets).sum().item()
        total += targets.size(0)

    accuracy = correct / total
    return accuracy


@torch.no_grad()  # we don't want gradients
def base_test_step(model: CLIP, dataset, categories, batch_size, device, label=""):
    """
    Evaluates a zero-shot CLIP model using cosine similarity between image and text embeddings.

    Args:
        model (CLIP): Pretrained CLIP model.
        dataset (Dataset): Dataset to evaluate.
        categories (List[int]): List of category indices to evaluate.
        batch_size (int): Batch size for evaluation.
        device (str): Computation device.
        label (str): Optional label for progress bar.

    Returns:
        float: Accuracy of zero-shot CLIP classification.
    """
    # let's set the model in evaluation mode
    model.eval()

    # Remap labels into a contiguous set starting from zero
    contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

    # here we apply the standard CLIP template used for oxford flowers to all categories
    # and immediately tokenize each sentence (convert natural language into numbers - feel free to print the text input to inspect them)
    text_inputs = clip.tokenize(
        [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]
    ).to(device)

    # we can encode the text features once as they are shared for all images
    # therefore we do it outside the evaluation loop
    text_features = model.encode_text(text_inputs)
    # and here we normalize them (standard pratice with CLIP)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # simple dataloader creation
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # here we store the number of correct predictions we will make
    correct_predictions = 0
    for image, target in tqdm(dataloader, desc=label):
        # base categories range from 0 to 50, while novel ones from 51 to 101
        # therefore we must map categories to the [0, 50], otherwise we will have wrong predictions
        # Map targets in contiguous set starting from zero
        # Labels needs to be .long() in pytorch
        target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

        image = image.to(device)
        target = target.to(device)

        # forward image through CLIP image encoder
        image_features = model.encode_image(image)
        # and normalize
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # here cosine similarity between image and text features and keep the argmax for every row (every image)
        predicted_class = (image_features @ text_features.T).argmax(dim=-1)
        # now we check which are correct, and sum them (False == 0, True == 1)
        correct_predictions += (predicted_class == target).sum().item()

    # and now we compute the accuracy
    accuracy = correct_predictions / len(dataset)
    return accuracy

# Main

### Main (Disabled for Jupyter Execution)

In [None]:
# From: ./main.py
import yaml
import argparse

# from training_systems.cocoop import CoCoOpSystem
# from training_systems.coop import CoOpSystem
import torch
import os
from datetime import datetime
import pickle
from collections import Counter


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default=os.getenv("DEVICE", "cuda:0"))
    parser.add_argument('--run_name', default=f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
    parser.add_argument('--using_coop', default=False, type=lambda x: x.lower() in ("1", "true", "yes", "true"))
    parser.add_argument('--config', default="train_config.yaml")
    parser.add_argument('--debug', default=True, type=lambda x: x.lower() in ("1", "true", "yes", "true"))
    return parser.parse_args()

####
#Disabled for jupyter notebook execution
####

# if __name__ == "__main__":
#     # Parse command-line arguments
#     # Assign parsed arguments to variables
#     # Display which device is being used
#     # Handle MPS backend by setting default tensor type to float32
#     # Indicate whether CoOp or CoCoOp is used for training
#     # Load training configuration from YAML file

#     # Initialize and train using CoOpSystem if specified in arguments
#     # Initialize and train using CoCoOpSystem otherwise
#     args = parse_args()

#     device = args.device
#     run_name = args.run_name
#     debug = args.debug
#     use_coop = args.using_coop

#     print(f"Using device: {device}")

#     if torch.backends.mps.is_available():
#         print("\u26a0\ufe0f Forcing float32 due to MPS limitations")
#         torch.set_default_dtype(torch.float32)

#     print(f"Using {'CoOp' if use_coop else 'CoCoOp'} for training")

#     # Load hyperparameters from YAML
#     with open(args.config, "r") as file:
#         config = yaml.safe_load(file)

#     if use_coop:
#         coop_cfg = config['coop']
#         train_sys = CoOpSystem(
#             device=device,
#             run_name=run_name,
#             **coop_cfg
#         )
#     else:
#         cocoop_cfg = config['cocoop']
#         train_sys = CoCoOpSystem(
#             device=device,
#             run_name=run_name,
#             debug=debug,
#             hparams_file=args.config,
#             **cocoop_cfg
#         )

#     train_sys.train()

### Jupyter Main

In [None]:
# Jupyter new Main

if __name__ == "__main__":

    runs_yaml = ["test_kl+adv"]

    ## converted to a list of dictionaries for compatibility from the yaml file

    runs_config = [{
        "coop": {
            "batch_size": 10,
            "learning_rate": 0.002,
            "weight_decay": 0.0001,
            "momentum": 0.9,
            "epochs": 20,
            "n_ctx": 8,
            "ctx_init": "",
            "class_token_position": "end",
            "csc": False
        },
        "cocoop": {
            "report": True,
            "pat": False,
            "cnn_model": "ViT-B/16",
            "test_batch_size": 10,
            "train_base_checkpoint_path": None, # Cant have outside files in the jupyter
            "optimizer_configs": [
                {
                    "prompt_lr": 0.002,
                    "weight_decay": 0.0001,
                    "momentum": 0.9
                },
                {
                    "prompt_lr": 0.002,
                    "mlp_lr": 0.005,
                    "weight_decay": 0.0005,
                    "momentum": 0.8
                }
            ],
            "skip_tests": [True, True, False],
            "prompt_learner_opt": {
                "n_ctx": 4,
                "ctx_init": "",
                "class_token_position": "end",
                "csc": False
            },
            "kl_loss_opt": {
                "lambda_kl": [0.1, 0.1],
                "using_kl": [True, False]
            },
            "adv_training_opt": {
                "adv_training_epochs": 1,
                "batch_size": 10,
                "warmup_lambda_adv": 1,
                "lambda_adv": 3,
                "grl_lambda": 1,
                "mlp_opt": {
                    "hidden_structure": [563, 256, 128]
                },
                "prompt_learner_warmup_epochs": 3,
                "ignore_no_improvement": True,
                "use_bias_ctx": True
            },
            "base_training_opt": {
                "epochs": 1,
                "batch_size": 10,
                "warmup_epoch": 0,
                "warmup_cons_lr": 1e-5
            },
            "clustering_opt": {
                "n_clusters": 4,
                "variance": 0.95,
                "vision_encoder": "ViT-B/32",
                "clustering_type": "semantic"
            }
            }
        },]

    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
    debug = False
    use_coop = False

    for run_id, run_yaml in enumerate(runs_yaml):
        run_name = "jupyter_notebook_"+ run_yaml+ "_" + datetime.now().strftime("%Y%m%d_%H%M%S")

        print(f"Using device: {device}")

        if torch.backends.mps.is_available():
            print("\u26a0\ufe0f Forcing float32 due to MPS limitations")
            torch.set_default_dtype(torch.float32)

        print(f"Using {'CoOp' if use_coop else 'CoCoOp'} for training")

        config = runs_config[run_id] if isinstance(runs_config, list) else runs_config

        if use_coop:
            coop_cfg = config['coop']
            train_sys = CoOpSystem(
                device=device,
                run_name=run_name,
                **coop_cfg
            )
        else:
            cocoop_cfg = config['cocoop']
            train_sys = CoCoOpSystem(
                device=device,
                run_name=run_name,
                debug=debug,
                hparams_file=run_yaml,
                **cocoop_cfg
            )
        print(f"\nTraining system initialized with run name: {train_sys.run_name}")
        train_sys.train()

# Charts and results

## Main results'table

| **Method**                 | **KL V1** (λ_kl) | **KL V2** (λ_kl) | **Adversarial** | **n_ctx** | **Base** | **Harmonic Mean** | **Novel** | **Notes**                                                                 |
|---------------------------|------------------|------------------|------------------|----------|---------|-------------------|----------|------------------------------------------------------------------------------|
| CLIP zero-shot            |                  |                  |                  | 4        | 72.08   | 74.83             | **77.80** | Reported in CoCoOp paper [2]                                               |
| CoCoOp                    |                  |                  |                  | 4        | **94.87** | **81.71**         | 71.75    | Reported in CoCoOp paper [2]                                               |
|                           | ✅               |                  |                  | 4        | 63.69   | 67.62             | 72.06    | KLv1 baseline, 4 context tokens.                                           |
|                           |                  | ✅               |                  | 4        | 75.66   | 75.33             | 75.00    | KLv2 with λ_kl = 0.3, rotation period of 3 epochs.                         |
|                           |                  |                  | ✅               | 4        | **87.26** | 75.54             | 66.59    | Adversarial-only training 4 clusters.                                     |
|                           |                  | ✅               |                  | 4        | 77.23   | 75.68             | 74.18    | KLv2 with λ_kl = 0.1, rotation every 4 epochs.                             |
|                           |                  | ✅               |                  | 8        | 75.62   | 73.83             | 72.12    | KLv2 with 8 context tokens.                                                |
|                           |                  | ✅               | ✅               | 4        | 79.01   | 77.20             | **75.46** | Adversary input: `ctx + bias` (`ctx_shifted`).                             |
|                           |                  | ✅               | ✅               | 4        | 77.80   | 76.20             | 74.67    | Adversary input: `selected_text_feature + ctx_shifted.mean()`.             |
| Custom CoCoOp             |                  | ✅               | ✅               | 4        | 85.32   | **79.18**         | 73.86    | Adversary input: `avg_text_features + masked logits`.                      |

[2]: Yaroslav Ganin and Victor Lempitsky. *Unsupervised Domain Adaptation by Backpropagation*. 2015. arXiv: [1409.7495](https://arxiv.org/abs/1409.7495) [stat.ML].

## Part with Graph Plotting 
### repo clone is required, run the block below

In [None]:
#import the repo for accessing logs of the experiments, needed for visualization of the graphs
!git clone https://github.com/bacobax/DeepL-project /content/deepl_project
import sys
sys.path.insert(0, '/content/deepl_project')
%cd /content/deepl_project
# %pip install -r /content/deepl_project/req_clean.txt    
!git pull

In [None]:
import os
import math
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

def plot_tensorboard_scalars_grid(log_dirs: dict, scalar_names: list, max_cols: int = 2):
    """
    Plots TensorBoard scalar values from multiple runs in a grid layout.

    Args:
        log_dirs (dict): Mapping from label name (str) to log directory path (str).
        scalar_names (list): List of scalar tag names to plot (str).
        max_cols (int): Maximum number of columns in the grid layout.
    """
    num_scalars = len(scalar_names)
    cols = min(max_cols, num_scalars)
    rows = math.ceil(num_scalars / cols)

    fig, axs = plt.subplots(rows, cols, figsize=(6 * cols, 4 * rows))
    axs = axs.flatten() if num_scalars > 1 else [axs]

    for i, scalar_name in enumerate(scalar_names):
        ax = axs[i]
        found = False

        for label, log_dir in log_dirs.items():
            if not os.path.exists(log_dir):
                print(f"[WARN] Directory not found: {log_dir}. Skipping '{label}'.")
                continue

            event_acc = EventAccumulator(log_dir)
            try:
                event_acc.Reload()
            except Exception as e:
                print(f"[ERROR] Failed to load {log_dir}: {e}")
                continue

            if scalar_name not in event_acc.Tags().get('scalars', []):
                print(f"[INFO] Scalar '{scalar_name}' not found in '{label}'.")
                continue

            events = event_acc.Scalars(scalar_name)
            steps = [e.step for e in events]
            values = [e.value for e in events]

            ax.plot(steps, values, label=label)
            found = True

        ax.set_title(scalar_name)
        ax.set_xlabel("Step")
        ax.set_ylabel("Value")
        ax.grid(True)
        if found:
            ax.legend()
        else:
            ax.text(0.5, 0.5, "Scalar not found", ha='center', va='center')
            ax.set_axis_off()

    # Turn off any unused axes
    for j in range(len(scalar_names), len(axs)):
        fig.delaxes(axs[j])

    fig.suptitle("TensorBoard Scalars Comparison", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

def plot_novel_base_hmean_bar_chart(log_dirs: dict):
    """
    Creates a bar chart comparing 'novel_classes/accuracy', 'base_classes/accuracy',
    and their harmonic mean for multiple TensorBoard runs.

    Args:
        log_dirs (dict): Mapping from label name (str) to log directory path (str).
    """
    novel_values = []
    base_values = []
    hmean_values = []
    run_labels = []

    for label, log_dir in log_dirs.items():
        if not os.path.exists(log_dir):
            print(f"[WARN] Directory not found: {log_dir}. Skipping '{label}'.")
            continue

        event_acc = EventAccumulator(log_dir)
        try:
            event_acc.Reload()
        except Exception as e:
            print(f"[ERROR] Failed to load {log_dir}: {e}")
            continue

        tags = event_acc.Tags().get('scalars', [])
        if 'novel_classes/accuracy' not in tags or 'base_classes/accuracy' not in tags:
            print(f"[INFO] Missing scalars in '{label}'. Skipping.")
            continue

        novel_events = event_acc.Scalars('novel_classes/accuracy')
        base_events = event_acc.Scalars('base_classes/accuracy')

        novel = novel_events[-1].value if novel_events else 0
        base = base_events[-1].value if base_events else 0

        # Harmonic mean: 2 * (base * novel) / (base + novel)
        hmean = 2 * base * novel / (base + novel) if (base + novel) != 0 else 0

        novel_values.append(novel)
        base_values.append(base)
        hmean_values.append(hmean)
        run_labels.append(label)

    # Plotting
    x = np.arange(len(run_labels))
    width = 0.25

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - width, base_values, width, label='Base Accuracy')
    ax.bar(x,         novel_values, width, label='Novel Accuracy')
    ax.bar(x + width, hmean_values, width, label='Harmonic Mean')

    ax.set_ylabel('Accuracy')
    ax.set_title('Base / Novel / Harmonic Mean Accuracy per Run')
    ax.set_xticks(x)
    ax.set_xticklabels(run_labels, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, axis='y')
    plt.tight_layout()
    plt.show()

## Training metrics for most relevant runs in the non-adversarial training

In [None]:
run_names = {
    'kl v2 kl0.3' : './runs/report_no_pat/from_yaml_base_kl_v2_80_20_kl_03_rot_period_3_4_ctx_balanced_20250720_134756',
    'kl v2 kl0.1' :  './runs/report_no_pat/from_yaml_base_kl_v2_80_20_kl_01_rot_period_4_4_ctx_base_acc_20250720_125637',
    'kl v1' : './runs/report_no_pat/from_yaml_base_kl_v1_4_ctx_20250720_073108',
    'kl v2 rotation period 6 ep' : './runs/report_no_pat/from_yaml_base_kl_v2_80_20_kl_03_rot_period_rel_4_ctx_novel_acc_20250720_133801',
    'kl v2 kl0.1 8 ctx' : './runs/report_no_pat/from_yaml_base_kl_v2_80_20_kl_01_rot_period_4_8_ctx_20250719_125847',
}
scalar_names = [

    'validation_novel/accuracy',
    'validation_base/accuracy',
]


In [None]:
plot_tensorboard_scalars_grid(run_names, scalar_names, max_cols=2)
plot_novel_base_hmean_bar_chart(run_names)

### KL Divergence Training (KLv1 vs. KLv2)

**KLv2** clearly outperforms **KLv1** in both base and novel accuracy, thanks to its staged curriculum. In KLv2, a full epoch is first used to train on seen classes with cross-entropy (CE), then another full epoch for distillation on unseen classes using the KL loss. This ensures sufficient class diversity and a stable signal for both objectives.

In contrast, KLv1 splits pseudo-base and pseudo-novel classes within each mini-batch. This can result in limited or imbalanced class distributions, especially when batch sizes are small, weakening the effectiveness of KL training. Additionally, KLv1 tends to overfit after 4–5 epochs, leading to degradation in novel accuracy. KLv2 is more robust, converging quickly and tolerating longer training schedules.

### Rotation Periods

We evaluated several schedules for rotating the seen/unseen (pseudo-base/pseudo-novel) class split: every 1, 3, 4, 6 epochs, and random. While all configurations yield comparable results (harmonic mean between 0.71–0.76), rotating every 3–4 epochs performs best overall. Rotation every epoch is too frequent, not allowing enough time to adapt to a fixed pseudo-novel set. On the other hand, rotation every 6 epochs can lead to higher novel accuracy than base, suggesting underfitting on base classes.

### KL Weight

Adjusting the KL loss weight (`λ_KL`) impacts the balance between base and novel performance. A value of **0.3** favors novel accuracy but may reduce base performance due to gradient conflict between CE and KL objectives. A lower value like **0.1** improves base accuracy and often yields a higher harmonic mean. We found `λ_KL = 0.3` to offer a reasonable trade-off.

### Rotation Periods Comparison

We tested various rotation intervals for changing the seen/unseen class splits:

- every epoch  
- every 3 epochs  
- every 4 epochs  
- every 6 epochs  
- random rotation  

All approaches produced similar harmonic means (between **0.707** and **0.757**), but the best results were achieved when rotating every **3 or 4 epochs**.

- Rotating every epoch is too frequent, limiting time for the model to distill zero-shot knowledge.  
- Rotating every 6 epochs may cause underfitting on base classes, as the model focuses too long on a fixed unseen set.  
- Interestingly, rotating every 6 epochs tends to result in higher novel accuracy than base accuracy.

### `λ_KL = 0.3` vs `λ_KL = 0.1`

Increasing `λ_KL` emphasizes the objective of mimicking CLIP’s zero-shot predictions on unseen classes, which may lead to decreased performance on base classes.

This is likely because gradients from the CE loss and KL loss can be orthogonal, reducing alignment during optimization.

- `λ_KL = 0.3` strikes the best balance between base and novel accuracy.  
- `λ_KL = 0.1` achieves higher base accuracy and harmonic mean compared to the `λ = 0.3` setting.

## Training metrics for most relevant mixed adversarial/pretrain+adversarial runs

In [None]:
"""
Niente -> avg_text_feat + logits
Noise-> avg_text_feat + noisy logits
bias -> selected_text_feat + ctx_shifted.mean()




BIAS FATTO LA NOTTE TRA IL 20 e 21 era diverso e non funzionava, mi serve solo per dire che non funziona
bias -> ctx_shifted (as ctx_shifted = ctx+bias)
"""
run_names = {

    'l-adv3.0+avg_txt_ft+noisy_logits':'./runs/report_no_pat/from_yaml_kl_pretrain_adv_4_ctx_4_clusters_adv_3_noise_20250720_215302',
    'l-adv1.0+avg_txt_ft+logits':'./runs/report_no_pat/from_yaml_kl_pretrain_adv_4_ctx_4_clusters_20250720_164121',
    'l-adv3.0+post_metanet_ctx' : './runs/report_no_pat/from_yaml_kl_pretrain_adv_4_ctx_4_clusters_adv_3_bias_20250721_004545',
    'l-adv3.0+avg_post_metanet_ctx+selected_txt_ft':'./runs/report_no_pat/from_yaml_kl_pretrain_adv_4_ctx_4_clusters_adv_3_bias_20250721_130018',
    'all adv 4 clusters' : './runs/report_no_pat/from_yaml_all_adv_4_ctx_4_clusters_20250721_104046',
}
scalar_names = [

    'validation_novel/accuracy',
    'validation_base/accuracy',
    'train_adv/mlp_loss',
]



In [None]:
plot_tensorboard_scalars_grid(run_names, scalar_names, max_cols=2)
plot_novel_base_hmean_bar_chart(run_names)

## Adversarial Training

Adversarial training, when applied **after KLv2**, significantly improves **test base accuracy** (up to 0.91) while maintaining **competitive test novel performance** (e.g., 0.71).  
However, when used **in isolation**, it tends to **converge more slowly** and exhibits **instability**.  
Using a **higher value of** `λ_adv` (e.g., 3 vs. 1) improves the **stability of novel accuracy** across epochs, particularly in **longer training schedules** (e.g., 60+ epochs).

### Discriminator's Input Impact

- **With logits:**  
  Feeding logits directly to the discriminator can be considered a form of "cheating", as the model may simply learn the mapping `recognized_category → cluster`, bypassing any visual feature representation.  
  To address this, **Gaussian noise** is added to the logits:
  - Without noise: overfitting to base classes often occurs.
  - With noise: novel accuracy remains slightly more stable, but the model still overfits to base classes.

- **Without logits:**  
  Learning meaningful cluster mappings from alternative features is challenging.  
  We tested feeding the **MetaNet-biased prompt context** directly:
  - Direct input resulted in poor performance (accuracy ≈ 1.33 ≈ random).
  - Best results occurred when averaging the **MetaNet-biased prompt** with the **text feature of the predicted category** (from softmax output).  

  This allowed the discriminator to learn category → cluster mappings *without relying on logits*.

Nevertheless, performance gains were marginal and comparable to setups where the discriminator did not learn at all.  
This suggests that **any update to the prompt learner weights** may degrade the discriminator’s effectiveness similarly to exposing it directly to logits.

### Limitations

Training the model to convergence using KLv2 (`λ = 0.3`) plus adversarial learning (4 clusters, `λ_adv = 1.0`) results in:

- Base accuracy ≈ 0.90  
- Novel accuracy ≈ 0.71  

These results align with CoCoOp paper benchmarks, with this configuration yielding the **highest harmonic mean**.

---

*Open question:*  
If updates to prompt learner weights impair discriminator performance,  
**Could such degradation lead to improved generalization on novel categories?**  
Our adversarial learning hypothesis suggests: **Yes**.

## Overall View

In [None]:
run_names = {
    'kl v2 kl0.3' : './runs/report_no_pat/from_yaml_base_kl_v2_80_20_kl_03_rot_period_3_4_ctx_balanced_20250720_134756',
    'kl v2 kl0.1' :  './runs/report_no_pat/from_yaml_base_kl_v2_80_20_kl_01_rot_period_4_4_ctx_base_acc_20250720_125637',
    'kl v1' : './runs/report_no_pat/from_yaml_base_kl_v1_4_ctx_20250720_073108',
    'kl v2 rotation period 6 ep' : './runs/report_no_pat/from_yaml_base_kl_v2_80_20_kl_03_rot_period_rel_4_ctx_novel_acc_20250720_133801',
    'kl v2 kl0.1 8 ctx' : './runs/report_no_pat/from_yaml_base_kl_v2_80_20_kl_01_rot_period_4_8_ctx_20250719_125847',
    'l-adv3.0+avg_txt_ft+noisy_logits':'./runs/report_no_pat/from_yaml_kl_pretrain_adv_4_ctx_4_clusters_adv_3_noise_20250720_215302',
    'l-adv1.0+avg_txt_ft+logits':'./runs/report_no_pat/from_yaml_kl_pretrain_adv_4_ctx_4_clusters_20250720_164121',
    'l-adv3.0+post_metanet_ctx' : './runs/report_no_pat/from_yaml_kl_pretrain_adv_4_ctx_4_clusters_adv_3_bias_20250721_004545',
    'l-adv3.0+avg_post_metanet_ctx+selected_txt_ft':'./runs/report_no_pat/from_yaml_kl_pretrain_adv_4_ctx_4_clusters_adv_3_bias_20250721_130018',
    'all adv 4 clusters' : './runs/report_no_pat/from_yaml_all_adv_4_ctx_4_clusters_20250721_104046',
}
scalar_names = [
    'validation_novel/accuracy',
    'validation_base/accuracy',
    'train_adv/mlp_loss',
]



In [None]:
plot_tensorboard_scalars_grid(run_names, scalar_names, max_cols=2)
plot_novel_base_hmean_bar_chart(run_names)

### Impact of Prompt Length
Prompt length influences generalization. Shorter prompts ($n_{\text{ctx}} = 4$) outperform longer ones ($n_{\text{ctx}} = 8$), particularly on novel classes. Longer prompts offer more capacity but tend to overfit faster.




# Conclusion

## Summary

In this project, we explored ways to improve CoCoOp, a method that builds on CLIP to help generalize better in zero-shot image classification tasks. Our focus was on two main strategies: a curriculum learning method based on KL divergence, and a second training phase that uses an adversarial MLP network.

We developed a two-stage training approach (which we called **KLv2**) that slowly shifts supervision from base classes to pseudo-novel ones. Compared to a simpler, single-stage version (**KLv1**), KLv2 led to better accuracy and faster, more stable training.

We also experimented with **adversarial training**, where we used a small MLP network to predict cluster labels. While this didn’t boost novel class performance on its own, it worked well as a **regularizer when used after KLv2**, improving base class accuracy and helping prevent overfitting.

Throughout the project, we looked closely at how different design choices—like the **length of the context prompt** or the **type of features fed into the MLP**—affect the results. We also confirmed that using **data augmentation similar to what CLIP was trained on** helped improve generalization to novel classes.

**Overall**, this project showed how thoughtful training strategies and architectural tweaks can make prompt learning systems like CoCoOp more robust. The combination of curriculum learning and regularization turned out to be especially effective for balancing performance between base and novel classes.

## Future Work

In future work, we aim to **enhance the adversarial training component** by replacing the current MLP with a **lightweight transformer architecture**. This would allow better modeling of the interactions between context tokens and potentially improve the adversary’s effectiveness. Additionally, we consider **dividing the available training data into specialized subsets** to enable targeted training strategies.

One promising direction is to **incorporate contrastive loss functions** on the prompt embeddings. Contrastive learning encourages the model to bring similar class prompts closer in the feature space while pushing apart dissimilar ones. This approach could help improve the **representation quality and generalization ability**, especially for novel classes, by explicitly enforcing class-discriminative embeddings.

## References

[1^] Yaroslav Ganin and Victor Lempitsky. *Unsupervised Domain Adaptation by Backpropagation*. 2015. arXiv: [1409.7495](https://arxiv.org/abs/1409.7495) [stat.ML].  
[2^] Alec Radford et al. *Learning Transferable Visual Models From Natural Language Supervision*. 2021. arXiv: [2103.00020](https://arxiv.org/abs/2103.00020) [cs.CV].  
[3^] Kaiyang Zhou et al. *Conditional Prompt Learning for Vision-Language Models*. 2022. arXiv: [2203.05557](https://arxiv.org/abs/2203.05557) [cs.CV].  
[4^] Kaiyang Zhou et al. “Learning to Prompt for Vision-Language Models”. In: *CoRR* abs/2109.01134 (2021). arXiv: [2109.01134](https://arxiv.org/abs/2109.01134).