# CLIP Weighted Model Mixture - Few-Shot Learning Technique

This notebook presents a novel strategy for few-shot adaptation using CLIP.

Our approach involves the construction of multiple models whose outputs are combined through a weighted average, yielding an aggregated prediction generally more accurate than any single model.

This method is characterized by the following aspects:
1. The models are not trained; instead, they serve as prompts fed into the CLIP text encoder.
2. Prompts are manually designed based on domain-specific knowledge.
3. A weighted average is computed, where the weights are predicted by a _Re-Weighting Model_—a simple neural network that takes CLIP image embeddings as input.

## Installing and Importing dependency

In [1]:
%pip install openai_clip

from torchvision.datasets import Flowers102
import random
import torch
import clip
from tqdm import tqdm
from typing import TypeVar, cast, TypedDict, Generic
import random
from torch.utils.data import Dataset, Subset, DataLoader
from typing import Callable
from torch import nn
import math
import copy
import numpy as np
import clip.model
from functools import lru_cache
from torch.optim.lr_scheduler import StepLR
from typing import Literal

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


### Settings

Configurable parameters of the notebook

In [2]:
# Whether to use the default train-test spit provided by torchvision or not
# > *If set to True*: we will have the default split where each training class has 10 samples
# > *if set to False* the number of shot will be defined by the `NUMBER_OF_SHOTS` constants
#   this is done to make it easier to compare our algorithm with other papers, that often uses 16 as the default number of shots
USE_DEFAULT_SPLIT = False
# The number of shot to use in learning. Only has effect if `USE_DEFAULT_SPLIT` is set to False
NUMBER_OF_SHOTS = 16

# the device to use for training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# underlying model to use for the clip visual encoder
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
CLIP_MODEL="ViT-B/16"

# The noise to add on top of the image embedding.
# We noted that adding noise increases the accuracy of novel and based classes when evaluated separately
# at minor accuracy cost when evaluating base and novel classes together.
NOISE: float | Literal[False] = 0.2


# The batch size used.
# We noted high batch size are better for training, however due to memory constraint
# we used gradient accumulation every BATCH_SIZE_MULTIPLIER batch to simulate larger sizes
BATCH_SIZE: int = 32
BATCH_SIZE_MULTIPLIER: int = 4

# Number of training steps
NUM_STEPS = 100

# The parameters for the optimizer
LR: float = 5
MOMENTUM: float = 0.5
SCHEDULER_STEP: int = 50
GAMMA: float = 0.9
CLIP_GRADIENT_VALUE = 5

# Number of neurons in the re-weighter's intermediate layer
RE_WEIGHTER_L1_SIZE: int = 75


### Type definitions

In [3]:
class CategoryLabel(TypedDict):
    id: int
    name: str
    novel: bool

### Dataset loading

Functions for dataset loading

In [4]:
def get_data(data_dir="./data", transform=None):
    """
    Load Flowers102 train, validation and test sets.
    Uses the default split provided by torchvision.

    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return (
        cast(Dataset[tuple[torch.Tensor,int]], train),
        cast(Dataset[tuple[torch.Tensor,int]], val),
        cast(Dataset[tuple[torch.Tensor,int]], test)
    )

def get_data_custom_split(data_dir="./data", transform=None):
    """
    Load Flowers102 train, validation and test sets.
    Uses a custom split that allow to specify the number of items per class that will be assigned to the train set.
    The validation set will always be empty.

    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    
    a,b,c = get_data(data_dir, transform)
    full_dataset: Dataset[tuple[torch.Tensor, int]] = torch.utils.data.ConcatDataset([a,b,c])

    labels_set = set(l for _,l in full_dataset)
    class_to_index_dict: dict[int, list[int]] = {l: [] for l in labels_set}

    for i in range(len(full_dataset)):
        l = full_dataset[i][1]
        class_to_index_dict[l].append(i)

    train: list[int] = []
    test: list[int] = []

    for indexes in class_to_index_dict.values():
        random.shuffle(indexes)
        train += indexes[0:NUMBER_OF_SHOTS]
        test += indexes[NUMBER_OF_SHOTS:]

    train_dataset = torch.utils.data.Subset(full_dataset, train)
    validation_dataset = torch.utils.data.Subset(full_dataset, [])
    test_dataset = torch.utils.data.Subset(full_dataset, test)
    return (
        cast(Dataset[tuple[torch.Tensor,int]], train_dataset),
        cast(Dataset[tuple[torch.Tensor,int]], validation_dataset),
        cast(Dataset[tuple[torch.Tensor,int]], test_dataset)
    )


## Base and novel categories
Function definition to split the dataset into novel and base classes

In [5]:
T = TypeVar("T")
E = TypeVar("E")
def base_novel_categories(dataset: Dataset[tuple[T,E]]):
    # set returns the unique set of all dataset classes
    all_classes = set(l for _, l in dataset)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

## Split dataset
The dataset is partitioned into base and novel categories based on the `base_novel_categories` definition. This operation requires both the dataset and the list of base classes. Each sample is assigned to the base set if its label belongs to the base categories; otherwise, it is assigned to the novel set.

In [6]:

T = TypeVar("T")
E = TypeVar("E")
def split_data(dataset: Dataset[tuple[T,E]], base_classes: list[int]):
    # these two lists will store the sample indexes
    base_categories_samples: list[int] = []
    novel_categories_samples: list[int] = []

    # set with the base classes (so that checking existence is O(1))
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    label: int
    for sample_id, (_, label) in enumerate(dataset): #type: ignore
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = Subset(dataset, base_categories_samples)
    novel_dataset = Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Load CLIP

In [7]:
model, preprocess = clip.load(CLIP_MODEL, device=DEVICE)

### Defining the prompt templates

A template is a function that takes as input a class and returns a description of the underlying object.

They are divided into two categories:  
1. **General prompts**: generic prompts that can be applied to any class.  
2. **Class-specific prompts**: category-dependent prompts.

Prompt-class mappings are computed by the re-weighter model, hence the absence of explicit code for this task.

In [8]:

#####################################################
################# GENERAL PROMPTS ###################
#####################################################
general_prompt_template: list[Callable[[str], str]] = [
    lambda x: f"a photo of a {x}, a type of flower.",
    lambda x: f"a photo of some {x}, a type of flower.",
    lambda x: f"a close-up of a {x} flower.",
    lambda x: f"an image of a {x} blossom.",
    lambda x: f"a beautiful {x} in bloom.",
    lambda x: f"a bunch of {x} flowers.",
    lambda x: f"a macro shot of a {x} flower.",
    lambda x: f"a single {x} flower.",
    lambda x: f"fresh {x} flowers in a garden.",
]

#####################################################
############## CLASS SPECIFIC PROMPTS ###############
#####################################################

class_specific_prompt_templates: list[Callable[[str], str]] = [
    # pink primrose
    lambda x: f"a photo of a {x}, a delicate flower with soft pink petals.",
    lambda x: f"an image of a {x}, a blooming plant often found in springtime gardens.",
    lambda x: f"a close-up of a {x}, known for its pale pink blossoms and gentle appearance.",

    # hard-leaved pocket orchid
    lambda x: f"a photo of a {x}, a tropical orchid with stiff, glossy leaves.",
    lambda x: f"an image of a {x}, an exotic flower with waxy petals and leathery foliage.",
    lambda x: f"a botanical image of a {x}, an orchid species with hard, durable leaves.",

    # canterbury bells
    lambda x: f"a photo of a {x}, a bell-shaped flower in shades of purple and blue.",
    lambda x: f"an image of a {x}, known for its tall spikes of bell-like blossoms.",
    lambda x: f"a close-up of a {x}, a cottage garden flower with cup-shaped blooms.",

    # sweet pea
    lambda x: f"a photo of a {x}, a fragrant flower with delicate, ruffled petals.",
    lambda x: f"an image of a {x}, often grown for its pastel colors and pleasant scent.",
    lambda x: f"a close-up of a {x}, a climbing plant with butterfly-shaped flowers.",

    # english marigold
    lambda x: f"a photo of a {x}, a bright orange or yellow flower with daisy-like blooms.",
    lambda x: f"an image of a {x}, known for its healing properties and sunny appearance.",
    lambda x: f"a close-up of a {x}, a calendula flower common in herb gardens.",

    # tiger lily
    lambda x: f"a photo of a {x}, an orange flower with dark spots and recurved petals.",
    lambda x: f"an image of a {x}, a wild-looking lily with dramatic coloring.",
    lambda x: f"a close-up of a {x}, known for its bold, tiger-striped blooms.",

    # moon orchid
    lambda x: f"a photo of a {x}, an elegant white orchid with a moon-like glow.",
    lambda x: f"an image of a {x}, a phalaenopsis flower often found in tropical climates.",
    lambda x: f"a close-up of a {x}, a soft and symmetrical flower with wide petals.",

    # bird of paradise
    lambda x: f"a photo of a {x}, a tropical flower resembling a colorful bird.",
    lambda x: f"an image of a {x}, known for its bright orange and blue petals.",
    lambda x: f"a close-up of a {x}, an exotic bloom that looks like a flying bird.",

    # monkshood
    lambda x: f"a photo of a {x}, a hooded purple flower with toxic properties.",
    lambda x: f"an image of a {x}, often called wolfsbane, with dark violet petals.",
    lambda x: f"a close-up of a {x}, a tall plant with helmet-shaped blooms.",

    # globe thistle
    lambda x: f"a photo of a {x}, a spherical flower with spiky blue petals.",
    lambda x: f"an image of a {x}, known for its round shape and thistle-like texture.",
    lambda x: f"a close-up of a {x}, a unique ornamental flower with a metallic hue.",

    # snapdragon
    lambda x: f"a photo of a {x}, a colorful flower that resembles a dragon's mouth.",
    lambda x: f"an image of a {x}, known for its vertical clusters of blooming petals.",
    lambda x: f"a close-up of a {x}, a common garden flower with hinged, snout-like blooms.",

    # colt's foot
    lambda x: f"a photo of a {x}, a yellow wildflower that appears before its leaves.",
    lambda x: f"an image of a {x}, a small flower resembling a dandelion in early spring.",
    lambda x: f"a close-up of a {x}, a plant with hoof-shaped leaves and bright yellow blooms.",

    # king protea
    lambda x: f"a photo of a {x}, a large flower with a spiky crown-like appearance.",
    lambda x: f"an image of a {x}, a South African bloom with a central cone and pink petals.",
    lambda x: f"a close-up of a {x}, a striking flower often used in bold arrangements.",

    # spear thistle
    lambda x: f"a photo of a {x}, a spiny plant with purple tufted blooms.",
    lambda x: f"an image of a {x}, known for its sharp leaves and thistle head.",
    lambda x: f"a close-up of a {x}, a wildflower with prickly stems and a vibrant purple flower.",

    # yellow iris
    lambda x: f"a photo of a {x}, a bright yellow iris with upright petals.",
    lambda x: f"an image of a {x}, commonly found near water, with sword-shaped leaves.",
    lambda x: f"a close-up of a {x}, an elegant flower with golden hues and frilled edges.",

    # globe-flower
    lambda x: f"a photo of a {x}, a round yellow bloom resembling a buttercup.",
    lambda x: f"an image of a {x}, a spherical flower found in alpine meadows.",
    lambda x: f"a close-up of a {x}, a glowing, globe-shaped flower with dense petals.",

    # purple coneflower
    lambda x: f"a photo of a {x}, a daisy-like flower with purple petals and a spiky cone.",
    lambda x: f"an image of a {x}, often used in herbal remedies and garden borders.",
    lambda x: f"a close-up of a {x}, known for its downward-sloping petals and orange center.",

    # peruvian lily
    lambda x: f"a photo of a {x}, a spotted flower with multiple colorful petals.",
    lambda x: f"an image of a {x}, a long-lasting bloom used in cut flower arrangements.",
    lambda x: f"a close-up of a {x}, a lily-like flower with striped inner petals.",

    # balloon flower
    lambda x: f"a photo of a {x}, a flower bud that inflates like a balloon before opening.",
    lambda x: f"an image of a {x}, a star-shaped bloom in shades of blue or purple.",
    lambda x: f"a close-up of a {x}, a unique flower with puffy unopened buds.",

    # giant white arum lily
    lambda x: f"a photo of a {x}, a large white flower with a trumpet-like shape.",
    lambda x: f"an image of a {x}, also known as a calla lily, with a central yellow spadix.",
    lambda x: f"a close-up of a {x}, an elegant flower with smooth white petals.",

    # fire lily
    lambda x: f"a photo of a {x}, a bright red or orange lily with curled petals.",
    lambda x: f"an image of a {x}, a dramatic flower known for its flame-like appearance.",
    lambda x: f"a close-up of a {x}, a fiery-looking lily with backward-bending petals.",

    # pincushion flower
    lambda x: f"a photo of a {x}, a flower with a domed center and delicate fringe.",
    lambda x: f"an image of a {x}, often purple or lavender, resembling a pin-filled cushion.",
    lambda x: f"a close-up of a {x}, known for its central disk and lace-like petals.",

    # fritillary
    lambda x: f"a photo of a {x}, a checkered bell-shaped flower often in purple tones.",
    lambda x: f"an image of a {x}, a rare flower with a distinctive petal pattern.",
    lambda x: f"a close-up of a {x}, a delicate wildflower with hanging, nodding blooms.",

    # red ginger
    lambda x: f"a photo of a {x}, a tropical flower with bright red bracts.",
    lambda x: f"an image of a {x}, known for its bold color and upright flower spikes.",
    lambda x: f"a close-up of a {x}, a striking plant native to rainforests.",

    # grape hyacinth
    lambda x: f"a photo of a {x}, a small bulbous plant with clusters of blue-purple flowers.",
    lambda x: f"an image of a {x}, resembling tiny grapes arranged on a spike.",
    lambda x: f"a close-up of a {x}, a spring flower with densely packed florets.",

    # corn poppy
    lambda x: f"a photo of a {x}, a bright red flower with papery petals and a dark center.",
    lambda x: f"an image of a {x}, a wild poppy often found in meadows and fields.",
    lambda x: f"a close-up of a {x}, a flower symbolizing remembrance and resilience.",

    # prince of wales feathers
    lambda x: f"a photo of a {x}, a spiky flower head resembling a feathery plume.",
    lambda x: f"an image of a {x}, known for its upright purple-pink floral spikes.",
    lambda x: f"a close-up of a {x}, a member of the amaranth family with plume-like blooms.",

    # stemless gentian
    lambda x: f"a photo of a {x}, a vivid blue flower growing close to the ground.",
    lambda x: f"an image of a {x}, a low-growing gentian with trumpet-shaped petals.",
    lambda x: f"a close-up of a {x}, a mountain flower with intensely blue blossoms.",

    # artichoke
    lambda x: f"a photo of a {x}, a thistle-like plant with edible buds and purple flowers.",
    lambda x: f"an image of a {x}, a spiky flower head that blooms into a vibrant violet.",
    lambda x: f"a close-up of a {x}, a large budding flower with a layered appearance.",

    # sweet william
    lambda x: f"a photo of a {x}, a cluster of small flowers in pink, red, or white.",
    lambda x: f"an image of a {x}, a garden flower known for its fringed petal edges.",
    lambda x: f"a close-up of a {x}, a fragrant bloom often used in cottage gardens.",

    # carnation
    lambda x: f"a photo of a {x}, a ruffled flower commonly seen in bouquets.",
    lambda x: f"an image of a {x}, a traditional bloom with a clove-like scent.",
    lambda x: f"a close-up of a {x}, known for its layered petals and vibrant color range.",

    # garden phlox
    lambda x: f"a photo of a {x}, a tall plant with clusters of pink or purple blooms.",
    lambda x: f"an image of a {x}, a perennial flower found in cottage-style gardens.",
    lambda x: f"a close-up of a {x}, a phlox with star-shaped petals growing in dense bunches.",

    # love in the mist
    lambda x: f"a photo of a {x}, a blue flower surrounded by fine, feathery foliage.",
    lambda x: f"an image of a {x}, a delicate flower with a misty, lace-like background.",
    lambda x: f"a close-up of a {x}, a whimsical bloom with soft, threadlike leaves.",

    # mexican aster
    lambda x: f"a photo of a {x}, a daisy-like flower with bright pink or purple petals.",
    lambda x: f"an image of a {x}, a tall annual bloom often seen in wildflower fields.",
    lambda x: f"a close-up of a {x}, a lightweight flower with yellow centers and soft petals.",

    # alpine sea holly
    lambda x: f"a photo of a {x}, a spiky blue flower with thistle-like bracts.",
    lambda x: f"an image of a {x}, a unique alpine plant with metallic-colored petals.",
    lambda x: f"a close-up of a {x}, a flower with a cone center and pointed star-shaped sepals.",

    # ruby-lipped cattleya
    lambda x: f"a photo of a {x}, an orchid with bold purple lips and pastel petals.",
    lambda x: f"an image of a {x}, a showy flower with ruby-colored accents on its lip.",
    lambda x: f"a close-up of a {x}, a fragrant orchid with elaborate ruffled petals.",

    # cape flower
    lambda x: f"a photo of a {x}, a brightly colored South African flower with daisy form.",
    lambda x: f"an image of a {x}, known for its vivid hues and sun-tracking behavior.",
    lambda x: f"a close-up of a {x}, a vibrant flower with a dark center and radiating petals.",

    # great masterwort
    lambda x: f"a photo of a {x}, a flower with a central cluster surrounded by papery bracts.",
    lambda x: f"an image of a {x}, a perennial bloom with intricate star-like heads.",
    lambda x: f"a close-up of a {x}, a soft-colored flower with detailed florets in the center.",

    # siam tulip
    lambda x: f"a photo of a {x}, a tropical flower with pink petals and green bracts.",
    lambda x: f"an image of a {x}, also known as Curcuma, with cone-shaped blossoms.",
    lambda x: f"a close-up of a {x}, a vibrant Thai flower with tulip-like structure.",

    # lenten rose
    lambda x: f"a photo of a {x}, a spring flower with nodding blooms and muted colors.",
    lambda x: f"an image of a {x}, a hellebore plant with leathery leaves and soft petals.",
    lambda x: f"a close-up of a {x}, a cold-hardy flower blooming in late winter to early spring.",

    # barbeton daisy
    lambda x: f"a photo of a {x}, a brightly colored daisy native to South Africa.",
    lambda x: f"an image of a {x}, known for its vivid red, orange, or pink petals.",
    lambda x: f"a close-up of a {x}, a cheerful flower with a prominent central disk.",

    # daffodil
    lambda x: f"a photo of a {x}, a trumpet-shaped flower with yellow or white petals.",
    lambda x: f"an image of a {x}, one of the first blooms of spring with a central corona.",
    lambda x: f"a close-up of a {x}, a classic bulb flower symbolizing renewal and hope.",

    # sword lily
    lambda x: f"a photo of a {x}, a tall flower with sword-like leaves and vertical blooms.",
    lambda x: f"an image of a {x}, commonly known as gladiolus with stacked florets.",
    lambda x: f"a close-up of a {x}, a showy flower in a rainbow of colors on long spikes.",

    # poinsettia
    lambda x: f"a photo of a {x}, a festive plant with red or white leaf-like bracts.",
    lambda x: f"an image of a {x}, often used in winter displays with green foliage and colorful tops.",
    lambda x: f"a close-up of a {x}, a holiday flower with bright petal-like leaves.",

    # bolero deep blue
    lambda x: f"a photo of a {x}, a compact flower with deep violet petals and ruffled texture.",
    lambda x: f"an image of a {x}, a variety of pansy known for its rich blue coloring.",
    lambda x: f"a close-up of a {x}, a velvety flower with intricate patterns and deep hues.",

    # wallflower
    lambda x: f"a photo of a {x}, a small clustered flower known for growing on walls or rocky soil.",
    lambda x: f"an image of a {x}, a plant with fragrant blooms in yellow, orange, or red.",
    lambda x: f"a close-up of a {x}, a simple flower with four-petal blossoms and warm colors.",

    # marigold
    lambda x: f"a photo of a {x}, a vibrant flower with layers of orange or yellow petals.",
    lambda x: f"an image of a {x}, known for its strong scent and decorative garden use.",
    lambda x: f"a close-up of a {x}, a sun-loving flower with round, bushy blooms.",

    # buttercup
    lambda x: f"a photo of a {x}, a shiny yellow flower with cup-shaped petals.",
    lambda x: f"an image of a {x}, a wildflower with a simple structure and glossy surface.",
    lambda x: f"a close-up of a {x}, a cheerful bloom with golden overlapping petals.",

    # oxeye daisy
    lambda x: f"a photo of a {x}, a white-petaled daisy with a yellow central disk.",
    lambda x: f"an image of a {x}, a classic meadow flower with a tall, slender stem.",
    lambda x: f"a close-up of a {x}, a widespread wildflower resembling a common daisy.",

    # common dandelion
    lambda x: f"a photo of a {x}, a yellow flower with toothed leaves and fluffy seed heads.",
    lambda x: f"an image of a {x}, a weedy plant known for its puffball seed dispersal.",
    lambda x: f"a close-up of a {x}, a golden flower head made of many tiny florets.",

    # petunia
    lambda x: f"a photo of a {x}, a funnel-shaped flower often used in hanging baskets.",
    lambda x: f"an image of a {x}, a colorful bloom available in many vibrant shades.",
    lambda x: f"a close-up of a {x}, a soft, velvety flower with a wide trumpet-like shape.",

]

### Text feature loading function

CLIP's text processing pipeline operates as follows:
 - A matrix of strings is generated with shape `number of prompts × number of classes`.
 - The matrix is tokenized.
 - The tokenized prompts are passed through the text encoder to obtain the embeddings.

The resulting output is a matrix of shape `number of prompts × embedding size × number of classes`.

#### Difference between base and novel class embeddings

When generating prompts for novel classes, general prompts are repeated to match the required vector size.

Initially, a padding prompt was used. However, the re-weighter learned to assign, on average, higher weights to the padding class to compensate for this design.

It was later observed that replacing padding prompts with well-crafted prompts led to improved accuracy. We hypothesize that this improvement arises from reducing constraints on the re-weighter, thereby facilitating the learning process.


In [9]:
@torch.no_grad()
def load_text_features(categories: list[CategoryLabel]):
    """
    return size: num_prompts x num_categories x input_size"]
    """

    # prompts for base classes (include both general and class specific templates)
    prompts_for_base_classes = general_prompt_template + class_specific_prompt_templates
    
    # prompts for novel classes (are obtained with the repetition of general templates)
    multiplier = len(prompts_for_base_classes) // len(general_prompt_template) + 1
    prompts_for_novel_classes = general_prompt_template * multiplier
    prompts_for_novel_classes = prompts_for_novel_classes[0:len(prompts_for_base_classes)]

    def get_prompts_for_one_class(category: CategoryLabel) -> list[str]:
        if category["novel"]:
            prompts = prompts_for_novel_classes
        else:
            prompts = prompts_for_base_classes
        name = category["name"]
        return [t(name) for t in prompts ]
    

    text_inputs = [
        clip.tokenize(
            get_prompts_for_one_class(category)
        ).to(DEVICE)
        for category in categories
    ]

    text_features_array: list[torch.Tensor] = [
        model.encode_text(x)
        for x in text_inputs
    ]

    # shape: num_classes x num_prompts x embedding_size
    text_features = torch.stack([
        x/x.norm(dim=-1,keepdim=True)
        for x in text_features_array
    ])

    # shape: num_prompts x embedding_size x num_classes
    text_features = text_features.permute(1,2,0)

    return text_features

### Loading the dataset

In [10]:
if USE_DEFAULT_SPLIT:
    train_set, val_set, test_set = get_data(transform=preprocess)
else:
    train_set, val_set, test_set = get_data_custom_split(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

### Re-weighter model

A lightweight feedforward neural network (FFNN) designed as follows:  
- **Input**: a 2D tensor representing the visual embedding of an image (i.e., the output of the visual encoder).  
- **Output**: a 3D tensor of shape `batch size × number of prompts × number of classes`.

The output is used to re-weight the prompts and compute a single classification score.

#### Mathematical formulation

Notation:  
- $T_{ij}$: the embedding of class `j` generated using prompt template `i`.  
- $V$: the embedding of the image to classify.  
- $W = \text{Reweighter}(V)$: the output of the re-weighter, shaped `(number of prompts × number of classes)`.  
- $w_{ij}$: the weight assigned to template `i` and class `j`, used to classify the image associated with $V$.  
- $s_j$: the weighted score for class `j`.

Using this notation, the classification score $s_j$ is computed as:
$$
s_j = \sum_{i=1}^{n\_prompts} w_{ij} \, (T_{ij} \cdot V)
$$
where $(T_{ij} \cdot V)$ denotes the dot product between the text and image embeddings, representing their similarity.

The predicted class is then given by:
$$
c_{s_j} = \operatorname*{arg\,max}_j(s_j)
$$

#### Generalization to novel classes

Since the model must output weights for each class, generalization is enabled by extending the output with weights representing novel categories.

A custom indexing layer maps the original model output (with shape `num_base_classes + 1`) to the actual number of target classes.

In the example below, the model is trained on 4 base classes (A, B, C, and D), and later adapted to a setting where 3 of the original base classes (A, B, and D) are retained, one is discarded (C), and 3 novel classes are introduced.

<p align="center">
    <img src="https://raw.githubusercontent.com/lucaSartore/CLIP-Few-shot/refs/heads/main/images/insexing_layer.png" alt="indexing layer" height="400"/>
</p>

The method `update_indexing_mask` reconfigures the indexing layer to align the re-weighter's output with the updated class structure.

#### Internal architecture

Given a visual embedding as input, the network performs the following steps:

1. Projects the input to a lower-dimensional latent space.  
2. Computes weights over prompts and classes.  
   - Weights for novel classes are learned independently of the input embedding.

Earlier versions of the model computed novel class weights as a function of the image embedding. However, using class-independent weights for novel categories was empirically shown to improve performance and was thus adopted.

#### Parameters

- **Clip**: clip model
- **Weighter**: re-weighter model
- **Dataset**: dataset 
- **Categories**: list of all the categories
- **Batch size**: batch size 
- **Number of steps**: number of training steps to take
- **Device**: device in which to run the training procedure 
- **Batch size multiplier**: by how much to multiply the batch size (see point **2** in **Training** section) 
- **Noise**: multiplier for gaussian noise to add to input data 
- **Learning rate**: learning rate for the optimizer 
- **Momentum**: multiplier for the momentum
- **Scheduler step**: number of steps after which the learning rate scheduler is triggered
- **Gamma**: multiplier for the learning rate scheduler


In [11]:
CLIP_EMBEDDING_SIZE = int(model.encode_text(clip.tokenize("foo").to(DEVICE)).shape[-1])
NUMBER_OF_PROMPTS = len(general_prompt_template) + len(class_specific_prompt_templates)


class ReWeighterModel(nn.Module):
    def __init__(self, classes_for_training: list[CategoryLabel], internal_size = 75):
        super().__init__()
        # we need to store the classes used for training in order to be able to update the indexing mask
        self.base_classes = [x for x in classes_for_training if not x["novel"]]
        self.num_base_classes = len(self.base_classes)
        # We use the update method for initialization here
        self.update_indexing_mask(classes_for_training)
        ####################### DESIGN_CHOICE ######################################
        # We observed that mapping the visual embedding in a small
        # embedding space before calculating the weights help
        # reducing overfitting, as we are "compressing" the
        # representation, and therefore reducing the model capacity.
        # We hypnotize that a low-dimensional space does not limit
        # much the performance of our model, as the task of re-weighting
        # is (at least intuitively) much simpler (and therefore lower dimensional)
        # than the task of comparing images and text
        ############################################################################
        # the first layer maps the visual embedding into a lower-dimensional space.
        self.l1 = nn.Linear(CLIP_EMBEDDING_SIZE, internal_size, dtype=model.dtype)
        # one output weight for each prompt, and for each class
        self.l2 = nn.Linear(internal_size, NUMBER_OF_PROMPTS * (self.num_base_classes), dtype=model.dtype)

        # weights for novel classes don't depends on the input
        weights = torch.rand(NUMBER_OF_PROMPTS).type(model.dtype)
        self.novel_classes_weights = nn.parameter.Parameter(weights)
        # sigmoid used to normalize the weights in the 0-1 range
        # TODO: idea, maybe allowing for negative weight can improve things?
        self.sigmoid = nn.Sigmoid()

    def forward(self, input: torch.Tensor):
        """
        input: [batch_size x clip_embedding_size] = the image-generated embeddings
        output: [batch_size x number_of_prompts x num_classes] = the weight estimated for each prompt-class pair
        """
        batch_size = input.shape[0]
        # the weights for base classes
        x: torch.Tensor = self.l1(input)
        x = self.l2(x)
        x = x.reshape(batch_size, NUMBER_OF_PROMPTS, self.num_base_classes)

        # the weights for novel classes
        x_novel = self.novel_classes_weights.unsqueeze(0)
        x_novel = x_novel.reshape(1, NUMBER_OF_PROMPTS, 1)
        x_novel = x_novel.expand(batch_size, NUMBER_OF_PROMPTS, 1)

        # concatenating into a single output
        x = torch.concat([x, x_novel], dim=-1)
        x = self.sigmoid(x)

        x = x[:,:,self.indexing_mask]
        return x

    def update_indexing_mask(self, classes: list[CategoryLabel]):
        """
        Update the indexing mask.
        This function need to be used after a model has being trained for a specific set of classes,
        """

        # map the class id, to the position on on the model's last layer
        id_to_index = {x["id"]: i for i,x in enumerate(self.base_classes)}

        # given a class, it returns the index inside the model's last layer that should be used
        # to re-weight the classes's prompts.
        def get_index(label: CategoryLabel):
            # novel categories are always mapped to the generic re-weighter (the last one)
            # as we haven't learned a model for them
            if label["novel"]:
                return len(self.base_classes)
            # base classes instead have their specialized re-weighter at the corresponding index
            return id_to_index[label["id"]]

        # building the indexing mask
        self.indexing_mask = torch.Tensor([
            get_index(x) for x in classes
        ]).type(torch.long)

    

### Creation of virtual novel samples

The re-weighter includes a dedicated output for novel classes. During training, "virtual novel classes" are introduced to facilitate learning. These are generated by duplicating base classes and assigning them new class IDs, effectively creating novel-like counterparts.

For instance, consider a dataset with 102 classes. Classes with IDs 0 to 50 are treated as base classes, and those from 51 to 101 as novel. An additional set of 51 virtual novel classes is created with IDs ranging from 102 to 152. These virtual classes are identical to the base classes in content but differ in class ID.

All virtual novel classes share a single output channel in the re-weighter (a tensor containing one weight per prompt). This design compels the re-weighter to learn a generalizable set of weights capable of distinguishing between the 51 virtual novel classes.

The underlying assumption is that if this shared weighting mechanism can effectively separate 51 distinct classes, it is likely to generalize well to 51 truly novel, unseen categories. This is considered a reasonable expectation, given the balanced division between base and novel classes.


In [12]:

# class names have being added to torch vision recently, but to avoid compatibility issues we added them here.
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# list of category labels of the base classes
base_classes_label: list[CategoryLabel] = [
    {
        "id": x,
        "name": CLASS_NAMES[x],
        "novel": False
    }
    for x in base_classes
]

# list of category labels for the novel classes
novel_classes_label: list[CategoryLabel] = [
    {
        "id": x,
        "name": CLASS_NAMES[x],
        "novel": True
    }
    for x in novel_classes
]

# list of category labels for all classes
all_classes_labels = base_classes_label + novel_classes_label

# list of category labels for base classes plus virtual-novel classes.
# This is the one that will be used during training, and you can see
# that we are essentially duplicating the number of novel classes,
# keeping the name unchanged, but shifting the ID, so to avoid overlap.
base_and_virtual_novel_classes_labels = base_classes_label + [
    {
        "id": x + len(all_classes_labels),
        "name": CLASS_NAMES[x],
        "novel": True
    }
    for x in base_classes
]

T = TypeVar("T")
E = TypeVar("E")
class VirtualNovelDataset(Dataset, Generic[T,E]):
    """
    This is a custom dataset implementation, that create "virtual classes"
    In short id create "num_novel_classes" new training samples, that are
    assigned a new ID (by adding "to_add" at the original ID)
    """
    def __init__(self, dataset: Dataset[tuple[T,E]], num_novel_classes: int, to_add: E):
        self._dataset_len: int = len(dataset) #type: ignore
        self._dataset = dataset
        self._virtual_novel: list[tuple[T,E]] = []
        indexes = random.sample(range(self._dataset_len), num_novel_classes)
        for index in indexes:
            data, label = dataset[index]
            new_label: E = label + to_add #type: ignore
            self._virtual_novel.append((data, new_label))
        
    def __len__(self):
        return self._dataset_len + len(self._virtual_novel)

    def __getitem__(self, index: int) -> tuple[T,E]:
        # non virtual sample
        if index < self._dataset_len:
            return self._dataset[index]
        # virtual sample:
        else:
            return self._virtual_novel[index - self._dataset_len]
            
train_base_and_virtual_novel = VirtualNovelDataset(train_base, len(train_base), len(all_classes_labels))

### Training

The training procedure largely follows standard practice, with four key modifications:

1. **Masking of input prompts**:  
   Virtual-novel training samples can introduce ambiguity. Consider an image of a rose appearing twice in a batch—once labeled "rose" and once "rose-novel". Since the visual content is identical, the model cannot predict both labels correctly at the same time, leading to instability.  
   To mitigate this, one of the prompt sets is masked during training. For example, when evaluating "rose-novel", all prompts associated with "rose" are disabled by zeroing out their corresponding text embeddings.

2. **Virtually larger batch size**:  
   To simulate larger batch sizes without exceeding memory constraints, a gradient accumulation strategy is used. Instead of calling `zero_grad` and `step` after every batch, updates are delayed across several batches.  
   This is particularly useful for re-weighter training, which performs better with larger batch sizes but is otherwise memory intensive.

3. **Gaussian noise injection**:  
   During training, Gaussian noise is added to the input data. This technique has been shown to improve the smoothness of decision boundaries, enhance the model's generalization capabilities, and reduce overfitting. It also acts as an implicit data augmentation strategy, which is particularly beneficial given the few-shot nature of the task.

4. **Learning rate scheduler**:  
   A hyperparameter defines the number of steps after which the scheduler activates, scaling the current learning rate by a fixed factor. In this case, a decay factor of 0.8 was found to offer a good balance between convergence speed and loss reduction.



In [13]:
def train(
        clip: clip.model.CLIP,
        weighter: ReWeighterModel ,
        dataset: Dataset[tuple[torch.Tensor, int]],
        categories: list[CategoryLabel],
        batch_size: int,
        num_steps: int,
        device: torch.device | str,
        batch_size_multiplier: int = 1,
        noise: bool | float = 0.01,
        lr: float = 5,
        momentum: float = 0.5,
        scheduler_step: int = 50,
        gamma: float = 0.9,
    ):
    clip.eval()
    weighter.train()

    #Saving the best model, in case of an unstable loss
    best_model: ReWeighterModel | None = None
    best_loss: float = math.inf

    # Todo: Add some interesting comments on why we chose the optimizer/parameters
    # that we did, once we are done with the hyperparameters optimization
    #optimizer = torch.optim.AdamW(params = weighter.parameters(), lr=lr)
    optimizer = torch.optim.SGD(params = weighter.parameters(), lr=lr, momentum=momentum)
    scheduler = StepLR(optimizer, scheduler_step, gamma=gamma)

    # Remap labels into a contiguous set starting from zero.
    # this is the same as what was done in the zero-shot example notebook:
    #   > contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
    # but is more efficient later on, as the indexing can be done natively
    # inside tensor, without using python structures
    contig_cat2idx = torch.zeros(1+max(cat["id"] for cat in categories)).long()
    for idx, cat in enumerate(categories):
        contig_cat2idx[cat["id"]] = idx
    
    # loading the text features
    # size: num_prompts x clip_embedding_size x num_categories"]
    text_features = load_text_features(categories)

    # these constants are useful when re-shaping some tensors later
    num_prompts = text_features.shape[0]
    num_classes = text_features.shape[2]

    # simple dataloader creation
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # definition of the loss
    loss_fn = nn.CrossEntropyLoss()

    progress_bar = tqdm(range(num_steps), "Steps") 
    for _ in progress_bar:
        image: torch.Tensor
        target: torch.Tensor
        losses: list[float] = []
        for i, (image, target) in enumerate(dataloader):


            # Converting the class indexes to contiguous indexes
            # Equivalent line, in the zero-shot notebook:
            #   > target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()
            # shape: [batch_size]
            target = contig_cat2idx[target]

            target = target.to(DEVICE)
            image = image.to(DEVICE)

            # calculating the image features
            with torch.no_grad():
                image_features: torch.Tensor = clip.encode_image(image)

                # Gaussian noise injection
                if noise:
                    noise_feature = torch.randn_like(image_features) * noise
                    noise_feature[target <= num_classes//2] = 0
                    image_features += noise_feature

                image_features /= image_features.norm(dim=-1, keepdim=True)


            # We need to be able to mask out some text features (to avoid issues
            # with virtual novel classes, as explained in markdown cell)
            # Do do so, we need to expand the text features, and add a dimension (batch size)
            # this is because every single element in a batch will have a different target
            # (and therefore a different input feature to mask)
            # shape: [batch_size x num_prompts x clip_embedding_size x num_classes]
            masked_text_features = text_features \
                .clone() \
                .unsqueeze(0) \
                .expand(
                    len(target), # batch size
                    num_prompts,
                    CLIP_EMBEDDING_SIZE,
                    num_classes
                ).clone()

            # the following 15 lines are used to mask the text feature in the correct place,
            # they are a bit hard to understand, however, here you can find a equivalent snippet of code
            # that is much more readable, even tho it is less efficient computationally
            # Note: 51 is the number of base classes
            # >     for i in range(51):
            # >         masked_text_features[target == i, :, :, i+51] = 0
            # >         masked_text_features[target == i+51, :, :, i] = 0
                

            # vector tell me which class we need to zero in each element of a batch size
            #shape: [batch_size]
            class_to_zero = (target + len(categories) / 2) % len(categories)
            class_to_zero = class_to_zero.long()
            #shape: [batch_size x num_classes]
            # I can use the one_hot notation to obtain exactly the boolean mask I need below
            class_to_zero = nn.functional.one_hot(class_to_zero, num_classes).bool()

            # I need to invert the axis order so that the boolean mask can be used correctly
            # shape: [batch_size x num_classes x num_prompts x clip_embedding_size]
            masked_text_features = masked_text_features.permute(0,3,2,1)
            # masking out the undesired text features
            masked_text_features[class_to_zero,:,:] = 0
            # shape: [batch_size x num_prompts x clip_embedding_size x num_classes]
            # set teh correct order back
            masked_text_features = masked_text_features.permute(0,3,2,1)

            # computing the similarity scores (using the masked text features)
            # b = batch, e = embedding, p = prompts, c = classes
            # shape: [batch_size x num_prompts x num_classes]
            scores = torch.einsum("be,bpec -> bpc", image_features, masked_text_features.detach())

            # computing the weights
            # shape: [batch_size x num_prompts x num_classes]
            weights: torch.Tensor = weighter(image_features)
    
            # reweighing scores
            scores *= weights

            # summing up the scores of every different prompt
            # shape: [batch_size x num_classes]
            out = torch.sum(scores, dim=1)

            # calculating the loss, and updating the gradient
            loss: torch.Tensor = loss_fn(out, target)
            losses.append(loss.item())
            # this is to avoid having different losses (and therefore
            # different behaver of the optimizer) when the batch
            # size multiplier changes
            loss /= batch_size_multiplier
            loss.backward()
            if (i+1) % batch_size_multiplier == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRADIENT_VALUE)  
                optimizer.step()
                optimizer.zero_grad()

        # repeat here, to avoid losing some information if the number
        # of batches is not divisible by batch_size_multiplier
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRADIENT_VALUE)  
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

        # printing the current loss in the progress bar
        final_loss = np.average(losses)
        progress_bar.set_postfix(loss=f"{final_loss:.4f}")

        # saving the best model to cpu
        if final_loss < best_loss:
            best_model = copy.deepcopy(weighter).to("cpu")

    # returning the best model, according to the loss evaluation.
    assert best_model is not None
    return best_model.to(device)


weighter = ReWeighterModel(base_and_virtual_novel_classes_labels, RE_WEIGHTER_L1_SIZE).to(DEVICE)


weighter = train(
    model,
    weighter,
    train_base_and_virtual_novel,
    base_and_virtual_novel_classes_labels,
    BATCH_SIZE,
    NUM_STEPS,
    DEVICE,
    BATCH_SIZE_MULTIPLIER,
    NOISE,
    LR,
    MOMENTUM,
    SCHEDULER_STEP,
    GAMMA
)



Steps: 100%|██████████| 100/100 [13:55<00:00,  8.36s/it, loss=0.6393]


### Evaluation

#### A note on performance

Model mixture architectures are generally slower; however, performance remains comparable to zero-shot CLIP. This is because the image embedding is computed only once, while text embeddings are precomputed and stored.  
The only computational overhead comes from the re-weighter and the subsequent dot product operations.

In [14]:

@torch.no_grad()
def eval(
        clip: clip.model.CLIP,
        weighter: ReWeighterModel ,
        dataset: Dataset[tuple[torch.Tensor, int]],
        categories: list[CategoryLabel],
        batch_size: int,
        device: torch.device | str,
        label = ""
    ):
    # let's set the model in evaluation mode
    clip.eval()
    weighter.eval()

    # Remap labels into a contiguous set starting from zero
    contig_cat2idx = torch.zeros(1+max(cat["id"] for cat in categories)).long()
    for idx, cat in enumerate(categories):
        contig_cat2idx[cat["id"]] = idx

    text_features = load_text_features(categories)

    # simple dataloader creation
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # here we store the number of correct predictions we will make
    correct_predictions = 0

    image: torch.Tensor
    target: torch.Tensor
    for image, target in tqdm(dataloader, desc=label):

        target = contig_cat2idx[target]

        image = image.to(device)
        target = target.to(device)

        image_features: torch.Tensor = clip.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # shape: [batch_size x num_prompts x num_classes]
        scores: torch.Tensor = torch.matmul(image_features, text_features).permute(1,0,2)

        # shape: [ batch_size x (num_prompts x num_classes)]
        weights: torch.Tensor = weighter(image_features)

        # reweighing scores
        scores *= weights

        out = torch.sum(scores, dim=1)
        predicted_class = out.argmax(dim=-1)

        # now we check which are correct, and sum them (False == 0, True == 1)
        correct_predictions += (predicted_class == target).sum().item()


    # and now we compute the accuracy
    accuracy = correct_predictions / len(dataset) #type: ignore
    return accuracy

weighter.update_indexing_mask(base_classes_label)
base_accuracy = eval(model, weighter, dataset=test_base, categories=base_classes_label, batch_size=32, device=DEVICE, label="🧠 Evaluation on Base Classes")

weighter.update_indexing_mask(novel_classes_label)
novel_accuracy = eval(model, weighter, dataset=test_novel, categories=novel_classes_label, batch_size=32, device=DEVICE, label="🧠 Evaluation on Novel Classes")


print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")


🧠 Evaluation on Base Classes: 100%|██████████| 84/84 [00:16<00:00,  5.05it/s]
🧠 Evaluation on Novel Classes: 100%|██████████| 122/122 [00:19<00:00,  6.25it/s]

🔍 Base classes accuracy: 96.90%
🔍 Novel classes accuracy: 78.89%





## Harmonic Mean

In [15]:
def get_harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    hm = numerator / denominator
    return hm


harmonic_mean = get_harmonic_mean(base_accuracy, novel_accuracy)
print(f"🔍 Harmonic Mean: {harmonic_mean*100:.2f}%")

🔍 Harmonic Mean: 86.97%


### Classification of novel and base categories

The model architecture may appear as a composition of two distinct systems, given that novel classes rely on different prompts and weights compared to base classes. Most state-of-the-art models show reduced performance on unseen classes relative to zero-shot CLIP, despite not introducing any additional components.

To demonstrate that our model functions as a unified system, we evaluate its performance on mixed datasets and observe that it maintains high accuracy even when both base and novel classes are present. This indicates that the model is not simply a conjunction of two specialized sub-models.

#### Results

When evaluated on a mixed dataset, there is a 3–4% drop in accuracy relative to the harmonic mean. This reduction is expected due to the increased entropy introduced by the larger number of classes and is consistent with the behavior of zero-shot CLIP. These results support the claim that the model operates as a single architecture.

#### How we got here

In early versions, the model behaved like two separate systems, showing a strong bias toward base classes and frequently misclassifying images from unseen categories.

To address this issue, the following techniques were applied:

1. **Precise masking of text features during training**  
   As described in the **Training** section, prompt masking is used to stabilize training. Early versions masked either all virtual-novel prompts or all base prompts, leading to a lack of training samples with both prompt types coexisting. Consequently, the model learned to treat them separately.

2. **Avoiding class-specific prompts for novel classes and using padding instead**  
   Although the model is capable of learning to ignore irrelevant prompts, the re-weighter tended to assign weight to class-specific prompts, especially due to the virtual nature of novel-class training. This created a bias in favor of base classes. Using padding instead of class-specific prompts mitigated this issue.

3. **Using learnable parameters for novel-class weights instead of image embedding-based predictions**  
   Earlier implementations used the same mechanism to compute weights for both base and novel classes. Replacing this with learnable, input-independent parameters for novel classes improved performance. This adjustment reflects the fact that real novel-class embeddings differ from virtual-novel ones, and removing this dependency enhanced accuracy when transitioning to true novel categories.


In [16]:
weighter.update_indexing_mask(all_classes_labels)
total_accuracy = eval(model, weighter, dataset=test_set, categories=all_classes_labels, batch_size=32, device=DEVICE, label="🧠 Evaluation on all Classes")
print(f"🔍 Novel and base classes accuracy: {total_accuracy*100:.2f}%")

delta = harmonic_mean - total_accuracy
print(f"🔍 Delta WRT harmonic mean: {delta*100:.2f}%")

🧠 Evaluation on all Classes: 100%|██████████| 205/205 [00:26<00:00,  7.88it/s]

🔍 Novel and base classes accuracy: 82.58%
🔍 Delta WRT harmonic mean: 4.39%



