# CLIP zero-shot Evaluation
This short notebook implements the dataset split into base and novel categories (see project assignment) and runs the zero-shot evaluation with CLIP.
Feel free to copy the code contained in this notebook or to directly use this notebook as starting point for you project.

In [2]:
# we need to install clip as it is not pre-installed
# you are also free to use open_clip which provide more models
# https://github.com/mlfoundations/open_clip
%pip install openai_clip

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [3]:
import os

import torch
import torchvision
import clip
from tqdm import tqdm
from typing import TypeVar, cast, TypedDict
from torch.utils.data import Dataset



class CategoryLabel(TypedDict):
    id: int
    name: str
    novel: bool

## Dataset Loading
Let's get the data directly from torchvision as we have seen during labs.

In [4]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return (
        cast(Dataset[tuple[torch.Tensor,int]], train),
        cast(Dataset[tuple[torch.Tensor,int]], val),
        cast(Dataset[tuple[torch.Tensor,int]], test)
    )

## Base and Novel categories
To split in base and novel categories we list all dataset classes, and count their number (we already know it's 102 but let's do it properly).
Then, we just allocate the first half to base categories and the remaining half to novel ones.
We can do this because we are simulating a real world application, but keep in mind this will not happen out there!

In [5]:
T = TypeVar("T")
E = TypeVar("E")
def base_novel_categories(dataset: Dataset[tuple[T,E]]):
    # set returns the unique set of all dataset classes
    all_classes = set(l for _, l in dataset)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

## Inspect Classes
Let's now visualize which are the base and novel classes.
To do so, we first get a dummy test set (without augmentations) as we are just interested in the dataset labels. Then, we split it useing `base_novel_categories`.
Finally, we use the hard-coded CLASS_NAMES to print the class in natural language.

> Note: the list of class names was only recently added to `torchvision.datasets.Flowers102`. To avoid useless errors that can occour to you, we decided to also provide such a list.

In [6]:
base_classes, novel_classes = base_novel_categories(get_data()[2])

CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
NAMES_TO_INDEX = {x: i for i,x in enumerate(CLASS_NAMES)}
print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

Base Class Names: [(0, 'pink primrose'), (1, 'hard-leaved pocket orchid'), (2, 'canterbury bells'), (3, 'sweet pea'), (4, 'english marigold'), (5, 'tiger lily'), (6, 'moon orchid'), (7, 'bird of paradise'), (8, 'monkshood'), (9, 'globe thistle'), (10, 'snapdragon'), (11, "colt's foot"), (12, 'king protea'), (13, 'spear thistle'), (14, 'yellow iris'), (15, 'globe-flower'), (16, 'purple coneflower'), (17, 'peruvian lily'), (18, 'balloon flower'), (19, 'giant white arum lily'), (20, 'fire lily'), (21, 'pincushion flower'), (22, 'fritillary'), (23, 'red ginger'), (24, 'grape hyacinth'), (25, 'corn poppy'), (26, 'prince of wales feathers'), (27, 'stemless gentian'), (28, 'artichoke'), (29, 'sweet william'), (30, 'carnation'), (31, 'garden phlox'), (32, 'love in the mist'), (33, 'mexican aster'), (34, 'alpine sea holly'), (35, 'ruby-lipped cattleya'), (36, 'cape flower'), (37, 'great masterwort'), (38, 'siam tulip'), (39, 'lenten rose'), (40, 'barbeton daisy'), (41, 'daffodil'), (42, 'sword 

## Split Dataset
The next step is to actually split the dataset into the base and novel categories we extract from `base_novel_categories`.
To split the data we need the dataset (obviously) and the list of base classes. If the sample label is not part of the base categories, then it must be part of the novel ones.

In [7]:
from torch.utils.data import Subset
from torch.utils.data import Dataset

T = TypeVar("T")
E = TypeVar("E")
def split_data(dataset: Dataset[tuple[T,E]], base_classes: list[int]):
    # these two lists will store the sample indexes
    base_categories_samples = []
    novel_categories_samples = []

    # we create a set of base classes to compute the test below in O(1)
    # this is optional and can be removed
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    labels = [l for _,l in dataset]
    for sample_id, label in enumerate(labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    # here we create the dataset subsets
    # the torch Subset is just a wrapper around the dataset
    # it simply stores the subset indexes and the original dataset (your_subset.dataset)
    # when asking for sample i in the subset, torch will look for its original position in the dataset and retrieve it
    # https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
    base_dataset = Subset(dataset, base_categories_samples)
    novel_dataset = Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Extract k shots
As the dataset already provides 10 train and validation shots, we do not need to extract them.
Beaware that Few-Shot Adaptation papers must do this operation as most datasets count significantly more samples in both the training and validation sets.

## Load CLIP

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
model, preprocess = clip.load("ViT-B/16", device=device)

## Load and Prepare Data
Here we get the three dataset split and pass clip pre-defined augmentations.
Then, we compute base and novel categories (in this case is redundand as we already did it before).
Finally, se split the three datasets into base and novel categories.
As we want to use the novel categories only for the test set, we drop `train_novel` and `val_novel`.

In [9]:
from typing import Callable

# defining the templates that we are going to use
prompt_template: list[Callable[[str], str]] = [
    # wrong prompt to see if is discarded
    lambda x: f"a photo of a {CLASS_NAMES[(NAMES_TO_INDEX[x]+1)%5]}, a type of flower.",
    lambda x: f"a photo of a {x}, a type of flower.",
    lambda x: f"a {x} flower.",
    lambda x: f"a photo of some {x}, a type of flower.",
    lambda x: f"some {x} flowers.",
    lambda x: f"a close-up of a {x} flower.",
    lambda x: f"an image of a {x} blossom.",
    lambda x: f"a beautiful {x} in bloom.",
    lambda x: f"a bunch of {x} flowers.",
    lambda x: f"a macro shot of a {x} flower.",
    lambda x: f"a single {x} flower.",
    lambda x: f"fresh {x} flowers in a garden.",
    lambda x: f"a {x} flower in the wild.",
    lambda x: f"a botanical photograph of a {x}.",
    lambda x: f"a vibrant {x} bloom.",
    lambda x: f"a {x} plant with flowers.",
    lambda x: f"a {x} growing in nature.",
    lambda x: f"a {x} flower in sunlight.",
    lambda x: f"a colorful {x} flower close-up.",
    lambda x: f"a {x}, commonly found in gardens.",
    lambda x: f"wild {x} flowers blooming.",
    lambda x: f"a garden filled with {x} flowers.",
    lambda x: f"floral photography featuring a {x}.",
    lambda x: f"an aesthetic photo of a {x}.",
    lambda x: f"a {x} flower in full bloom.",

    # Descriptive
    lambda x: f"a large blooming {x}.",
    lambda x: f"a freshly picked {x}.",
    lambda x: f"a wilted {x} flower.",
    lambda x: f"a {x} with dewdrops on its petals.",
    lambda x: f"a delicate {x} on a green stem.",
    lambda x: f"a colorful bouquet with {x}.",

    # Scientific-ish
    lambda x: f"a botanical illustration of {x}.",
    lambda x: f"a herbarium specimen of {x}.",
    lambda x: f"field photo of {x} species.",
    lambda x: f"{x} photographed for a flora study.",
    lambda x: f"a study sample of the {x} flower.",
    lambda x: f"{x} genus flower in bloom.",

    # Casual / Internet Style
    lambda x: f"my favorite flower: the {x}.",
    lambda x: f"saw a {x} today!",
    lambda x: f"check out this {x} flower!",
    lambda x: f"flowers like {x} are amazing.",
    lambda x: f"the {x} is blooming this season.",

    # Photographic / Artistic
    lambda x: f"an artistic photo of a {x}.",
    lambda x: f"film photo of a {x} flower.",
    lambda x: f"a {x} in black and white.",
    lambda x: f"the silhouette of a {x} in sunset light.",
    lambda x: f"a {x} flower in a vintage vase.",
    lambda x: f"an abstract painting of a {x}.",
    lambda x: f"macro photography of a {x} blossom.",

    # Poetic or Metaphorical
    lambda x: f"a {x}, soft as a whisper.",
    lambda x: f"a {x} dancing in the wind.",
    lambda x: f"petals of the {x}, kissed by rain.",
    lambda x: f"a lonely {x} on a quiet morning.",
    lambda x: f"a {x} symbolizing peace and beauty.",
    lambda x: f"like a {x} in springtime.",

    # Contextual / Scene-based
    lambda x: f"a {x} in a wildflower meadow.",
    lambda x: f"a {x} flower on a wedding table.",
    lambda x: f"{x} flowers in a forest clearing.",
    lambda x: f"a {x} growing beside a stone path.",
    lambda x: f"{x} blossoms in a city garden.",
    lambda x: f"{x} petals scattered on the ground.",

    # class specific templates
    # pink primrose
    lambda x: f"a photo of a {x}, a delicate flower with soft pink petals.",
    lambda x: f"an image of a {x}, a blooming plant often found in springtime gardens.",
    lambda x: f"a close-up of a {x}, known for its pale pink blossoms and gentle appearance.",

    # hard-leaved pocket orchid
    lambda x: f"a photo of a {x}, a tropical orchid with stiff, glossy leaves.",
    lambda x: f"an image of a {x}, an exotic flower with waxy petals and leathery foliage.",
    lambda x: f"a botanical image of a {x}, an orchid species with hard, durable leaves.",

    # canterbury bells
    lambda x: f"a photo of a {x}, a bell-shaped flower in shades of purple and blue.",
    lambda x: f"an image of a {x}, known for its tall spikes of bell-like blossoms.",
    lambda x: f"a close-up of a {x}, a cottage garden flower with cup-shaped blooms.",

    # sweet pea
    lambda x: f"a photo of a {x}, a fragrant flower with delicate, ruffled petals.",
    lambda x: f"an image of a {x}, often grown for its pastel colors and pleasant scent.",
    lambda x: f"a close-up of a {x}, a climbing plant with butterfly-shaped flowers.",

    # english marigold
    lambda x: f"a photo of a {x}, a bright orange or yellow flower with daisy-like blooms.",
    lambda x: f"an image of a {x}, known for its healing properties and sunny appearance.",
    lambda x: f"a close-up of a {x}, a calendula flower common in herb gardens.",

    # tiger lily
    lambda x: f"a photo of a {x}, an orange flower with dark spots and recurved petals.",
    lambda x: f"an image of a {x}, a wild-looking lily with dramatic coloring.",
    lambda x: f"a close-up of a {x}, known for its bold, tiger-striped blooms.",

    # moon orchid
    lambda x: f"a photo of a {x}, an elegant white orchid with a moon-like glow.",
    lambda x: f"an image of a {x}, a phalaenopsis flower often found in tropical climates.",
    lambda x: f"a close-up of a {x}, a soft and symmetrical flower with wide petals.",

    # bird of paradise
    lambda x: f"a photo of a {x}, a tropical flower resembling a colorful bird.",
    lambda x: f"an image of a {x}, known for its bright orange and blue petals.",
    lambda x: f"a close-up of a {x}, an exotic bloom that looks like a flying bird.",

    # monkshood
    lambda x: f"a photo of a {x}, a hooded purple flower with toxic properties.",
    lambda x: f"an image of a {x}, often called wolfsbane, with dark violet petals.",
    lambda x: f"a close-up of a {x}, a tall plant with helmet-shaped blooms.",

    # globe thistle
    lambda x: f"a photo of a {x}, a spherical flower with spiky blue petals.",
    lambda x: f"an image of a {x}, known for its round shape and thistle-like texture.",
    lambda x: f"a close-up of a {x}, a unique ornamental flower with a metallic hue.",

    # snapdragon
    lambda x: f"a photo of a {x}, a colorful flower that resembles a dragon's mouth.",
    lambda x: f"an image of a {x}, known for its vertical clusters of blooming petals.",
    lambda x: f"a close-up of a {x}, a common garden flower with hinged, snout-like blooms.",

    # colt's foot
    lambda x: f"a photo of a {x}, a yellow wildflower that appears before its leaves.",
    lambda x: f"an image of a {x}, a small flower resembling a dandelion in early spring.",
    lambda x: f"a close-up of a {x}, a plant with hoof-shaped leaves and bright yellow blooms.",

    # king protea
    lambda x: f"a photo of a {x}, a large flower with a spiky crown-like appearance.",
    lambda x: f"an image of a {x}, a South African bloom with a central cone and pink petals.",
    lambda x: f"a close-up of a {x}, a striking flower often used in bold arrangements.",

    # spear thistle
    lambda x: f"a photo of a {x}, a spiny plant with purple tufted blooms.",
    lambda x: f"an image of a {x}, known for its sharp leaves and thistle head.",
    lambda x: f"a close-up of a {x}, a wildflower with prickly stems and a vibrant purple flower.",

    # yellow iris
    lambda x: f"a photo of a {x}, a bright yellow iris with upright petals.",
    lambda x: f"an image of a {x}, commonly found near water, with sword-shaped leaves.",
    lambda x: f"a close-up of a {x}, an elegant flower with golden hues and frilled edges.",

    # globe-flower
    lambda x: f"a photo of a {x}, a round yellow bloom resembling a buttercup.",
    lambda x: f"an image of a {x}, a spherical flower found in alpine meadows.",
    lambda x: f"a close-up of a {x}, a glowing, globe-shaped flower with dense petals.",

    # purple coneflower
    lambda x: f"a photo of a {x}, a daisy-like flower with purple petals and a spiky cone.",
    lambda x: f"an image of a {x}, often used in herbal remedies and garden borders.",
    lambda x: f"a close-up of a {x}, known for its downward-sloping petals and orange center.",

    # peruvian lily
    lambda x: f"a photo of a {x}, a spotted flower with multiple colorful petals.",
    lambda x: f"an image of a {x}, a long-lasting bloom used in cut flower arrangements.",
    lambda x: f"a close-up of a {x}, a lily-like flower with striped inner petals.",

    # balloon flower
    lambda x: f"a photo of a {x}, a flower bud that inflates like a balloon before opening.",
    lambda x: f"an image of a {x}, a star-shaped bloom in shades of blue or purple.",
    lambda x: f"a close-up of a {x}, a unique flower with puffy unopened buds.",

    # giant white arum lily
    lambda x: f"a photo of a {x}, a large white flower with a trumpet-like shape.",
    lambda x: f"an image of a {x}, also known as a calla lily, with a central yellow spadix.",
    lambda x: f"a close-up of a {x}, an elegant flower with smooth white petals.",

    # fire lily
    lambda x: f"a photo of a {x}, a bright red or orange lily with curled petals.",
    lambda x: f"an image of a {x}, a dramatic flower known for its flame-like appearance.",
    lambda x: f"a close-up of a {x}, a fiery-looking lily with backward-bending petals.",

    # pincushion flower
    lambda x: f"a photo of a {x}, a flower with a domed center and delicate fringe.",
    lambda x: f"an image of a {x}, often purple or lavender, resembling a pin-filled cushion.",
    lambda x: f"a close-up of a {x}, known for its central disk and lace-like petals.",

    # fritillary
    lambda x: f"a photo of a {x}, a checkered bell-shaped flower often in purple tones.",
    lambda x: f"an image of a {x}, a rare flower with a distinctive petal pattern.",
    lambda x: f"a close-up of a {x}, a delicate wildflower with hanging, nodding blooms.",

    # red ginger
    lambda x: f"a photo of a {x}, a tropical flower with bright red bracts.",
    lambda x: f"an image of a {x}, known for its bold color and upright flower spikes.",
    lambda x: f"a close-up of a {x}, a striking plant native to rainforests.",

    # grape hyacinth
    lambda x: f"a photo of a {x}, a small bulbous plant with clusters of blue-purple flowers.",
    lambda x: f"an image of a {x}, resembling tiny grapes arranged on a spike.",
    lambda x: f"a close-up of a {x}, a spring flower with densely packed florets.",

    # corn poppy
    lambda x: f"a photo of a {x}, a bright red flower with papery petals and a dark center.",
    lambda x: f"an image of a {x}, a wild poppy often found in meadows and fields.",
    lambda x: f"a close-up of a {x}, a flower symbolizing remembrance and resilience.",

    # prince of wales feathers
    lambda x: f"a photo of a {x}, a spiky flower head resembling a feathery plume.",
    lambda x: f"an image of a {x}, known for its upright purple-pink floral spikes.",
    lambda x: f"a close-up of a {x}, a member of the amaranth family with plume-like blooms.",

    # stemless gentian
    lambda x: f"a photo of a {x}, a vivid blue flower growing close to the ground.",
    lambda x: f"an image of a {x}, a low-growing gentian with trumpet-shaped petals.",
    lambda x: f"a close-up of a {x}, a mountain flower with intensely blue blossoms.",

    # artichoke
    lambda x: f"a photo of a {x}, a thistle-like plant with edible buds and purple flowers.",
    lambda x: f"an image of a {x}, a spiky flower head that blooms into a vibrant violet.",
    lambda x: f"a close-up of a {x}, a large budding flower with a layered appearance.",

    # sweet william
    lambda x: f"a photo of a {x}, a cluster of small flowers in pink, red, or white.",
    lambda x: f"an image of a {x}, a garden flower known for its fringed petal edges.",
    lambda x: f"a close-up of a {x}, a fragrant bloom often used in cottage gardens.",

    # carnation
    lambda x: f"a photo of a {x}, a ruffled flower commonly seen in bouquets.",
    lambda x: f"an image of a {x}, a traditional bloom with a clove-like scent.",
    lambda x: f"a close-up of a {x}, known for its layered petals and vibrant color range.",

    # garden phlox
    lambda x: f"a photo of a {x}, a tall plant with clusters of pink or purple blooms.",
    lambda x: f"an image of a {x}, a perennial flower found in cottage-style gardens.",
    lambda x: f"a close-up of a {x}, a phlox with star-shaped petals growing in dense bunches.",

    # love in the mist
    lambda x: f"a photo of a {x}, a blue flower surrounded by fine, feathery foliage.",
    lambda x: f"an image of a {x}, a delicate flower with a misty, lace-like background.",
    lambda x: f"a close-up of a {x}, a whimsical bloom with soft, threadlike leaves.",

    # mexican aster
    lambda x: f"a photo of a {x}, a daisy-like flower with bright pink or purple petals.",
    lambda x: f"an image of a {x}, a tall annual bloom often seen in wildflower fields.",
    lambda x: f"a close-up of a {x}, a lightweight flower with yellow centers and soft petals.",

    # alpine sea holly
    lambda x: f"a photo of a {x}, a spiky blue flower with thistle-like bracts.",
    lambda x: f"an image of a {x}, a unique alpine plant with metallic-colored petals.",
    lambda x: f"a close-up of a {x}, a flower with a cone center and pointed star-shaped sepals.",

    # ruby-lipped cattleya
    lambda x: f"a photo of a {x}, an orchid with bold purple lips and pastel petals.",
    lambda x: f"an image of a {x}, a showy flower with ruby-colored accents on its lip.",
    lambda x: f"a close-up of a {x}, a fragrant orchid with elaborate ruffled petals.",

    # cape flower
    lambda x: f"a photo of a {x}, a brightly colored South African flower with daisy form.",
    lambda x: f"an image of a {x}, known for its vivid hues and sun-tracking behavior.",
    lambda x: f"a close-up of a {x}, a vibrant flower with a dark center and radiating petals.",

    # great masterwort
    lambda x: f"a photo of a {x}, a flower with a central cluster surrounded by papery bracts.",
    lambda x: f"an image of a {x}, a perennial bloom with intricate star-like heads.",
    lambda x: f"a close-up of a {x}, a soft-colored flower with detailed florets in the center.",

    # siam tulip
    lambda x: f"a photo of a {x}, a tropical flower with pink petals and green bracts.",
    lambda x: f"an image of a {x}, also known as Curcuma, with cone-shaped blossoms.",
    lambda x: f"a close-up of a {x}, a vibrant Thai flower with tulip-like structure.",

    # lenten rose
    lambda x: f"a photo of a {x}, a spring flower with nodding blooms and muted colors.",
    lambda x: f"an image of a {x}, a hellebore plant with leathery leaves and soft petals.",
    lambda x: f"a close-up of a {x}, a cold-hardy flower blooming in late winter to early spring.",

    # barbeton daisy
    lambda x: f"a photo of a {x}, a brightly colored daisy native to South Africa.",
    lambda x: f"an image of a {x}, known for its vivid red, orange, or pink petals.",
    lambda x: f"a close-up of a {x}, a cheerful flower with a prominent central disk.",

    # daffodil
    lambda x: f"a photo of a {x}, a trumpet-shaped flower with yellow or white petals.",
    lambda x: f"an image of a {x}, one of the first blooms of spring with a central corona.",
    lambda x: f"a close-up of a {x}, a classic bulb flower symbolizing renewal and hope.",

    # sword lily
    lambda x: f"a photo of a {x}, a tall flower with sword-like leaves and vertical blooms.",
    lambda x: f"an image of a {x}, commonly known as gladiolus with stacked florets.",
    lambda x: f"a close-up of a {x}, a showy flower in a rainbow of colors on long spikes.",

    # poinsettia
    lambda x: f"a photo of a {x}, a festive plant with red or white leaf-like bracts.",
    lambda x: f"an image of a {x}, often used in winter displays with green foliage and colorful tops.",
    lambda x: f"a close-up of a {x}, a holiday flower with bright petal-like leaves.",

    # bolero deep blue
    lambda x: f"a photo of a {x}, a compact flower with deep violet petals and ruffled texture.",
    lambda x: f"an image of a {x}, a variety of pansy known for its rich blue coloring.",
    lambda x: f"a close-up of a {x}, a velvety flower with intricate patterns and deep hues.",

    # wallflower
    lambda x: f"a photo of a {x}, a small clustered flower known for growing on walls or rocky soil.",
    lambda x: f"an image of a {x}, a plant with fragrant blooms in yellow, orange, or red.",
    lambda x: f"a close-up of a {x}, a simple flower with four-petal blossoms and warm colors.",

    # marigold
    lambda x: f"a photo of a {x}, a vibrant flower with layers of orange or yellow petals.",
    lambda x: f"an image of a {x}, known for its strong scent and decorative garden use.",
    lambda x: f"a close-up of a {x}, a sun-loving flower with round, bushy blooms.",

    # buttercup
    lambda x: f"a photo of a {x}, a shiny yellow flower with cup-shaped petals.",
    lambda x: f"an image of a {x}, a wildflower with a simple structure and glossy surface.",
    lambda x: f"a close-up of a {x}, a cheerful bloom with golden overlapping petals.",

    # oxeye daisy
    lambda x: f"a photo of a {x}, a white-petaled daisy with a yellow central disk.",
    lambda x: f"an image of a {x}, a classic meadow flower with a tall, slender stem.",
    lambda x: f"a close-up of a {x}, a widespread wildflower resembling a common daisy.",

    # common dandelion
    lambda x: f"a photo of a {x}, a yellow flower with toothed leaves and fluffy seed heads.",
    lambda x: f"an image of a {x}, a weedy plant known for its puffball seed dispersal.",
    lambda x: f"a close-up of a {x}, a golden flower head made of many tiny florets.",

    # petunia
    lambda x: f"a photo of a {x}, a funnel-shaped flower often used in hanging baskets.",
    lambda x: f"an image of a {x}, a colorful bloom available in many vibrant shades.",
    lambda x: f"a close-up of a {x}, a soft, velvety flower with a wide trumpet-like shape.",
]

In [10]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

In [11]:
@torch.no_grad()
def load_text_features(categories: list[CategoryLabel]):
    """
    size: num_prompts x num_categories x input_size"]
    """
    text_inputs = [
        clip.tokenize(
            [template(c['name']) for c in categories]
        ).to(device)
        for template in prompt_template
    ]

    text_features_array: list[torch.Tensor] = [
        model.encode_text(x)
        for x in text_inputs
    ]

    # shape: num_prompts x num_classes x embedding_size
    text_features = torch.stack([
        x/x.norm(dim=-1,keepdim=True)
        for x in text_features_array
    ])

    # shape: num_prompts x embedding_size x num_classes
    text_features = text_features.permute(0,2,1)

    return text_features


In [19]:
CLIP_EMBEDDING_SIZE = int(model.encode_text(clip.tokenize("foo").to(device)).shape[-1])
NUMBER_OF_PROMPTS = len(prompt_template)

from torch import nn

class ModelWeightingModel(nn.Module):
    def __init__(self, classes_for_training: list[CategoryLabel]):
        super().__init__()

        self.base_classes = [x for x in classes_for_training if not x["novel"]]
        self.update_indexing_mask(classes_for_training)
        self.num_base_classes = len(self.base_classes)
        self.l1 = nn.Linear(CLIP_EMBEDDING_SIZE, 200, dtype=model.dtype)
        # one output weight for each prompt, and for each class, plus one for novel classes
        self.l2 = nn.Linear(200, NUMBER_OF_PROMPTS * (self.num_base_classes + 1), dtype=model.dtype)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input: torch.Tensor):
        """
        input: [batch_size x clip_embedding_size] = the image-generated embeddings
        """
        x: torch.Tensor = self.l1(input)
        x = self.l2(x)
        x = self.sigmoid(x)
        batch_size = input.shape[0]
        x = x.reshape(batch_size, NUMBER_OF_PROMPTS, self.num_base_classes+1)
        x = x[:,:,self.indexing_mask]
        return x

    def update_indexing_mask(self, classes: list[CategoryLabel]):

        id_to_index = {x["id"]: i for i,x in enumerate(self.base_classes)}

        def get_index(label: CategoryLabel):
            # novel categories are always mapped to the generic re-weighter (the last one)
            # as we haven't learned a model for them
            if label["novel"]:
                return len(self.base_classes)
            # base classes instead have their specialized re-weighter at the corresponding index
            return id_to_index[label["id"]]

        # the indexing mask will be used to index into the output of the layer
        # and build a tensor that has the same size as the number of categories
        self.indexing_mask = torch.Tensor([
            get_index(x) for x in classes
        ]).type(torch.long)

    


In [13]:
from typing import Generic
import random


base_classes_label: list[CategoryLabel] = [
    {
        "id": x,
        "name": CLASS_NAMES[x],
        "novel": False
    }
    for x in base_classes
]

novel_classes_label: list[CategoryLabel] = [
    {
        "id": x,
        "name": CLASS_NAMES[x],
        "novel": True
    }
    for x in novel_classes
]

all_classes_labels = base_classes_label + novel_classes_label

base_and_virtual_novel_classes_labels = base_classes_label + [
    {
        "id": x + len(all_classes_labels),
        "name": CLASS_NAMES[x],
        "novel": False
    }
    for x in base_classes
]

train_base_and_virtual_novel = train_base
T = TypeVar("T")
E = TypeVar("E")
class VirtualNovelDataset(Dataset, Generic[T,E]):
    def __init__(self, dataset: Dataset[tuple[T,E]], num_novel_classes: int, to_add: E):
        self._dataset_len: int = len(dataset) #type: ignore
        self._dataset = dataset
        self._virtual_novel: list[tuple[T,E]] = []
        indexes = random.sample(range(self._dataset_len), num_novel_classes)
        for index in indexes:
            data, label = dataset[index]
            new_label: E = label + to_add #type: ignore
            self._virtual_novel.append((data, new_label))
        
    def __len__(self):
        return self._dataset_len + len(self._virtual_novel)

    def __getitem__(self, index: int) -> tuple[T,E]:
        # non virtual sample
        if index < self._dataset_len:
            return self._dataset[index]
        # virtual sample:
        else:
            return self._virtual_novel[index - self._dataset_len]
            
train_base_and_virtual_novel = VirtualNovelDataset(train_base, 200, len(all_classes_labels))


In [20]:
import clip.model
from typing import cast
import numpy as np


def train(
        clip: clip.model.CLIP,
        weighter: ModelWeightingModel ,
        dataset: Dataset[tuple[torch.Tensor, int]],
        categories: list[CategoryLabel],
        batch_size: int,
        num_steps: int,
        device: torch.device | str,
        batch_size_multiplier: int = 1
    ):
    clip.eval()
    weighter.train()


    # optimizer = torch.optim.AdamW(params = weighter.parameters())
    optimizer = torch.optim.SGD(params = weighter.parameters(), lr=5)


    contig_cat2idx = {cat["id"]: idx for idx, cat in enumerate(categories)}
    # size: num_prompts x clip_embedding_size x num_categories"]
    text_features = load_text_features(categories)
    print(text_features.shape)
    # size: num_category, tell me what class every text feature is referring to


    num_prompts = text_features.shape[0]
    num_classes = text_features.shape[2]

    # simple dataloader creation
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    loss_fn = torch.nn.CrossEntropyLoss()

    progress_bar = tqdm(range(num_steps), "Steps") 
    for _ in progress_bar:
        image: torch.Tensor
        target: torch.Tensor
        losses: list[float] = []
        for i, (image, target) in enumerate(dataloader):


            # shape: [batch_size]
            target = torch.Tensor([contig_cat2idx[cast(int,t.item())] for t in target]).long()


            image = image.to(device)
            target = target.to(device)

            with torch.no_grad():
                image_features: torch.Tensor = clip.encode_image(image)
                image_features /= image_features.norm(dim=-1, keepdim=True)


            # shape: [batch_size x num_prompts x clip_embedding_size x num_classes]
            masked_text_features = text_features \
                .clone() \
                .unsqueeze(0) \
                .expand(
                    len(target), # batch size
                    num_prompts,
                    CLIP_EMBEDDING_SIZE,
                    num_classes
                ).clone()

            #shape: [batch_size]
            class_to_zero = (target + len(categories) / 2) % len(categories)
            class_to_zero = class_to_zero.long()
            #shape: [batch_size x num_classes]
            class_to_zero = torch.nn.functional.one_hot(class_to_zero, num_classes).bool()

            # # shape: [batch_size x num_classes x num_prompts x clip_embedding_size]
            masked_text_features_eff = masked_text_features.permute(0,3,2,1)
            masked_text_features_eff[class_to_zero,:,:] = 0
            # shape: [batch_size x num_prompts x clip_embedding_size x num_classes]
            masked_text_features_eff = masked_text_features_eff.permute(0,3,2,1)

            # Alternative option
            # masked_text_features[target >= 51,:,:,:51] = 0
            # masked_text_features[target < 51,:,:,51:] = 0

            # b = batch, e = embedding, p = prompts, c = classes
            # shape: [batch_size x num_prompts x num_classes]
            scores = torch.einsum("be,bpec -> bpc", image_features, masked_text_features.detach())

            # shape: [batch_size x num_prompts x num_classes]
            weights: torch.Tensor = weighter(image_features.detach())

            # reweighing scores
            # scores = (weights * scores.permute(2,0,1)).permute(1,2,0)
            scores *= weights

            # shape: [batch_size x num_classes]
            out = torch.sum(scores, dim=1)

            loss: torch.Tensor = loss_fn(out, target)
            losses.append(loss.item())
            loss /= batch_size_multiplier
            loss.backward()
            if (i+1) % batch_size_multiplier == 0:
                optimizer.step()
                optimizer.zero_grad()
        
        optimizer.step()
        optimizer.zero_grad()

        final_loss = np.average(losses)
        progress_bar.set_postfix(loss=f"{final_loss:.4f}")




            

weighter = ModelWeightingModel(base_and_virtual_novel_classes_labels).to(device)


train(
    model,
    weighter,
    train_base_and_virtual_novel,
    base_and_virtual_novel_classes_labels,
    32,
    40,
    device,
    4
)

torch.Size([214, 512, 102])


Steps: 100%|██████████| 40/40 [02:34<00:00,  3.85s/it, loss=0.1262]


## Compute Zero-Shot Predictions

In [None]:
@torch.no_grad() # we don't want gradients
def eval(
        clip: clip.model.CLIP,
        weighter: ModelWeightingModel ,
        dataset: Dataset[tuple[torch.Tensor, int]],
        categories: list[CategoryLabel],
        batch_size: int,
        device: torch.device | str,
        label = ""
    ):
    # let's set the model in evaluation mode
    clip.eval()
    weighter.eval()


    # Remap labels into a contiguous set starting from zero
    contig_cat2idx = {cat["id"]: idx for idx, cat in enumerate(categories)}

    text_features = load_text_features(categories)

    # simple dataloader creation
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # here we store the number of correct predictions we will make
    correct_predictions = 0
    for image, target in tqdm(dataloader, desc=label):
        # base categories range from 0 to 50, whil novel ones from 51 to 101
        # therefore we must map categories to the [0, 50], otherwise we will have wrong predictions
        # Map targets in contiguous set starting from zero
        # Labels needs to be .long() in pytorch
        target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

        image = image.to(device)
        target = target.to(device)

        image_features = clip.encode_image(image)
        # and normalize
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # shape: [batch_size x num_prompts x num_classes]
        scores: torch.Tensor = torch.matmul(image_features, text_features).permute(1,0,2).detach()

        # shape: [ batch_size x (num_prompts x num_classes)]
        weights: torch.Tensor = weighter(image_features.detach())

        # reweighing scores
        # scores = (weights * scores.permute(2,0,1)).permute(1,2,0)
        scores *= weights

        out = torch.sum(scores, dim=1)
        predicted_class = out.argmax(dim=-1)
        
        # now we check which are correct, and sum them (False == 0, True == 1)
        correct_predictions += (predicted_class == target).sum().item()

    # and now we compute the accuracy
    accuracy = correct_predictions / len(dataset) #type: ignore
    return accuracy

# weighter.update_indexing_mask(base_classes_label)
# base_accuracy = eval(model, weighter, dataset=test_base, categories=base_classes_label, batch_size=32, device=device, label="🧠 Zero-shot evaluation on Base Classes")

# weighter.update_indexing_mask(novel_classes_label)
# novel_accuracy = eval(model, weighter, dataset=test_novel, categories=novel_classes_label, batch_size=32, device=device, label="🧠 Zero-shot evaluation on Novel Classes")

weighter.update_indexing_mask(all_classes_labels)
total_accuracy = eval(model, weighter, dataset=test_set, categories=all_classes_labels, batch_size=32, device=device, label="🧠 Zero-shot evaluation on Novel Classes")

# print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
# print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {total_accuracy*100:.2f}%")


🧠 Zero-shot evaluation on Base Classes: 100%|██████████| 78/78 [00:12<00:00,  6.10it/s]
🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:15<00:00,  7.34it/s]

🔍 Base classes accuracy: 94.38%
🔍 Novel classes accuracy: 78.16%





## Harmonic Mean
Few-Shot Adaptations papers usually report the Harmonic Mean.
The harmonic mean tends to mitigate the impact of large outliers (base accuracy) and aggravate the impact of small ones (novel accuracy).
Thus, achieving very high base accuracies at the expense of the novel accuracy will be penalized by the HM.

In [22]:
def harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    hm = numerator / denominator
    return hm

print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

🔍 Harmonic Mean: 85.50%
