# CLIP Weighted Model Mixture - Few shot learning technique
This notebook implement a novel strategy for clip few shot adaptation.

Our strategy involve the creation of multiple "models" where the output is then weighted
to then create an "average guess" that should be, on average more accurate than any individual model.

So far this is a pretty common technique, however there are two key differences that make this approach novel:
1. Each model is not a learned model, but is essentially just a prompt. That is then passed through the clip text encoder.
The prompts are NOT learned, instead they are manually crafted base on domain knowledge.
2. The "averaging" is not a simple average, but is a weighted average, and the weights are provided by a "Re-Weighting Model"
This model is a simple neural network that take as input the clip image embeddings.

## Installing and Importing dependency

In [1]:
%pip install openai_clip

from torchvision.datasets import Flowers102
import random
import torch
import clip
from tqdm import tqdm
from typing import TypeVar, cast, TypedDict, Generic
import random
from torch.utils.data import Dataset, Subset, DataLoader
from typing import Callable
from torch import nn
import math
import copy
import numpy as np
import clip.model
from functools import lru_cache

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


### Settings

This section include a few configurable parameters of the notebook

In [2]:
# Whether to use the default train-test spit provided by torchvision or not
# > *If set to True*: we will have the default split where each training class has 10 samples
# > *if set to False* the number of shot will be defined by the `NUMBER_OF_SHOTS` constants
#   this is done to make it easier to compare our algorithm with other options, that often uses 16 as the default number of shots
USE_DEFAULT_SPLIT = False
# The number of shot to use in learning. Only has effect if `USE_DEFAULT_SPLIT` is set to False
NUMBER_OF_SHOTS = 16

# the device to use for training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# underlying model to use for the clip visual encoder
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
CLIP_MODEL="ViT-B/16"


### Type definitions
A few type definition that will be used throughout the notebook

In [3]:
class CategoryLabel(TypedDict):
    id: int
    name: str
    novel: bool

### Dataset loading

Nothing special here, just the definition of the functions to load the dataset

In [4]:
def get_data(data_dir="./data", transform=None):
    """
    Load Flowers102 train, validation and test sets.
    Uses the default split provided by torchvision.

    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return (
        cast(Dataset[tuple[torch.Tensor,int]], train),
        cast(Dataset[tuple[torch.Tensor,int]], val),
        cast(Dataset[tuple[torch.Tensor,int]], test)
    )

def get_data_custom_split(data_dir="./data", transform=None):
    """
    Load Flowers102 train, validation and test sets.
    Uses a custom split that allow to specify the number of items per class that will be assigned to the train set.
    The validation set will always be empty.

    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    
    a,b,c = get_data(data_dir, transform)
    full_dataset: Dataset[tuple[torch.Tensor, int]] = torch.utils.data.ConcatDataset([a,b,c])

    labels_set = set(l for _,l in full_dataset)
    class_to_index_dict: dict[int, list[int]] = {l: [] for l in labels_set}

    for i in range(len(full_dataset)):
        l = full_dataset[i][1]
        class_to_index_dict[l].append(i)

    train: list[int] = []
    test: list[int] = []

    for indexes in class_to_index_dict.values():
        random.shuffle(indexes)
        train += indexes[0:NUMBER_OF_SHOTS]
        test += indexes[NUMBER_OF_SHOTS:]

    train_dataset = torch.utils.data.Subset(full_dataset, train)
    validation_dataset = torch.utils.data.Subset(full_dataset, [])
    test_dataset = torch.utils.data.Subset(full_dataset, test)
    return (
        cast(Dataset[tuple[torch.Tensor,int]], train_dataset),
        cast(Dataset[tuple[torch.Tensor,int]], validation_dataset),
        cast(Dataset[tuple[torch.Tensor,int]], test_dataset)
    )


## Base and Novel categories
Function definition to split the dataset into novel and base classes

In [5]:
T = TypeVar("T")
E = TypeVar("E")
def base_novel_categories(dataset: Dataset[tuple[T,E]]):
    # set returns the unique set of all dataset classes
    all_classes = set(l for _, l in dataset)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

## Split Dataset
The next step is to actually split the dataset into the base and novel categories we extract from `base_novel_categories`.
To split the data we need the dataset (obviously) and the list of base classes. If the sample label is not part of the base categories, then it must be part of the novel ones.

In [6]:

T = TypeVar("T")
E = TypeVar("E")
def split_data(dataset: Dataset[tuple[T,E]], base_classes: list[int]):
    # these two lists will store the sample indexes
    base_categories_samples: list[int] = []
    novel_categories_samples: list[int] = []

    # set with the base classes (so that checking existence is O(1))
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    label: int
    for sample_id, (_, label) in enumerate(dataset): #type: ignore
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    base_dataset = Subset(dataset, base_categories_samples)
    novel_dataset = Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Load CLIP

In [7]:
model, preprocess = clip.load(CLIP_MODEL, device=DEVICE)

### Defining the prompt templates

This section define a list of "prompt templates". A template is just a function that take as input the name of a class, and return
a string describing the underlying object.

The templates defined in this section can be divided into two broad categories:
1. **General prompts**: those prompts are generic and could be used for potentially any class.
2. **Class specific prompts**: those prompts are specific for one category.

Here you should note that there is nothing that maps the class specific prompts to the associated class,
infact we don't worry at all about that, and we let the reweighed take care of that.




In [8]:

#####################################################
################# GENERAL PROMPTS ###################
#####################################################
general_prompt_template: list[Callable[[str], str]] = [
    lambda x: f"a photo of a {x}, a type of flower.",
    lambda x: f"a {x} flower.",
    lambda x: f"a photo of some {x}, a type of flower.",
    lambda x: f"some {x} flowers.",
    lambda x: f"a close-up of a {x} flower.",
    lambda x: f"an image of a {x} blossom.",
    lambda x: f"a beautiful {x} in bloom.",
    lambda x: f"a bunch of {x} flowers.",
    lambda x: f"a macro shot of a {x} flower.",
    lambda x: f"a single {x} flower.",
    lambda x: f"fresh {x} flowers in a garden.",
    lambda x: f"a {x} flower in the wild.",
    lambda x: f"a botanical photograph of a {x}.",
    lambda x: f"a vibrant {x} bloom.",
    lambda x: f"a {x} plant with flowers.",
    lambda x: f"a {x} growing in nature.",
    lambda x: f"a {x} flower in sunlight.",
    lambda x: f"a colorful {x} flower close-up.",
    lambda x: f"a {x}, commonly found in gardens.",
    lambda x: f"wild {x} flowers blooming.",
    lambda x: f"a garden filled with {x} flowers.",
    lambda x: f"floral photography featuring a {x}.",
    lambda x: f"an aesthetic photo of a {x}.",
    lambda x: f"a {x} flower in full bloom.",

    # Descriptive
    lambda x: f"a large blooming {x}.",
    lambda x: f"a freshly picked {x}.",
    lambda x: f"a wilted {x} flower.",
    lambda x: f"a {x} with dewdrops on its petals.",
    lambda x: f"a delicate {x} on a green stem.",
    lambda x: f"a colorful bouquet with {x}.",

    # Scientific-ish
    lambda x: f"a botanical illustration of {x}.",
    lambda x: f"a herbarium specimen of {x}.",
    lambda x: f"field photo of {x} species.",
    lambda x: f"{x} photographed for a flora study.",
    lambda x: f"a study sample of the {x} flower.",
    lambda x: f"{x} genus flower in bloom.",

    # Casual / Internet Style
    lambda x: f"my favorite flower: the {x}.",
    lambda x: f"saw a {x} today!",
    lambda x: f"check out this {x} flower!",
    lambda x: f"flowers like {x} are amazing.",
    lambda x: f"the {x} is blooming this season.",

    # Photographic / Artistic
    lambda x: f"an artistic photo of a {x}.",
    lambda x: f"film photo of a {x} flower.",
    lambda x: f"a {x} in black and white.",
    lambda x: f"the silhouette of a {x} in sunset light.",
    lambda x: f"a {x} flower in a vintage vase.",
    lambda x: f"an abstract painting of a {x}.",
    lambda x: f"macro photography of a {x} blossom.",

    # Poetic or Metaphorical
    lambda x: f"a {x}, soft as a whisper.",
    lambda x: f"a {x} dancing in the wind.",
    lambda x: f"petals of the {x}, kissed by rain.",
    lambda x: f"a lonely {x} on a quiet morning.",
    lambda x: f"a {x} symbolizing peace and beauty.",
    lambda x: f"like a {x} in springtime.",

    # Contextual / Scene-based
    lambda x: f"a {x} in a wildflower meadow.",
    lambda x: f"a {x} flower on a wedding table.",
    lambda x: f"{x} flowers in a forest clearing.",
    lambda x: f"a {x} growing beside a stone path.",
    lambda x: f"{x} blossoms in a city garden.",
    lambda x: f"{x} petals scattered on the ground.",
]

#####################################################
############## CLASS SPECIFIC PROMPTS ###############
#####################################################

class_specific_prompt_templates: list[Callable[[str], str]] = [
    # pink primrose
    lambda x: f"a photo of a {x}, a delicate flower with soft pink petals.",
    lambda x: f"an image of a {x}, a blooming plant often found in springtime gardens.",
    lambda x: f"a close-up of a {x}, known for its pale pink blossoms and gentle appearance.",

    # hard-leaved pocket orchid
    lambda x: f"a photo of a {x}, a tropical orchid with stiff, glossy leaves.",
    lambda x: f"an image of a {x}, an exotic flower with waxy petals and leathery foliage.",
    lambda x: f"a botanical image of a {x}, an orchid species with hard, durable leaves.",

    # canterbury bells
    lambda x: f"a photo of a {x}, a bell-shaped flower in shades of purple and blue.",
    lambda x: f"an image of a {x}, known for its tall spikes of bell-like blossoms.",
    lambda x: f"a close-up of a {x}, a cottage garden flower with cup-shaped blooms.",

    # sweet pea
    lambda x: f"a photo of a {x}, a fragrant flower with delicate, ruffled petals.",
    lambda x: f"an image of a {x}, often grown for its pastel colors and pleasant scent.",
    lambda x: f"a close-up of a {x}, a climbing plant with butterfly-shaped flowers.",

    # english marigold
    lambda x: f"a photo of a {x}, a bright orange or yellow flower with daisy-like blooms.",
    lambda x: f"an image of a {x}, known for its healing properties and sunny appearance.",
    lambda x: f"a close-up of a {x}, a calendula flower common in herb gardens.",

    # tiger lily
    lambda x: f"a photo of a {x}, an orange flower with dark spots and recurved petals.",
    lambda x: f"an image of a {x}, a wild-looking lily with dramatic coloring.",
    lambda x: f"a close-up of a {x}, known for its bold, tiger-striped blooms.",

    # moon orchid
    lambda x: f"a photo of a {x}, an elegant white orchid with a moon-like glow.",
    lambda x: f"an image of a {x}, a phalaenopsis flower often found in tropical climates.",
    lambda x: f"a close-up of a {x}, a soft and symmetrical flower with wide petals.",

    # bird of paradise
    lambda x: f"a photo of a {x}, a tropical flower resembling a colorful bird.",
    lambda x: f"an image of a {x}, known for its bright orange and blue petals.",
    lambda x: f"a close-up of a {x}, an exotic bloom that looks like a flying bird.",

    # monkshood
    lambda x: f"a photo of a {x}, a hooded purple flower with toxic properties.",
    lambda x: f"an image of a {x}, often called wolfsbane, with dark violet petals.",
    lambda x: f"a close-up of a {x}, a tall plant with helmet-shaped blooms.",

    # globe thistle
    lambda x: f"a photo of a {x}, a spherical flower with spiky blue petals.",
    lambda x: f"an image of a {x}, known for its round shape and thistle-like texture.",
    lambda x: f"a close-up of a {x}, a unique ornamental flower with a metallic hue.",

    # snapdragon
    lambda x: f"a photo of a {x}, a colorful flower that resembles a dragon's mouth.",
    lambda x: f"an image of a {x}, known for its vertical clusters of blooming petals.",
    lambda x: f"a close-up of a {x}, a common garden flower with hinged, snout-like blooms.",

    # colt's foot
    lambda x: f"a photo of a {x}, a yellow wildflower that appears before its leaves.",
    lambda x: f"an image of a {x}, a small flower resembling a dandelion in early spring.",
    lambda x: f"a close-up of a {x}, a plant with hoof-shaped leaves and bright yellow blooms.",

    # king protea
    lambda x: f"a photo of a {x}, a large flower with a spiky crown-like appearance.",
    lambda x: f"an image of a {x}, a South African bloom with a central cone and pink petals.",
    lambda x: f"a close-up of a {x}, a striking flower often used in bold arrangements.",

    # spear thistle
    lambda x: f"a photo of a {x}, a spiny plant with purple tufted blooms.",
    lambda x: f"an image of a {x}, known for its sharp leaves and thistle head.",
    lambda x: f"a close-up of a {x}, a wildflower with prickly stems and a vibrant purple flower.",

    # yellow iris
    lambda x: f"a photo of a {x}, a bright yellow iris with upright petals.",
    lambda x: f"an image of a {x}, commonly found near water, with sword-shaped leaves.",
    lambda x: f"a close-up of a {x}, an elegant flower with golden hues and frilled edges.",

    # globe-flower
    lambda x: f"a photo of a {x}, a round yellow bloom resembling a buttercup.",
    lambda x: f"an image of a {x}, a spherical flower found in alpine meadows.",
    lambda x: f"a close-up of a {x}, a glowing, globe-shaped flower with dense petals.",

    # purple coneflower
    lambda x: f"a photo of a {x}, a daisy-like flower with purple petals and a spiky cone.",
    lambda x: f"an image of a {x}, often used in herbal remedies and garden borders.",
    lambda x: f"a close-up of a {x}, known for its downward-sloping petals and orange center.",

    # peruvian lily
    lambda x: f"a photo of a {x}, a spotted flower with multiple colorful petals.",
    lambda x: f"an image of a {x}, a long-lasting bloom used in cut flower arrangements.",
    lambda x: f"a close-up of a {x}, a lily-like flower with striped inner petals.",

    # balloon flower
    lambda x: f"a photo of a {x}, a flower bud that inflates like a balloon before opening.",
    lambda x: f"an image of a {x}, a star-shaped bloom in shades of blue or purple.",
    lambda x: f"a close-up of a {x}, a unique flower with puffy unopened buds.",

    # giant white arum lily
    lambda x: f"a photo of a {x}, a large white flower with a trumpet-like shape.",
    lambda x: f"an image of a {x}, also known as a calla lily, with a central yellow spadix.",
    lambda x: f"a close-up of a {x}, an elegant flower with smooth white petals.",

    # fire lily
    lambda x: f"a photo of a {x}, a bright red or orange lily with curled petals.",
    lambda x: f"an image of a {x}, a dramatic flower known for its flame-like appearance.",
    lambda x: f"a close-up of a {x}, a fiery-looking lily with backward-bending petals.",

    # pincushion flower
    lambda x: f"a photo of a {x}, a flower with a domed center and delicate fringe.",
    lambda x: f"an image of a {x}, often purple or lavender, resembling a pin-filled cushion.",
    lambda x: f"a close-up of a {x}, known for its central disk and lace-like petals.",

    # fritillary
    lambda x: f"a photo of a {x}, a checkered bell-shaped flower often in purple tones.",
    lambda x: f"an image of a {x}, a rare flower with a distinctive petal pattern.",
    lambda x: f"a close-up of a {x}, a delicate wildflower with hanging, nodding blooms.",

    # red ginger
    lambda x: f"a photo of a {x}, a tropical flower with bright red bracts.",
    lambda x: f"an image of a {x}, known for its bold color and upright flower spikes.",
    lambda x: f"a close-up of a {x}, a striking plant native to rainforests.",

    # grape hyacinth
    lambda x: f"a photo of a {x}, a small bulbous plant with clusters of blue-purple flowers.",
    lambda x: f"an image of a {x}, resembling tiny grapes arranged on a spike.",
    lambda x: f"a close-up of a {x}, a spring flower with densely packed florets.",

    # corn poppy
    lambda x: f"a photo of a {x}, a bright red flower with papery petals and a dark center.",
    lambda x: f"an image of a {x}, a wild poppy often found in meadows and fields.",
    lambda x: f"a close-up of a {x}, a flower symbolizing remembrance and resilience.",

    # prince of wales feathers
    lambda x: f"a photo of a {x}, a spiky flower head resembling a feathery plume.",
    lambda x: f"an image of a {x}, known for its upright purple-pink floral spikes.",
    lambda x: f"a close-up of a {x}, a member of the amaranth family with plume-like blooms.",

    # stemless gentian
    lambda x: f"a photo of a {x}, a vivid blue flower growing close to the ground.",
    lambda x: f"an image of a {x}, a low-growing gentian with trumpet-shaped petals.",
    lambda x: f"a close-up of a {x}, a mountain flower with intensely blue blossoms.",

    # artichoke
    lambda x: f"a photo of a {x}, a thistle-like plant with edible buds and purple flowers.",
    lambda x: f"an image of a {x}, a spiky flower head that blooms into a vibrant violet.",
    lambda x: f"a close-up of a {x}, a large budding flower with a layered appearance.",

    # sweet william
    lambda x: f"a photo of a {x}, a cluster of small flowers in pink, red, or white.",
    lambda x: f"an image of a {x}, a garden flower known for its fringed petal edges.",
    lambda x: f"a close-up of a {x}, a fragrant bloom often used in cottage gardens.",

    # carnation
    lambda x: f"a photo of a {x}, a ruffled flower commonly seen in bouquets.",
    lambda x: f"an image of a {x}, a traditional bloom with a clove-like scent.",
    lambda x: f"a close-up of a {x}, known for its layered petals and vibrant color range.",

    # garden phlox
    lambda x: f"a photo of a {x}, a tall plant with clusters of pink or purple blooms.",
    lambda x: f"an image of a {x}, a perennial flower found in cottage-style gardens.",
    lambda x: f"a close-up of a {x}, a phlox with star-shaped petals growing in dense bunches.",

    # love in the mist
    lambda x: f"a photo of a {x}, a blue flower surrounded by fine, feathery foliage.",
    lambda x: f"an image of a {x}, a delicate flower with a misty, lace-like background.",
    lambda x: f"a close-up of a {x}, a whimsical bloom with soft, threadlike leaves.",

    # mexican aster
    lambda x: f"a photo of a {x}, a daisy-like flower with bright pink or purple petals.",
    lambda x: f"an image of a {x}, a tall annual bloom often seen in wildflower fields.",
    lambda x: f"a close-up of a {x}, a lightweight flower with yellow centers and soft petals.",

    # alpine sea holly
    lambda x: f"a photo of a {x}, a spiky blue flower with thistle-like bracts.",
    lambda x: f"an image of a {x}, a unique alpine plant with metallic-colored petals.",
    lambda x: f"a close-up of a {x}, a flower with a cone center and pointed star-shaped sepals.",

    # ruby-lipped cattleya
    lambda x: f"a photo of a {x}, an orchid with bold purple lips and pastel petals.",
    lambda x: f"an image of a {x}, a showy flower with ruby-colored accents on its lip.",
    lambda x: f"a close-up of a {x}, a fragrant orchid with elaborate ruffled petals.",

    # cape flower
    lambda x: f"a photo of a {x}, a brightly colored South African flower with daisy form.",
    lambda x: f"an image of a {x}, known for its vivid hues and sun-tracking behavior.",
    lambda x: f"a close-up of a {x}, a vibrant flower with a dark center and radiating petals.",

    # great masterwort
    lambda x: f"a photo of a {x}, a flower with a central cluster surrounded by papery bracts.",
    lambda x: f"an image of a {x}, a perennial bloom with intricate star-like heads.",
    lambda x: f"a close-up of a {x}, a soft-colored flower with detailed florets in the center.",

    # siam tulip
    lambda x: f"a photo of a {x}, a tropical flower with pink petals and green bracts.",
    lambda x: f"an image of a {x}, also known as Curcuma, with cone-shaped blossoms.",
    lambda x: f"a close-up of a {x}, a vibrant Thai flower with tulip-like structure.",

    # lenten rose
    lambda x: f"a photo of a {x}, a spring flower with nodding blooms and muted colors.",
    lambda x: f"an image of a {x}, a hellebore plant with leathery leaves and soft petals.",
    lambda x: f"a close-up of a {x}, a cold-hardy flower blooming in late winter to early spring.",

    # barbeton daisy
    lambda x: f"a photo of a {x}, a brightly colored daisy native to South Africa.",
    lambda x: f"an image of a {x}, known for its vivid red, orange, or pink petals.",
    lambda x: f"a close-up of a {x}, a cheerful flower with a prominent central disk.",

    # daffodil
    lambda x: f"a photo of a {x}, a trumpet-shaped flower with yellow or white petals.",
    lambda x: f"an image of a {x}, one of the first blooms of spring with a central corona.",
    lambda x: f"a close-up of a {x}, a classic bulb flower symbolizing renewal and hope.",

    # sword lily
    lambda x: f"a photo of a {x}, a tall flower with sword-like leaves and vertical blooms.",
    lambda x: f"an image of a {x}, commonly known as gladiolus with stacked florets.",
    lambda x: f"a close-up of a {x}, a showy flower in a rainbow of colors on long spikes.",

    # poinsettia
    lambda x: f"a photo of a {x}, a festive plant with red or white leaf-like bracts.",
    lambda x: f"an image of a {x}, often used in winter displays with green foliage and colorful tops.",
    lambda x: f"a close-up of a {x}, a holiday flower with bright petal-like leaves.",

    # bolero deep blue
    lambda x: f"a photo of a {x}, a compact flower with deep violet petals and ruffled texture.",
    lambda x: f"an image of a {x}, a variety of pansy known for its rich blue coloring.",
    lambda x: f"a close-up of a {x}, a velvety flower with intricate patterns and deep hues.",

    # wallflower
    lambda x: f"a photo of a {x}, a small clustered flower known for growing on walls or rocky soil.",
    lambda x: f"an image of a {x}, a plant with fragrant blooms in yellow, orange, or red.",
    lambda x: f"a close-up of a {x}, a simple flower with four-petal blossoms and warm colors.",

    # marigold
    lambda x: f"a photo of a {x}, a vibrant flower with layers of orange or yellow petals.",
    lambda x: f"an image of a {x}, known for its strong scent and decorative garden use.",
    lambda x: f"a close-up of a {x}, a sun-loving flower with round, bushy blooms.",

    # buttercup
    lambda x: f"a photo of a {x}, a shiny yellow flower with cup-shaped petals.",
    lambda x: f"an image of a {x}, a wildflower with a simple structure and glossy surface.",
    lambda x: f"a close-up of a {x}, a cheerful bloom with golden overlapping petals.",

    # oxeye daisy
    lambda x: f"a photo of a {x}, a white-petaled daisy with a yellow central disk.",
    lambda x: f"an image of a {x}, a classic meadow flower with a tall, slender stem.",
    lambda x: f"a close-up of a {x}, a widespread wildflower resembling a common daisy.",

    # common dandelion
    lambda x: f"a photo of a {x}, a yellow flower with toothed leaves and fluffy seed heads.",
    lambda x: f"an image of a {x}, a weedy plant known for its puffball seed dispersal.",
    lambda x: f"a close-up of a {x}, a golden flower head made of many tiny florets.",

    # petunia
    lambda x: f"a photo of a {x}, a funnel-shaped flower often used in hanging baskets.",
    lambda x: f"an image of a {x}, a colorful bloom available in many vibrant shades.",
    lambda x: f"a close-up of a {x}, a soft, velvety flower with a wide trumpet-like shape.",

]

### Text feature loading function

This function take as input a list of categories, and apply clip's text processing pipeline.
in particular:
 - It first uses the prompt templates to generate a matrix of strings (shape num_prompts x num classes)
 - Then it tokenize the matrix
 - finally it pass the tokenized prompts thought the text encoder and thus getting the text embedding

the final result is a matrix having the shape: `num_prompts` x `embedding size` x `num_classes`

In [None]:
@torch.no_grad()
def load_text_features(categories: list[CategoryLabel]):
    """
    size: num_prompts x num_categories x input_size"]
    """
    prompts_for_base_classes = general_prompt_template + class_specific_prompt_templates
    
    multiplier = len(prompts_for_base_classes) // len(general_prompt_template) + 1
    prompts_for_novel_classes = general_prompt_template * multiplier
    prompts_for_novel_classes = prompts_for_novel_classes[0:len(prompts_for_base_classes)]

    def get_prompts_for_one_class(category: CategoryLabel) -> list[str]:
        if category["novel"]:
            prompts = prompts_for_novel_classes
        else:
            prompts = prompts_for_base_classes
        name = category["name"]
        return [t(name) for t in prompts ]
    

    text_inputs = [
        clip.tokenize(
            get_prompts_for_one_class(category)
        ).to(DEVICE)
        for category in categories
    ]

    text_features_array: list[torch.Tensor] = [
        model.encode_text(x)
        for x in text_inputs
    ]

    # shape: num_classes x num_prompts x embedding_size
    text_features = torch.stack([
        x/x.norm(dim=-1,keepdim=True)
        for x in text_features_array
    ])

    # shape: num_prompts x embedding_size x num_classes
    text_features = text_features.permute(1,2,0)

    return text_features

['a photo of a Foo, a type of flower.', 'a Foo flower.', 'a photo of some Foo, a type of flower.', 'some Foo flowers.', 'a close-up of a Foo flower.', 'an image of a Foo blossom.', 'a beautiful Foo in bloom.', 'a bunch of Foo flowers.', 'a macro shot of a Foo flower.', 'a single Foo flower.', 'fresh Foo flowers in a garden.', 'a Foo flower in the wild.', 'a botanical photograph of a Foo.', 'a vibrant Foo bloom.', 'a Foo plant with flowers.', 'a Foo growing in nature.', 'a Foo flower in sunlight.', 'a colorful Foo flower close-up.', 'a Foo, commonly found in gardens.', 'wild Foo flowers blooming.', 'a garden filled with Foo flowers.', 'floral photography featuring a Foo.', 'an aesthetic photo of a Foo.', 'a Foo flower in full bloom.', 'a large blooming Foo.', 'a freshly picked Foo.', 'a wilted Foo flower.', 'a Foo with dewdrops on its petals.', 'a delicate Foo on a green stem.', 'a colorful bouquet with Foo.', 'a botanical illustration of Foo.', 'a herbarium specimen of Foo.', 'field ph

tensor([[[ 0.0075, -0.0435],
         [-0.0228, -0.0388],
         [-0.0130,  0.0512],
         ...,
         [ 0.0048,  0.0242],
         [ 0.0002,  0.0071],
         [ 0.0062,  0.0005]],

        [[ 0.0044, -0.0224],
         [-0.0175, -0.0142],
         [-0.0146,  0.0381],
         ...,
         [-0.0116, -0.0145],
         [-0.0079, -0.0091],
         [-0.0070, -0.0056]],

        [[-0.0155, -0.0438],
         [-0.0216, -0.0368],
         [-0.0066,  0.0457],
         ...,
         [ 0.0014,  0.0219],
         [ 0.0100,  0.0101],
         [ 0.0185,  0.0099]],

        ...,

        [[-0.0140, -0.0014],
         [-0.0831, -0.0115],
         [-0.0108,  0.0322],
         ...,
         [-0.0479, -0.0076],
         [-0.0432, -0.0148],
         [ 0.0068,  0.0228]],

        [[-0.0074, -0.0090],
         [-0.0310,  0.0085],
         [-0.0151,  0.0207],
         ...,
         [-0.0029, -0.0300],
         [-0.0027, -0.0009],
         [ 0.0246,  0.0055]],

        [[ 0.0266,  0.0274],
       

### Loading the dataset

In [None]:
if USE_DEFAULT_SPLIT:
    train_set, val_set, test_set = get_data(transform=preprocess)
else:
    train_set, val_set, test_set = get_data_custom_split(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

### Re-Weighter model

This is the core model at the base of our architecture.

It is a small FFNN that has the following Input/Outputs
 - **Input**: A one dimensional tensor (or two if considering batch size) that is
 the visual embedding of the image (i.e. the output of the visual encoder)
 - **Output**: A two dimensional tensor (or three if considering batch size) that is
 shaped `number of prompts` x `number of classes`

The output will then be used to re-weight the various prompts so to obtain a single score

#### Mathematical formulation

To put it in a more mathematical way, if we define the following notation:
 - $T_{ij}$: The text embedding of the class `j` generated with the prompt template `i`
 - $V$: The visual embedding of the image we are trying to classify
 - $W = Reweighter(V)$: The output of the re-weighter (shaped `number of prompts` x `number of classes` )
 - $w_{ij}$: This is the weight (i.e. estimated relative importance) that the prompt template `i` with class `j` has, in particular for the classification of the image that generated `V`
 - $s_j$: The absolute (i.e. weighted) score for the class `j`

We can then use the aforementioned notation to express the formula of $s_j$
$$
s_j =  \sum_{i=1}^{n\_prompts} w_{ij} (T_{ij} \cdot V)
$$
where $(T_{ij} \cdot V)$ is the dot product between the text and the image embedding representing the similarity


Finally we can get the classifier's output by taking the absolute score with the highest value
$$
class = argmax_j(s_j)
$$


#### A note on novel classes

At this point a question may emerge:
> If your model need to output a set of weights for each class, how will you be able to generalize to novel classes?

And the answer is quite straightforward: We add to the output of the model a set of weights that represent the "novel class"

The way they are trained is quite interesting, but we will discuss it, when it become relevant, for now you just need to
note that we have create a custom "indexing" layer at the end, that map the original model output from a dimension of `num_base_classes + 1`
to a dimension that depends on the number of classes we are trying to classify at a certain point in time.

In this example we can se that a model has being trained on 4 base classes (A,B,C and D) and than
it is adapted to be used in a context where 3 of the base classes are still relevant (A,B and D),
one is discarded (C) and 3 novel classes are added.


<img src="https://raw.githubusercontent.com/lucaSartore/CLIP-Few-shot/refs/heads/main/images/insexing_layer.png" alt="drawing" height="400"/>

This is the reason why the method `update_indexing_mask` is needed. It essentially re-configure the indexing layer to adapt the re-weighter to a context where the number of classes has changed


TODO: Add a section that talks about the parameters once they are being tuned

In [None]:
CLIP_EMBEDDING_SIZE = int(model.encode_text(clip.tokenize("foo").to(DEVICE)).shape[-1])
NUMBER_OF_PROMPTS = len(general_prompt_template) + len(class_specific_prompt_templates)


class ReWeighterModel(nn.Module):
    def __init__(self, classes_for_training: list[CategoryLabel], internal_size = 75):
        super().__init__()
        # we need to store the classes used for training in order to be able to update the indexing mask
        self.base_classes = [x for x in classes_for_training if not x["novel"]]
        self.num_base_classes = len(self.base_classes)
        # We use the update method for initialization here
        self.update_indexing_mask(classes_for_training)
        ####################### DESIGN_CHOICE ######################################
        # We observed that mapping the visual embedding in a small
        # embedding space before calculating the weights help
        # reducing overfitting, as we are "compressing" the
        # representation, and therefore reducing the model capacity.
        # We hypnotize that a low-dimensional space does not limit
        # much the performance of our model, as the task of re-weighting
        # is (at least intuitively) much simpler (and therefore lower dimensional)
        # than the task of comparing images and text
        ############################################################################
        # the first layer maps the visual embedding into a lower-dimensional space.
        self.l1 = nn.Linear(CLIP_EMBEDDING_SIZE, internal_size, dtype=model.dtype)
        # one output weight for each prompt, and for each class, plus one for novel classes
        self.l2 = nn.Linear(internal_size, NUMBER_OF_PROMPTS * (self.num_base_classes + 1), dtype=model.dtype)
        # sigmoid used to normalize the weights in the 0-1 range
        # TODO: idea, maybe allowing for negative weight can improve things?
        self.sigmoid = nn.Sigmoid()

    def forward(self, input: torch.Tensor):
        """
        input: [batch_size x clip_embedding_size] = the image-generated embeddings
        output: [batch_size x number_of_prompts x num_classes] = the weight estimated for each prompt-class pair
        """
        batch_size = input.shape[0]
        x: torch.Tensor = self.l1(input)
        x = self.l2(x)
        x = self.sigmoid(x)
        x = x.reshape(batch_size, NUMBER_OF_PROMPTS, self.num_base_classes+1)
        # shape: [batch_size, number_of_prompts, num_classes]
        x = x[:,:,self.indexing_mask]
        return x

    def update_indexing_mask(self, classes: list[CategoryLabel]):
        """
        Update the indexing mask.
        This function need to be used after a model has being trained for a specific set of classes,
        """

        # map the class id, to the position on on the model's last layer
        id_to_index = {x["id"]: i for i,x in enumerate(self.base_classes)}

        # given a class, it returns the index inside the model's last layer that should be used
        # to re-weight the classes's prompts.
        def get_index(label: CategoryLabel):
            # novel categories are always mapped to the generic re-weighter (the last one)
            # as we haven't learned a model for them
            if label["novel"]:
                return len(self.base_classes)
            # base classes instead have their specialized re-weighter at the corresponding index
            return id_to_index[label["id"]]

        # the indexing mask will be used to index into the output of the layer
        # and build a tensor that has the same size as the number of categories
        self.indexing_mask = torch.Tensor([
            get_index(x) for x in classes
        ]).type(torch.long)

    


### Creation of virtual novel samples
We mentioned in the previous section that the re-weighter has an output dedicated to "novel classes", but
how can we train it if we don't have access to the novel classes?

The answer is to use "Virtual novel classes" were we essentially duplicate all of the base classes and create a "novel" version of them.

So essentially, if (like in our case) the classes are 102, and classes with IDs 0 to 50 are base, and IDs 51 to 101 are novel
we create 51 new classes, and assign them IDs 102 to 152. Those new classes will have the same images, and same class names
as the base one, but will have a different ID, and they will be marked as "novel"

What this means is that all of the virtual novel classes will share a single re-weighter output (it will still be a tensor, because
there is one weight for each prompt, but it will not discriminate between the ID 102 and 103 for example).

By doing so we essentially force the re-weighter to learn a set of weight that can be used to 
discriminate all the virtual novel classes AT THE SAME TIME, and the assumption we made here is
that if a SINGLE set of weights is good at discriminating 51 different classes, the same set
will also be good at discriminating 51 new un-seen classes.

We think this assumption is reasonable if we also assume that the split of novel and base classes is not particularly biased.

In [None]:

# class names have being added to torch vision recently, but to avoid compatibility issues we added them here.
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]

# list of category labels of the base classes
base_classes_label: list[CategoryLabel] = [
    {
        "id": x,
        "name": CLASS_NAMES[x],
        "novel": False
    }
    for x in base_classes
]

# list of category labels for the novel classes
novel_classes_label: list[CategoryLabel] = [
    {
        "id": x,
        "name": CLASS_NAMES[x],
        "novel": True
    }
    for x in novel_classes
]

# list of category labels for all classes
all_classes_labels = base_classes_label + novel_classes_label

# list of category labels for base classes plus virtual-novel classes
# this is the one that will be used during training, and you can see
# that we are essentially duplicating the number of novel classes,
# keeping the name unchanged, but shifting the ID, so to avoid overlap.
base_and_virtual_novel_classes_labels = base_classes_label + [
    {
        "id": x + len(all_classes_labels),
        "name": CLASS_NAMES[x],
        "novel": False
    }
    for x in base_classes
]

T = TypeVar("T")
E = TypeVar("E")
class VirtualNovelDataset(Dataset, Generic[T,E]):
    """
    This is a custom dataset implementation, that create "virtual classes"
    In short id create "num_novel_classes" new training samples, that are
    assigned a new ID (by adding "to_add" at the original ID)
    """
    def __init__(self, dataset: Dataset[tuple[T,E]], num_novel_classes: int, to_add: E):
        self._dataset_len: int = len(dataset) #type: ignore
        self._dataset = dataset
        self._virtual_novel: list[tuple[T,E]] = []
        indexes = random.sample(range(self._dataset_len), num_novel_classes)
        for index in indexes:
            data, label = dataset[index]
            new_label: E = label + to_add #type: ignore
            self._virtual_novel.append((data, new_label))
        
    def __len__(self):
        return self._dataset_len + len(self._virtual_novel)

    def __getitem__(self, index: int) -> tuple[T,E]:
        # non virtual sample
        if index < self._dataset_len:
            return self._dataset[index]
        # virtual sample:
        else:
            return self._virtual_novel[index - self._dataset_len]
            
train_base_and_virtual_novel = VirtualNovelDataset(train_base, len(train_base), len(all_classes_labels))

### Training

This is the section where we train the re-weighter. Overall, the training is pretty standard, with only two key differences:
1. **The masking of some input prompts:** The "virtual novel" training samples, create an issue, assume we pass a picture of
a rose, through the model twice, the first time we pass it with the label "rose" and the second with the label "rose-novel".
Obviously the model can't predict both of them at the same time, and it will have to pick one, this results in the model being
wrong 50% of the times, and thus creating some instability during training.
To fix this issue we simply disabled one of the prompts during training, so for example, if I am evaluating "rose-novel" I disable
all the prompts associated with "rose" by setting the text embeddings to zero.
The exact implementation details can be seen in the code.

1. **Virtually larger batch size** We had to implement some logic (namely waiting a few batches before calling `zero_grad` and `step`)
to simulate larger batch sizes. We found that the re-weighter is extremely hard to train, if we have small batch sizes, on the other hand
we also found that the training is really memory intensive sometimes, therefore we we use the aforementioned technique to
achieve the same result as large batch sizes, without the same memory requirements.


In [None]:
def train(
        clip: clip.model.CLIP,
        weighter: ReWeighterModel ,
        dataset: Dataset[tuple[torch.Tensor, int]],
        categories: list[CategoryLabel],
        batch_size: int,
        num_steps: int,
        device: torch.device | str,
        batch_size_multiplier: int = 1,
        
    ):
    clip.eval()
    weighter.train()

    # We noted that the training is a bit unstable some times.
    # the loss start at 1.5, and gradually go down reaching 0.8, 
    # but then it jumps to 12 or something. We haven't figured out
    # the reason of this instability yet, but we "fixed" it by
    # saving the best model, and returning that. So in the worst case
    # scenario we can at least return this.
    best_model: ReWeighterModel | None = None
    best_loss: float = math.inf

    # Todo: Add some interesting comments on why we chose the optimizer/parameters
    # that we did, once we are done with the hyperparameters optimization
    # optimizer = torch.optim.AdamW(params = weighter.parameters())
    optimizer = torch.optim.SGD(params = weighter.parameters(), lr=5)

    # Remap labels into a contiguous set starting from zero
    # this is the same as what was done in the zero-shot example notebook:
    #   > contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
    # but is more efficient later on, as the indexing can be done natively
    # inside tensor, without using python structures
    contig_cat2idx = torch.zeros(1+max(cat["id"] for cat in categories)).long()
    for idx, cat in enumerate(categories):
        contig_cat2idx[cat["id"]] = idx
    contig_cat2idx.to(DEVICE)
    
    # loading the text features
    # size: num_prompts x clip_embedding_size x num_categories"]
    text_features = load_text_features(categories)
    print(text_features.shape)

    # these constants are useful when re-shaping some tensors later
    num_prompts = text_features.shape[0]
    num_classes = text_features.shape[2]

    # simple dataloader creation
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # definition of the loss
    loss_fn = nn.CrossEntropyLoss()

    progress_bar = tqdm(range(num_steps), "Steps") 
    for _ in progress_bar:
        image: torch.Tensor
        target: torch.Tensor
        losses: list[float] = []
        for i, (image, target) in enumerate(dataloader):


            # Converting the class indexes to contiguous indexes
            # Equivalent line, in the zero-shot notebook:
            #   > target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()
            # shape: [batch_size]
            target = contig_cat2idx[target]

            target = target.to(DEVICE)
            image = image.to(DEVICE)

            # calculating the image features
            with torch.no_grad():
                image_features: torch.Tensor = clip.encode_image(image)
                image_features /= image_features.norm(dim=-1, keepdim=True)


            # We need to be able to mask out some text features (to avoid issues
            # with virtual novel classes, as explained in markdown cell)
            # Do do so, we need to expand the text features, and add a dimension (batch size)
            # this is because every single element in a batch will have a different target
            # (and therefore a different input feature to mask)
            # shape: [batch_size x num_prompts x clip_embedding_size x num_classes]
            masked_text_features = text_features \
                .clone() \
                .unsqueeze(0) \
                .expand(
                    len(target), # batch size
                    num_prompts,
                    CLIP_EMBEDDING_SIZE,
                    num_classes
                ).clone()

            # the following 15 lines are used to mask the text feature in the correct place,
            # they are a bit hard to understand, however, here you can find a equivalent snippet of code
            # that is much more readable, even tho it is less efficient computationally
            # Note: 51 is the number of base classes
            # >     for i in range(51):
            # >         masked_text_features[target == i, :, :, i+51] = 0
            # >         masked_text_features[target == i+51, :, :, i] = 0
                

            # vector tell me which class i need to zero in each element of a batch size
            #shape: [batch_size]
            class_to_zero = (target + len(categories) / 2) % len(categories)
            class_to_zero = class_to_zero.long()
            #shape: [batch_size x num_classes]
            # I can use the one_hot notation to obtain exactly the boolean mask I need below
            class_to_zero = torch.nn.functional.one_hot(class_to_zero, num_classes).bool()

            # I need to invert the axis order so that the boolean mask can be used correctly
            # shape: [batch_size x num_classes x num_prompts x clip_embedding_size]
            masked_text_features = masked_text_features.permute(0,3,2,1)
            masked_text_features[class_to_zero,:,:] = 0
            # shape: [batch_size x num_prompts x clip_embedding_size x num_classes]
            # set teh correct order back
            masked_text_features = masked_text_features.permute(0,3,2,1)

            # computing the similarity scores (using the masked text features)
            # b = batch, e = embedding, p = prompts, c = classes
            # shape: [batch_size x num_prompts x num_classes]
            scores = torch.einsum("be,bpec -> bpc", image_features, masked_text_features.detach())

            # computing the weights
            # shape: [batch_size x num_prompts x num_classes]
            weights: torch.Tensor = weighter(image_features)
    

            # reweighing scores
            # scores = (weights * scores.permute(2,0,1)).permute(1,2,0)
            scores *= weights

            # summing up the scores of every different prompt
            # shape: [batch_size x num_classes]
            out = torch.sum(scores, dim=1)

            # calculating the loss, and updating the gradient
            loss: torch.Tensor = loss_fn(out, target)
            losses.append(loss.item())
            # this is to avoid having different losses (and therefore
            # different behaver of the optimizer) when the batch
            # size multiplier changes
            loss /= batch_size_multiplier
            loss.backward()
            if (i+1) % batch_size_multiplier == 0:
                # todo: add a note on gradient clipping when the hyperparameters are fully optimized
                # (we may even remove it, but for now it seem to work fine...)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5)  
                optimizer.step()
                optimizer.zero_grad()

        # repeat here, to avoid losing some information if the number
        # of batches is not divisible by batch_size_multiplier
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)  
        optimizer.step()
        optimizer.zero_grad()

        # printing the current loss in the progress bar
        final_loss = np.average(losses)
        progress_bar.set_postfix(loss=f"{final_loss:.4f}")

        # saving the best model to cpu
        if final_loss < best_loss:
            best_model = copy.deepcopy(weighter).to("cpu")

    # returning the best model, according to the loss evaluation.
    assert best_model is not None
    return best_model.to(device)


weighter = ReWeighterModel(base_and_virtual_novel_classes_labels, 75).to(DEVICE)


weighter = train(
    model,
    weighter,
    train_base_and_virtual_novel,
    base_and_virtual_novel_classes_labels,
    32,
    30,
    DEVICE,
    4
)

['a photo of a pink primrose, a type of flower.', 'a pink primrose flower.', 'a photo of some pink primrose, a type of flower.', 'some pink primrose flowers.', 'a close-up of a pink primrose flower.', 'an image of a pink primrose blossom.', 'a beautiful pink primrose in bloom.', 'a bunch of pink primrose flowers.', 'a macro shot of a pink primrose flower.', 'a single pink primrose flower.', 'fresh pink primrose flowers in a garden.', 'a pink primrose flower in the wild.', 'a botanical photograph of a pink primrose.', 'a vibrant pink primrose bloom.', 'a pink primrose plant with flowers.', 'a pink primrose growing in nature.', 'a pink primrose flower in sunlight.', 'a colorful pink primrose flower close-up.', 'a pink primrose, commonly found in gardens.', 'wild pink primrose flowers blooming.', 'a garden filled with pink primrose flowers.', 'floral photography featuring a pink primrose.', 'an aesthetic photo of a pink primrose.', 'a pink primrose flower in full bloom.', 'a large bloomin

KeyboardInterrupt: 

### Evaluation

In this section we evaluate teh model

In [None]:

@torch.no_grad()
def eval(
        clip: clip.model.CLIP,
        weighter: ReWeighterModel ,
        dataset: Dataset[tuple[torch.Tensor, int]],
        categories: list[CategoryLabel],
        batch_size: int,
        device: torch.device | str,
        label = ""
    ):
    # let's set the model in evaluation mode
    clip.eval()
    weighter.eval()

    # Remap labels into a contiguous set starting from zero
    contig_cat2idx = {cat["id"]: idx for idx, cat in enumerate(categories)}

    text_features = load_text_features(categories)

    # simple dataloader creation
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # here we store the number of correct predictions we will make
    correct_predictions = 0
    select_base_instead_of_novel = 0
    select_novel_instead_of_base = 0
    total_wrong_predictions = 0
    wrong_prediction_in_the_same_category = 0
    for image, target in tqdm(dataloader, desc=label):
        # base categories range from 0 to 50, whil novel ones from 51 to 101
        # therefore we must map categories to the [0, 50], otherwise we will have wrong predictions
        # Map targets in contiguous set starting from zero
        # Labels needs to be .long() in pytorch
        target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

        image = image.to(device)
        target = target.to(device)

        image_features = clip.encode_image(image)
        # and normalize
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # shape: [batch_size x num_prompts x num_classes]
        scores: torch.Tensor = torch.matmul(image_features, text_features).permute(1,0,2).detach()

        # shape: [ batch_size x (num_prompts x num_classes)]
        weights: torch.Tensor = weighter(image_features.detach())

        # reweighing scores
        # scores = (weights * scores.permute(2,0,1)).permute(1,2,0)
        scores *= weights

        out = torch.sum(scores, dim=1)
        predicted_class = out.argmax(dim=-1)

        select_base_instead_of_novel += ((predicted_class < 51) & (target >= 51)).sum().item()
        select_novel_instead_of_base += ((predicted_class >= 51) & (target < 51)).sum().item()
        total_wrong_predictions += (predicted_class != target).sum().item()
        wrong_prediction_in_the_same_category += (
            (predicted_class != target) &
            (
                (predicted_class < 51) & (target < 51) |
                (predicted_class >= 51) & (target >= 51) 
            )
        ).sum().item()

        (predicted_class != target).sum().item()
        # now we check which are correct, and sum them (False == 0, True == 1)
        correct_predictions += (predicted_class == target).sum().item()


    print("select_base_instead_of_novel: ", select_base_instead_of_novel)
    print("select_novel_instead_of_base: ", select_novel_instead_of_base)
    print("total_wrong_predictions: ", total_wrong_predictions)
    print("wrong_prediction_in_the_same_category: ", wrong_prediction_in_the_same_category)
    # and now we compute the accuracy
    accuracy = correct_predictions / len(dataset) #type: ignore
    return accuracy

weighter.update_indexing_mask(base_classes_label)
base_accuracy = eval(model, weighter, dataset=test_base, categories=base_classes_label, batch_size=32, device=DEVICE, label="🧠 Evaluation on Base Classes")

weighter.update_indexing_mask(novel_classes_label)
novel_accuracy = eval(model, weighter, dataset=test_novel, categories=novel_classes_label, batch_size=32, device=DEVICE, label="🧠 Evaluation on Novel Classes")

weighter.update_indexing_mask(all_classes_labels)
total_accuracy = eval(model, weighter, dataset=test_set, categories=all_classes_labels, batch_size=32, device=DEVICE, label="🧠 Evaluation on all Classes")

print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
print(f"🔍 Novel and base classes accuracy: {total_accuracy*100:.2f}%")


🧠 Evaluation on Base Classes: 100%|██████████| 84/84 [00:13<00:00,  6.35it/s]


select_base_instead_of_novel:  0
select_novel_instead_of_base:  0
total_wrong_predictions:  105
wrong_prediction_in_the_same_category:  105


🧠 Evaluation on Novel Classes: 100%|██████████| 122/122 [00:18<00:00,  6.75it/s]


select_base_instead_of_novel:  0
select_novel_instead_of_base:  0
total_wrong_predictions:  823
wrong_prediction_in_the_same_category:  823


🧠 Evaluation on all Classes: 100%|██████████| 205/205 [00:22<00:00,  9.10it/s]

select_base_instead_of_novel:  1636
select_novel_instead_of_base:  0
total_wrong_predictions:  1929
wrong_prediction_in_the_same_category:  293
🔍 Base classes accuracy: 96.08%
🔍 Novel classes accuracy: 78.79%
🔍 Novel and base classes accuracy: 70.58%





## Harmonic Mean
Few-Shot Adaptations papers usually report the Harmonic Mean.
The harmonic mean tends to mitigate the impact of large outliers (base accuracy) and aggravate the impact of small ones (novel accuracy).
Thus, achieving very high base accuracies at the expense of the novel accuracy will be penalized by the HM.

In [None]:
def harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    hm = numerator / denominator
    return hm

print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

🔍 Harmonic Mean: 86.58%
