# CLIP zero-shot Evaluation
This short notebook implements the dataset split into base and novel categories (see project assignment) and runs the zero-shot evaluation with CLIP.
Feel free to copy the code contained in this notebook or to directly use this notebook as starting point for you project.

In [19]:
# we need to install clip as it is not pre-installed
# you are also free to use open_clip which provide more models
# https://github.com/mlfoundations/open_clip
%pip install openai_clip

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [20]:
import os

import torch
import torchvision
import clip
from tqdm import tqdm


## Dataset Loading
Let's get the data directly from torchvision as we have seen during labs.

In [21]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

## Base and Novel categories
To split in base and novel categories we list all dataset classes, and count their number (we already know it's 102 but let's do it properly).
Then, we just allocate the first half to base categories and the remaining half to novel ones.
We can do this because we are simulating a real world application, but keep in mind this will not happen out there!

In [22]:
def base_novel_categories(dataset):
    # set returns the unique set of all dataset classes
    all_classes = set(dataset._labels)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

## Inspect Classes
Let's now visualize which are the base and novel classes.
To do so, we first get a dummy test set (without augmentations) as we are just interested in the dataset labels. Then, we split it useing `base_novel_categories`.
Finally, we use the hard-coded CLASS_NAMES to print the class in natural language.

> Note: the list of class names was only recently added to `torchvision.datasets.Flowers102`. To avoid useless errors that can occour to you, we decided to also provide such a list.

In [23]:
_, _, tmp_test = get_data()
base_classes, novel_classes = base_novel_categories(tmp_test)
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

Base Class Names: [(0, 'pink primrose'), (1, 'hard-leaved pocket orchid'), (2, 'canterbury bells'), (3, 'sweet pea'), (4, 'english marigold'), (5, 'tiger lily'), (6, 'moon orchid'), (7, 'bird of paradise'), (8, 'monkshood'), (9, 'globe thistle'), (10, 'snapdragon'), (11, "colt's foot"), (12, 'king protea'), (13, 'spear thistle'), (14, 'yellow iris'), (15, 'globe-flower'), (16, 'purple coneflower'), (17, 'peruvian lily'), (18, 'balloon flower'), (19, 'giant white arum lily'), (20, 'fire lily'), (21, 'pincushion flower'), (22, 'fritillary'), (23, 'red ginger'), (24, 'grape hyacinth'), (25, 'corn poppy'), (26, 'prince of wales feathers'), (27, 'stemless gentian'), (28, 'artichoke'), (29, 'sweet william'), (30, 'carnation'), (31, 'garden phlox'), (32, 'love in the mist'), (33, 'mexican aster'), (34, 'alpine sea holly'), (35, 'ruby-lipped cattleya'), (36, 'cape flower'), (37, 'great masterwort'), (38, 'siam tulip'), (39, 'lenten rose'), (40, 'barbeton daisy'), (41, 'daffodil'), (42, 'sword 

## Split Dataset
The next step is to actually split the dataset into the base and novel categories we extract from `base_novel_categories`.
To split the data we need the dataset (obviously) and the list of base classes. If the sample label is not part of the base categories, then it must be part of the novel ones.

In [24]:
def split_data(dataset, base_classes):
    # these two lists will store the sample indexes
    base_categories_samples = []
    novel_categories_samples = []

    # we create a set of base classes to compute the test below in O(1)
    # this is optional and can be removed
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    # here we create the dataset subsets
    # the torch Subset is just a wrapper around the dataset
    # it simply stores the subset indexes and the original dataset (your_subset.dataset)
    # when asking for sample i in the subset, torch will look for its original position in the dataset and retrieve it
    # https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Extract k shots
As the dataset already provides 10 train and validation shots, we do not need to extract them.
Beaware that Few-Shot Adaptation papers must do this operation as most datasets count significantly more samples in both the training and validation sets.

## Load CLIP

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
model, preprocess = clip.load("ViT-B/16", device=device)

# preprocess contains CLIP's pre-defined augmentations, let's inspect them!
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x000001CF2E1F1D00>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

## Load and Prepare Data
Here we get the three dataset split and pass clip pre-defined augmentations.
Then, we compute base and novel categories (in this case is redundand as we already did it before).
Finally, se split the three datasets into base and novel categories.
As we want to use the novel categories only for the test set, we drop `train_novel` and `val_novel`.

In [34]:
from typing import Callable

# defining the templates that we are going to use
prompt_template: list[Callable[[str], str]] = [
    lambda x: f"a photo of a {x}, a type of flower.",
    lambda x: f"a {x} flower.",
    lambda x: f"a photo of some {x}, a type of flower.",
    lambda x: f"some {x} flowers.",
    lambda x: f"a close-up of a {x} flower.",
    lambda x: f"an image of a {x} blossom.",
    lambda x: f"a beautiful {x} in bloom.",
    lambda x: f"a bunch of {x} flowers.",
    lambda x: f"a macro shot of a {x} flower.",
    lambda x: f"a single {x} flower.",
    lambda x: f"fresh {x} flowers in a garden.",
    lambda x: f"a {x} flower in the wild.",
    lambda x: f"a botanical photograph of a {x}.",
    lambda x: f"a vibrant {x} bloom.",
    lambda x: f"a {x} plant with flowers.",
    lambda x: f"a {x} growing in nature.",
    lambda x: f"a {x} flower in sunlight.",
    lambda x: f"a colorful {x} flower close-up.",
    lambda x: f"a {x}, commonly found in gardens.",
    lambda x: f"wild {x} flowers blooming.",
    lambda x: f"a garden filled with {x} flowers.",
    lambda x: f"floral photography featuring a {x}.",
    lambda x: f"an aesthetic photo of a {x}.",
    lambda x: f"a {x} flower in full bloom.",

    # Descriptive
    lambda x: f"a large blooming {x}.",
    lambda x: f"a freshly picked {x}.",
    lambda x: f"a wilted {x} flower.",
    lambda x: f"a {x} with dewdrops on its petals.",
    lambda x: f"a delicate {x} on a green stem.",
    lambda x: f"a colorful bouquet with {x}.",

    # Scientific-ish
    lambda x: f"a botanical illustration of {x}.",
    lambda x: f"a herbarium specimen of {x}.",
    lambda x: f"field photo of {x} species.",
    lambda x: f"{x} photographed for a flora study.",
    lambda x: f"a study sample of the {x} flower.",
    lambda x: f"{x} genus flower in bloom.",

    # Casual / Internet Style
    lambda x: f"my favorite flower: the {x}.",
    lambda x: f"saw a {x} today!",
    lambda x: f"check out this {x} flower!",
    lambda x: f"flowers like {x} are amazing.",
    lambda x: f"the {x} is blooming this season.",

    # Photographic / Artistic
    lambda x: f"an artistic photo of a {x}.",
    lambda x: f"film photo of a {x} flower.",
    lambda x: f"a {x} in black and white.",
    lambda x: f"the silhouette of a {x} in sunset light.",
    lambda x: f"a {x} flower in a vintage vase.",
    lambda x: f"an abstract painting of a {x}.",
    lambda x: f"macro photography of a {x} blossom.",

    # Poetic or Metaphorical
    lambda x: f"a {x}, soft as a whisper.",
    lambda x: f"a {x} dancing in the wind.",
    lambda x: f"petals of the {x}, kissed by rain.",
    lambda x: f"a lonely {x} on a quiet morning.",
    lambda x: f"a {x} symbolizing peace and beauty.",
    lambda x: f"like a {x} in springtime.",

    # Contextual / Scene-based
    lambda x: f"a {x} in a wildflower meadow.",
    lambda x: f"a {x} flower on a wedding table.",
    lambda x: f"{x} flowers in a forest clearing.",
    lambda x: f"a {x} growing beside a stone path.",
    lambda x: f"{x} blossoms in a city garden.",
    lambda x: f"{x} petals scattered on the ground.",

]

In [35]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

In [36]:
from typing import Annotated
@torch.no_grad()
def load_text_features(categories: list[int]):
    """
    size: num_prompts x num_categories x input_size"]
    """
    text_inputs = [
        clip.tokenize(
            [template(CLASS_NAMES[c]) for c in categories]
        ).to(device)
        for template in prompt_template
    ]

    text_features_array: list[torch.Tensor] = [
        model.encode_text(x)
        for x in text_inputs
    ]

    # shape: num_prompts x num_classes x embedding_size
    text_features = torch.stack([
        x/x.norm(dim=-1,keepdim=True)
        for x in text_features_array
    ])


    # shape: num_prompts x embedding_size x num_classes
    text_features = text_features.permute(0,2,1)

    return text_features


In [37]:
clip_embedding_size = int(model.encode_text(clip.tokenize("foo").to(device)).shape[-1])
num_of_prompts = len(prompt_template)


from torch import nn

class ModelWeightingModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.weights = nn.Linear(clip_embedding_size, num_of_prompts, dtype=model.dtype)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input: torch.Tensor):
        """
        input: [batch_size x clip_embedding_size] = the image-generated embeddings
        """
        x = self.weights(input)
        x = self.sigmoid(x)
        return x


In [47]:
import clip.model
from typing import cast


def train(
        clip: clip.model.CLIP,
        weighter: ModelWeightingModel ,
        dataset: torch.utils.data.Subset[tuple[torch.Tensor, torch.Tensor]],
        categories: list[int],
        batch_size: int,
        num_steps: int,
        device: torch.device | str
    ):
    clip.eval()
    weighter.train()


    # optimizer = torch.optim.AdamW(params = weighter.parameters(), lr=0.00001)
    optimizer = torch.optim.SGD(params = weighter.parameters(), lr=5)


    contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
    # size: num_prompts x num_categories x input_size"]
    text_features = load_text_features(categories)

    # simple dataloader creation
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    loss_fn = torch.nn.CrossEntropyLoss()

    for _ in tqdm(range(num_steps), "Steps"):


        image: torch.Tensor
        target: torch.Tensor
        for image, target in dataloader:

            optimizer.zero_grad()
            target = torch.Tensor([contig_cat2idx[cast(int,t.item())] for t in target]).long()


            image = image.to(device)
            target = target.to(device)

            with torch.no_grad():
                image_features: torch.Tensor = clip.encode_image(image)
                image_features /= image_features.norm(dim=-1, keepdim=True)


            # shape: [batch_size x num_prompts x num_classes]
            scores: torch.Tensor = torch.matmul(image_features, text_features).permute(1,0,2).detach()

            # shape: [ batch_size x num_prompts]
            weights: torch.Tensor = weighter(image_features.detach())
            # reweighing scores
            scores = (weights * scores.permute(2,0,1)).permute(1,2,0)

            # shape: [batch_size x num_classes]
            out = torch.sum(scores, dim=1)

            # target_matrix = torch.nn.functional.one_hot(target, num_classes=out.shape[1])
            # target_matrix = target_matrix.to(device).type(out.type())

            loss: torch.Tensor = loss_fn(out, target)

            print(f"loss: {loss}")

            loss.backward()
            optimizer.step()





            

weighter = ModelWeightingModel().to(device)

train(
    model,
    weighter,
    train_base,
    base_classes,
    128,
    40,
    device
)
novel_classes   

Steps:   0%|          | 0/40 [00:00<?, ?it/s]

loss: 1.82421875
loss: 1.830078125
loss: 1.94140625
loss: 1.8037109375


Steps:   2%|▎         | 1/40 [00:05<03:43,  5.72s/it]

loss: 1.677734375
loss: 1.775390625
loss: 1.701171875
loss: 1.611328125


Steps:   5%|▌         | 2/40 [00:11<03:43,  5.87s/it]

loss: 1.705078125
loss: 1.630859375
loss: 1.4619140625
loss: 1.572265625


Steps:   8%|▊         | 3/40 [00:17<03:35,  5.82s/it]

loss: 1.6083984375
loss: 1.5810546875
loss: 1.60546875
loss: 1.31640625


Steps:  10%|█         | 4/40 [00:23<03:26,  5.75s/it]

loss: 1.447265625
loss: 1.484375
loss: 1.4609375
loss: 1.544921875


Steps:  12%|█▎        | 5/40 [00:28<03:18,  5.68s/it]

loss: 1.5908203125
loss: 1.34765625
loss: 1.4228515625
loss: 1.44921875


Steps:  15%|█▌        | 6/40 [00:34<03:12,  5.65s/it]

loss: 1.4072265625
loss: 1.3642578125
loss: 1.474609375
loss: 1.46875


Steps:  18%|█▊        | 7/40 [00:39<03:05,  5.63s/it]

loss: 1.4736328125
loss: 1.337890625
loss: 1.30859375
loss: 1.5234375


Steps:  20%|██        | 8/40 [00:45<02:59,  5.62s/it]

loss: 1.65625
loss: 1.30859375
loss: 1.3076171875
loss: 1.30859375


Steps:  22%|██▎       | 9/40 [00:51<02:54,  5.63s/it]

loss: 1.173828125
loss: 1.4580078125
loss: 1.4765625
loss: 1.427734375


Steps:  25%|██▌       | 10/40 [00:56<02:48,  5.63s/it]

loss: 1.3466796875
loss: 1.25390625
loss: 1.3857421875
loss: 1.509765625


Steps:  28%|██▊       | 11/40 [01:02<02:43,  5.65s/it]

loss: 1.2431640625
loss: 1.373046875
loss: 1.486328125
loss: 1.359375


Steps:  30%|███       | 12/40 [01:08<02:38,  5.65s/it]

loss: 1.310546875
loss: 1.4736328125
loss: 1.3310546875
loss: 1.31640625


Steps:  32%|███▎      | 13/40 [01:13<02:31,  5.62s/it]

loss: 1.32421875
loss: 1.478515625
loss: 1.2470703125
loss: 1.359375


Steps:  35%|███▌      | 14/40 [01:19<02:26,  5.62s/it]

loss: 1.4189453125
loss: 1.330078125
loss: 1.1884765625
loss: 1.451171875


Steps:  38%|███▊      | 15/40 [01:24<02:20,  5.63s/it]

loss: 1.3896484375
loss: 1.2685546875
loss: 1.2763671875
loss: 1.4345703125


Steps:  40%|████      | 16/40 [01:30<02:14,  5.62s/it]

loss: 1.4951171875
loss: 1.1318359375
loss: 1.265625
loss: 1.4599609375


Steps:  42%|████▎     | 17/40 [01:36<02:10,  5.68s/it]

loss: 1.3037109375
loss: 1.2412109375
loss: 1.3994140625
loss: 1.3916015625


Steps:  45%|████▌     | 18/40 [01:42<02:05,  5.70s/it]

loss: 1.4287109375
loss: 1.4130859375
loss: 1.1806640625
loss: 1.2998046875


Steps:  48%|████▊     | 19/40 [01:47<01:59,  5.68s/it]

loss: 1.298828125
loss: 1.3154296875
loss: 1.3681640625
loss: 1.3271484375


Steps:  50%|█████     | 20/40 [01:53<01:54,  5.72s/it]

loss: 1.169921875
loss: 1.166015625
loss: 1.439453125
loss: 1.52734375


Steps:  52%|█████▎    | 21/40 [01:59<01:48,  5.71s/it]

loss: 1.3359375
loss: 1.224609375
loss: 1.2529296875
loss: 1.4765625


Steps:  55%|█████▌    | 22/40 [02:05<01:44,  5.79s/it]

loss: 1.37109375
loss: 1.1494140625
loss: 1.44140625
loss: 1.318359375


Steps:  57%|█████▊    | 23/40 [02:10<01:38,  5.77s/it]

loss: 1.3818359375
loss: 1.109375
loss: 1.4482421875
loss: 1.33203125


Steps:  60%|██████    | 24/40 [02:16<01:31,  5.75s/it]

loss: 1.240234375
loss: 1.43359375
loss: 1.125
loss: 1.466796875


Steps:  62%|██████▎   | 25/40 [02:22<01:25,  5.70s/it]

loss: 1.3046875
loss: 1.2978515625
loss: 1.4150390625
loss: 1.236328125


Steps:  65%|██████▌   | 26/40 [02:27<01:19,  5.69s/it]

loss: 1.3818359375
loss: 1.0849609375
loss: 1.404296875
loss: 1.3779296875


Steps:  68%|██████▊   | 27/40 [02:33<01:14,  5.72s/it]

loss: 1.41796875
loss: 1.0869140625
loss: 1.4453125
loss: 1.29296875


Steps:  70%|███████   | 28/40 [02:39<01:08,  5.74s/it]

loss: 1.3515625
loss: 1.265625
loss: 1.328125
loss: 1.2900390625


Steps:  72%|███████▎  | 29/40 [02:45<01:02,  5.70s/it]

loss: 1.2060546875
loss: 1.3388671875
loss: 1.3515625
loss: 1.333984375


Steps:  75%|███████▌  | 30/40 [02:50<00:57,  5.72s/it]

loss: 1.18359375
loss: 1.3837890625
loss: 1.3154296875
loss: 1.3427734375


Steps:  78%|███████▊  | 31/40 [02:56<00:51,  5.73s/it]

loss: 1.259765625
loss: 1.416015625
loss: 1.318359375
loss: 1.2236328125


Steps:  80%|████████  | 32/40 [03:02<00:46,  5.75s/it]

loss: 1.4755859375
loss: 1.2646484375
loss: 1.4091796875
loss: 1.0625


Steps:  82%|████████▎ | 33/40 [03:08<00:40,  5.75s/it]

loss: 1.2373046875
loss: 1.3408203125
loss: 1.2763671875
loss: 1.357421875


Steps:  85%|████████▌ | 34/40 [03:13<00:34,  5.72s/it]

loss: 1.3583984375
loss: 1.173828125
loss: 1.42578125
loss: 1.248046875


Steps:  88%|████████▊ | 35/40 [03:19<00:28,  5.71s/it]

loss: 1.2060546875
loss: 1.388671875
loss: 1.0869140625
loss: 1.5244140625


Steps:  90%|█████████ | 36/40 [03:25<00:22,  5.74s/it]

loss: 1.2509765625
loss: 1.5068359375
loss: 1.2626953125
loss: 1.17578125


Steps:  92%|█████████▎| 37/40 [03:30<00:17,  5.71s/it]

loss: 1.33203125
loss: 1.2431640625
loss: 1.32421875
loss: 1.2958984375


Steps:  95%|█████████▌| 38/40 [03:36<00:11,  5.73s/it]

loss: 1.25
loss: 1.4140625
loss: 1.35546875
loss: 1.1708984375


Steps:  98%|█████████▊| 39/40 [03:42<00:05,  5.76s/it]

loss: 1.46484375
loss: 1.15625
loss: 1.111328125
loss: 1.458984375


Steps: 100%|██████████| 40/40 [03:48<00:00,  5.71s/it]


[51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101]

## Compute Zero-Shot Predictions

In [None]:
@torch.no_grad() # we don't want gradients


def eval(
        clip: clip.model.CLIP,
        weighter: ModelWeightingModel ,
        dataset: torch.utils.data.Subset[tuple[torch.Tensor, torch.Tensor]],
        categories: list[int],
        batch_size: int,
        device: torch.device | str,
        label = ""
    ):
    # let's set the model in evaluation mode
    clip.eval()
    weighter.eval()


    # Remap labels into a contiguous set starting from zero
    contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

    text_features = load_text_features(categories)

    # simple dataloader creation
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # here we store the number of correct predictions we will make
    correct_predictions = 0
    for image, target in tqdm(dataloader, desc=label):
        # base categories range from 0 to 50, whil novel ones from 51 to 101
        # therefore we must map categories to the [0, 50], otherwise we will have wrong predictions
        # Map targets in contiguous set starting from zero
        # Labels needs to be .long() in pytorch
        target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

        image = image.to(device)
        target = target.to(device)

        image_features = clip.encode_image(image)
        # and normalize
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # shape: [batch_size x num_prompts x num_classes]
        scores: torch.Tensor = torch.matmul(image_features, text_features).permute(1,0,2)

        # shape: [ batch_size x num_prompts]
        weights: torch.Tensor = weighter(image_features.detach())
        print(weights[0])
        # reweighing scores
        scores = (weights * scores.permute(2,0,1)).permute(1,2,0)

        out = torch.sum(scores, dim=1)
        predicted_class = out.argmax(dim=-1)
        
        # now we check which are correct, and sum them (False == 0, True == 1)
        correct_predictions += (predicted_class == target).sum().item()

    # and now we compute the accuracy
    accuracy = correct_predictions / len(dataset)
    return accuracy

base_accuracy = eval(model, weighter, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Base Classes")
novel_accuracy = eval(model, weighter, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Novel Classes")

print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")


🧠 Zero-shot evaluation on Base Classes:   5%|▌         | 1/20 [00:07<02:23,  7.53s/it]

tensor([0.9751, 0.9761, 0.9751, 0.9766, 0.9771, 0.9751, 0.9790, 0.9766, 0.9761,
        0.9756, 0.9771, 0.9736, 0.9766, 0.9775, 0.9756, 0.9712, 0.9766, 0.9766,
        0.9766, 0.9775, 0.9756, 0.9707, 0.9702, 0.9780, 0.9751, 0.9746, 0.9756,
        0.9761, 0.9741, 0.9673, 0.9741, 0.9673, 0.9717, 0.9771, 0.9736, 0.9805,
        0.9800, 0.9678, 0.9761, 0.9780, 0.9775, 0.9712, 0.9751, 0.9692, 0.9688,
        0.9736, 0.9683, 0.9761, 0.9717, 0.9727, 0.9751, 0.9712, 0.9722, 0.9531,
        0.9761, 0.9727, 0.9775, 0.9736, 0.9780, 0.9746], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  10%|█         | 2/20 [00:07<00:58,  3.27s/it]

tensor([0.9741, 0.9746, 0.9741, 0.9756, 0.9756, 0.9736, 0.9775, 0.9751, 0.9746,
        0.9741, 0.9766, 0.9727, 0.9756, 0.9761, 0.9746, 0.9702, 0.9751, 0.9751,
        0.9751, 0.9761, 0.9741, 0.9688, 0.9678, 0.9775, 0.9746, 0.9731, 0.9736,
        0.9751, 0.9727, 0.9663, 0.9722, 0.9648, 0.9702, 0.9756, 0.9727, 0.9790,
        0.9785, 0.9663, 0.9751, 0.9761, 0.9761, 0.9688, 0.9746, 0.9668, 0.9673,
        0.9722, 0.9663, 0.9741, 0.9688, 0.9702, 0.9731, 0.9697, 0.9707, 0.9507,
        0.9741, 0.9712, 0.9771, 0.9717, 0.9766, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  15%|█▌        | 3/20 [00:08<00:32,  1.91s/it]

tensor([0.9727, 0.9722, 0.9727, 0.9746, 0.9741, 0.9722, 0.9761, 0.9736, 0.9736,
        0.9727, 0.9751, 0.9717, 0.9736, 0.9746, 0.9736, 0.9692, 0.9736, 0.9746,
        0.9731, 0.9746, 0.9727, 0.9673, 0.9673, 0.9751, 0.9722, 0.9717, 0.9717,
        0.9736, 0.9712, 0.9644, 0.9707, 0.9639, 0.9688, 0.9741, 0.9707, 0.9785,
        0.9771, 0.9653, 0.9741, 0.9746, 0.9746, 0.9673, 0.9731, 0.9663, 0.9653,
        0.9707, 0.9653, 0.9731, 0.9688, 0.9688, 0.9722, 0.9683, 0.9697, 0.9492,
        0.9727, 0.9702, 0.9761, 0.9707, 0.9761, 0.9717], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  20%|██        | 4/20 [00:08<00:20,  1.27s/it]

tensor([0.9751, 0.9746, 0.9751, 0.9761, 0.9766, 0.9746, 0.9780, 0.9751, 0.9756,
        0.9756, 0.9761, 0.9731, 0.9761, 0.9766, 0.9751, 0.9712, 0.9761, 0.9761,
        0.9766, 0.9771, 0.9751, 0.9702, 0.9688, 0.9775, 0.9746, 0.9746, 0.9741,
        0.9751, 0.9736, 0.9663, 0.9731, 0.9663, 0.9707, 0.9761, 0.9731, 0.9800,
        0.9790, 0.9668, 0.9756, 0.9771, 0.9761, 0.9702, 0.9746, 0.9683, 0.9683,
        0.9727, 0.9683, 0.9756, 0.9712, 0.9717, 0.9736, 0.9712, 0.9717, 0.9512,
        0.9756, 0.9717, 0.9771, 0.9731, 0.9775, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  25%|██▌       | 5/20 [00:08<00:13,  1.09it/s]

tensor([0.9707, 0.9722, 0.9707, 0.9731, 0.9731, 0.9707, 0.9751, 0.9727, 0.9731,
        0.9717, 0.9741, 0.9697, 0.9722, 0.9736, 0.9717, 0.9673, 0.9727, 0.9727,
        0.9731, 0.9736, 0.9722, 0.9668, 0.9644, 0.9751, 0.9717, 0.9697, 0.9702,
        0.9727, 0.9697, 0.9629, 0.9692, 0.9624, 0.9663, 0.9731, 0.9702, 0.9771,
        0.9766, 0.9629, 0.9727, 0.9741, 0.9736, 0.9648, 0.9712, 0.9629, 0.9639,
        0.9702, 0.9634, 0.9722, 0.9658, 0.9668, 0.9702, 0.9663, 0.9683, 0.9478,
        0.9717, 0.9688, 0.9756, 0.9688, 0.9736, 0.9712], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  30%|███       | 6/20 [00:08<00:09,  1.42it/s]

tensor([0.9751, 0.9766, 0.9751, 0.9766, 0.9771, 0.9751, 0.9785, 0.9766, 0.9761,
        0.9761, 0.9775, 0.9731, 0.9766, 0.9775, 0.9756, 0.9717, 0.9766, 0.9766,
        0.9766, 0.9775, 0.9756, 0.9707, 0.9697, 0.9780, 0.9756, 0.9751, 0.9751,
        0.9761, 0.9741, 0.9678, 0.9736, 0.9673, 0.9717, 0.9771, 0.9741, 0.9805,
        0.9800, 0.9683, 0.9766, 0.9775, 0.9775, 0.9707, 0.9756, 0.9688, 0.9692,
        0.9736, 0.9683, 0.9756, 0.9717, 0.9722, 0.9751, 0.9712, 0.9727, 0.9531,
        0.9756, 0.9731, 0.9775, 0.9731, 0.9780, 0.9751], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  35%|███▌      | 7/20 [00:09<00:07,  1.76it/s]

tensor([0.9746, 0.9746, 0.9746, 0.9761, 0.9766, 0.9741, 0.9780, 0.9751, 0.9751,
        0.9751, 0.9766, 0.9736, 0.9756, 0.9766, 0.9746, 0.9712, 0.9751, 0.9761,
        0.9761, 0.9771, 0.9751, 0.9702, 0.9692, 0.9771, 0.9741, 0.9741, 0.9741,
        0.9751, 0.9736, 0.9663, 0.9731, 0.9663, 0.9707, 0.9766, 0.9731, 0.9805,
        0.9790, 0.9673, 0.9756, 0.9766, 0.9766, 0.9697, 0.9746, 0.9688, 0.9683,
        0.9727, 0.9683, 0.9751, 0.9712, 0.9717, 0.9736, 0.9712, 0.9717, 0.9517,
        0.9751, 0.9722, 0.9771, 0.9727, 0.9775, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  40%|████      | 8/20 [00:09<00:05,  2.08it/s]

tensor([0.9731, 0.9722, 0.9727, 0.9746, 0.9751, 0.9722, 0.9761, 0.9741, 0.9731,
        0.9731, 0.9746, 0.9722, 0.9736, 0.9746, 0.9736, 0.9692, 0.9741, 0.9741,
        0.9736, 0.9751, 0.9727, 0.9683, 0.9658, 0.9756, 0.9722, 0.9712, 0.9717,
        0.9741, 0.9717, 0.9639, 0.9707, 0.9634, 0.9683, 0.9741, 0.9717, 0.9780,
        0.9775, 0.9644, 0.9736, 0.9751, 0.9746, 0.9668, 0.9722, 0.9663, 0.9648,
        0.9702, 0.9653, 0.9736, 0.9683, 0.9688, 0.9717, 0.9683, 0.9692, 0.9492,
        0.9731, 0.9707, 0.9761, 0.9712, 0.9756, 0.9712], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  45%|████▌     | 9/20 [00:09<00:04,  2.38it/s]

tensor([0.9736, 0.9741, 0.9736, 0.9756, 0.9761, 0.9731, 0.9775, 0.9751, 0.9746,
        0.9746, 0.9761, 0.9727, 0.9746, 0.9761, 0.9741, 0.9697, 0.9751, 0.9751,
        0.9756, 0.9761, 0.9741, 0.9688, 0.9673, 0.9771, 0.9741, 0.9722, 0.9731,
        0.9751, 0.9722, 0.9653, 0.9722, 0.9653, 0.9697, 0.9756, 0.9722, 0.9790,
        0.9785, 0.9653, 0.9746, 0.9766, 0.9761, 0.9678, 0.9736, 0.9668, 0.9668,
        0.9717, 0.9663, 0.9741, 0.9692, 0.9702, 0.9727, 0.9692, 0.9702, 0.9517,
        0.9746, 0.9717, 0.9771, 0.9722, 0.9766, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  50%|█████     | 10/20 [00:10<00:03,  2.64it/s]

tensor([0.9722, 0.9722, 0.9717, 0.9731, 0.9736, 0.9707, 0.9751, 0.9727, 0.9717,
        0.9717, 0.9741, 0.9692, 0.9727, 0.9741, 0.9717, 0.9673, 0.9727, 0.9731,
        0.9731, 0.9741, 0.9722, 0.9663, 0.9653, 0.9746, 0.9717, 0.9712, 0.9712,
        0.9722, 0.9697, 0.9624, 0.9702, 0.9624, 0.9663, 0.9727, 0.9697, 0.9771,
        0.9761, 0.9634, 0.9727, 0.9751, 0.9736, 0.9663, 0.9712, 0.9648, 0.9653,
        0.9692, 0.9644, 0.9712, 0.9673, 0.9688, 0.9697, 0.9673, 0.9683, 0.9463,
        0.9722, 0.9683, 0.9741, 0.9697, 0.9751, 0.9712], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  55%|█████▌    | 11/20 [00:10<00:03,  2.84it/s]

tensor([0.9727, 0.9727, 0.9727, 0.9746, 0.9746, 0.9727, 0.9766, 0.9746, 0.9736,
        0.9736, 0.9746, 0.9722, 0.9741, 0.9746, 0.9731, 0.9692, 0.9736, 0.9741,
        0.9736, 0.9751, 0.9731, 0.9688, 0.9678, 0.9761, 0.9727, 0.9727, 0.9727,
        0.9736, 0.9717, 0.9648, 0.9717, 0.9629, 0.9697, 0.9746, 0.9717, 0.9785,
        0.9780, 0.9648, 0.9736, 0.9751, 0.9751, 0.9683, 0.9731, 0.9663, 0.9658,
        0.9702, 0.9653, 0.9736, 0.9692, 0.9697, 0.9712, 0.9697, 0.9692, 0.9482,
        0.9736, 0.9702, 0.9756, 0.9712, 0.9761, 0.9712], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  60%|██████    | 12/20 [00:10<00:02,  3.01it/s]

tensor([0.9722, 0.9722, 0.9727, 0.9736, 0.9746, 0.9722, 0.9756, 0.9731, 0.9727,
        0.9722, 0.9741, 0.9702, 0.9731, 0.9746, 0.9727, 0.9688, 0.9736, 0.9727,
        0.9731, 0.9741, 0.9712, 0.9673, 0.9658, 0.9756, 0.9717, 0.9707, 0.9717,
        0.9727, 0.9707, 0.9634, 0.9697, 0.9629, 0.9673, 0.9736, 0.9702, 0.9771,
        0.9766, 0.9629, 0.9722, 0.9746, 0.9741, 0.9663, 0.9712, 0.9639, 0.9644,
        0.9692, 0.9639, 0.9722, 0.9683, 0.9688, 0.9712, 0.9673, 0.9673, 0.9468,
        0.9727, 0.9692, 0.9746, 0.9697, 0.9741, 0.9712], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  65%|██████▌   | 13/20 [00:11<00:02,  3.13it/s]

tensor([0.9746, 0.9756, 0.9751, 0.9761, 0.9766, 0.9741, 0.9775, 0.9751, 0.9746,
        0.9751, 0.9771, 0.9727, 0.9756, 0.9771, 0.9746, 0.9712, 0.9756, 0.9761,
        0.9761, 0.9766, 0.9741, 0.9702, 0.9688, 0.9771, 0.9746, 0.9736, 0.9746,
        0.9756, 0.9731, 0.9663, 0.9731, 0.9658, 0.9702, 0.9766, 0.9731, 0.9800,
        0.9790, 0.9668, 0.9756, 0.9766, 0.9766, 0.9692, 0.9741, 0.9683, 0.9678,
        0.9722, 0.9673, 0.9746, 0.9712, 0.9712, 0.9736, 0.9707, 0.9717, 0.9512,
        0.9746, 0.9717, 0.9766, 0.9727, 0.9775, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  70%|███████   | 14/20 [00:11<00:01,  3.23it/s]

tensor([0.9746, 0.9746, 0.9746, 0.9761, 0.9761, 0.9741, 0.9775, 0.9756, 0.9746,
        0.9751, 0.9766, 0.9731, 0.9756, 0.9766, 0.9751, 0.9707, 0.9751, 0.9761,
        0.9761, 0.9771, 0.9746, 0.9697, 0.9688, 0.9771, 0.9746, 0.9746, 0.9741,
        0.9751, 0.9731, 0.9658, 0.9736, 0.9663, 0.9712, 0.9766, 0.9736, 0.9800,
        0.9790, 0.9673, 0.9756, 0.9771, 0.9766, 0.9702, 0.9746, 0.9683, 0.9678,
        0.9727, 0.9683, 0.9751, 0.9712, 0.9717, 0.9736, 0.9707, 0.9717, 0.9517,
        0.9751, 0.9717, 0.9771, 0.9727, 0.9775, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  75%|███████▌  | 15/20 [00:11<00:01,  3.27it/s]

tensor([0.9712, 0.9717, 0.9717, 0.9727, 0.9731, 0.9707, 0.9746, 0.9722, 0.9717,
        0.9717, 0.9731, 0.9692, 0.9722, 0.9727, 0.9722, 0.9663, 0.9722, 0.9727,
        0.9727, 0.9731, 0.9707, 0.9663, 0.9644, 0.9736, 0.9712, 0.9697, 0.9707,
        0.9717, 0.9697, 0.9624, 0.9692, 0.9609, 0.9668, 0.9727, 0.9692, 0.9771,
        0.9761, 0.9624, 0.9722, 0.9736, 0.9731, 0.9653, 0.9707, 0.9644, 0.9634,
        0.9688, 0.9639, 0.9712, 0.9678, 0.9668, 0.9697, 0.9663, 0.9673, 0.9453,
        0.9717, 0.9678, 0.9731, 0.9688, 0.9746, 0.9697], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  80%|████████  | 16/20 [00:11<00:01,  3.33it/s]

tensor([0.9756, 0.9756, 0.9761, 0.9775, 0.9775, 0.9746, 0.9785, 0.9766, 0.9761,
        0.9761, 0.9771, 0.9746, 0.9766, 0.9780, 0.9761, 0.9722, 0.9761, 0.9771,
        0.9771, 0.9780, 0.9756, 0.9707, 0.9702, 0.9785, 0.9756, 0.9751, 0.9756,
        0.9761, 0.9746, 0.9678, 0.9746, 0.9678, 0.9717, 0.9771, 0.9741, 0.9810,
        0.9800, 0.9678, 0.9766, 0.9780, 0.9771, 0.9712, 0.9756, 0.9697, 0.9692,
        0.9736, 0.9692, 0.9761, 0.9717, 0.9727, 0.9751, 0.9722, 0.9727, 0.9536,
        0.9761, 0.9731, 0.9780, 0.9741, 0.9785, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  85%|████████▌ | 17/20 [00:12<00:00,  3.34it/s]

tensor([0.9746, 0.9756, 0.9751, 0.9766, 0.9761, 0.9736, 0.9785, 0.9761, 0.9751,
        0.9756, 0.9766, 0.9727, 0.9756, 0.9771, 0.9756, 0.9707, 0.9756, 0.9761,
        0.9761, 0.9771, 0.9741, 0.9692, 0.9688, 0.9775, 0.9751, 0.9736, 0.9746,
        0.9756, 0.9731, 0.9663, 0.9731, 0.9658, 0.9712, 0.9761, 0.9731, 0.9800,
        0.9795, 0.9668, 0.9751, 0.9771, 0.9766, 0.9697, 0.9746, 0.9678, 0.9678,
        0.9727, 0.9668, 0.9756, 0.9707, 0.9712, 0.9741, 0.9702, 0.9717, 0.9521,
        0.9756, 0.9717, 0.9771, 0.9722, 0.9780, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  90%|█████████ | 18/20 [00:12<00:00,  3.38it/s]

tensor([0.9741, 0.9751, 0.9736, 0.9756, 0.9751, 0.9731, 0.9775, 0.9756, 0.9741,
        0.9746, 0.9761, 0.9722, 0.9741, 0.9756, 0.9746, 0.9692, 0.9751, 0.9751,
        0.9746, 0.9756, 0.9736, 0.9683, 0.9673, 0.9775, 0.9741, 0.9722, 0.9736,
        0.9751, 0.9717, 0.9653, 0.9717, 0.9644, 0.9697, 0.9751, 0.9722, 0.9790,
        0.9790, 0.9653, 0.9746, 0.9761, 0.9761, 0.9678, 0.9741, 0.9663, 0.9663,
        0.9717, 0.9648, 0.9741, 0.9683, 0.9692, 0.9727, 0.9683, 0.9697, 0.9512,
        0.9741, 0.9707, 0.9766, 0.9712, 0.9761, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Base Classes:  95%|█████████▌| 19/20 [00:12<00:00,  3.31it/s]

tensor([0.9751, 0.9761, 0.9751, 0.9766, 0.9766, 0.9746, 0.9780, 0.9761, 0.9756,
        0.9756, 0.9766, 0.9731, 0.9761, 0.9771, 0.9756, 0.9712, 0.9761, 0.9761,
        0.9766, 0.9771, 0.9751, 0.9702, 0.9692, 0.9780, 0.9751, 0.9746, 0.9751,
        0.9756, 0.9736, 0.9663, 0.9741, 0.9663, 0.9717, 0.9766, 0.9736, 0.9800,
        0.9795, 0.9673, 0.9756, 0.9775, 0.9771, 0.9702, 0.9751, 0.9683, 0.9683,
        0.9727, 0.9678, 0.9756, 0.9712, 0.9722, 0.9746, 0.9707, 0.9717, 0.9517,
        0.9756, 0.9712, 0.9771, 0.9731, 0.9780, 0.9736], device='cuda:0',
       dtype=torch.float16)
tensor([0.9746, 0.9751, 0.9741, 0.9761, 0.9761, 0.9736, 0.9780, 0.9756, 0.9746,
        0.9751, 0.9766, 0.9731, 0.9756, 0.9766, 0.9751, 0.9707, 0.9756, 0.9761,
        0.9761, 0.9766, 0.9746, 0.9692, 0.9688, 0.9771, 0.9746, 0.9741, 0.9751,
        0.9746, 0.9731, 0.9663, 0.9736, 0.9663, 0.9712, 0.9761, 0.9731, 0.9795,
        0.9790, 0.9668, 0.9756, 0.9775, 0.9761, 0.9697, 0.9746, 0.9688, 0.9678,
        0.9727, 0.

🧠 Zero-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:13<00:00,  1.50it/s]
🧠 Zero-shot evaluation on Novel Classes:   3%|▎         | 1/29 [00:07<03:33,  7.62s/it]

tensor([0.9746, 0.9761, 0.9751, 0.9766, 0.9771, 0.9746, 0.9780, 0.9756, 0.9756,
        0.9761, 0.9771, 0.9731, 0.9761, 0.9766, 0.9751, 0.9707, 0.9761, 0.9766,
        0.9761, 0.9771, 0.9751, 0.9707, 0.9692, 0.9775, 0.9746, 0.9741, 0.9751,
        0.9756, 0.9736, 0.9668, 0.9736, 0.9663, 0.9717, 0.9771, 0.9736, 0.9805,
        0.9795, 0.9663, 0.9756, 0.9771, 0.9771, 0.9702, 0.9751, 0.9692, 0.9683,
        0.9727, 0.9678, 0.9756, 0.9712, 0.9722, 0.9741, 0.9712, 0.9722, 0.9526,
        0.9756, 0.9722, 0.9766, 0.9731, 0.9780, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:   7%|▋         | 2/29 [00:07<01:29,  3.31s/it]

tensor([0.9746, 0.9756, 0.9746, 0.9761, 0.9766, 0.9746, 0.9775, 0.9756, 0.9751,
        0.9756, 0.9766, 0.9731, 0.9756, 0.9771, 0.9751, 0.9707, 0.9756, 0.9761,
        0.9756, 0.9761, 0.9746, 0.9697, 0.9688, 0.9775, 0.9746, 0.9736, 0.9746,
        0.9756, 0.9736, 0.9668, 0.9731, 0.9658, 0.9712, 0.9766, 0.9731, 0.9800,
        0.9795, 0.9663, 0.9756, 0.9771, 0.9766, 0.9697, 0.9746, 0.9688, 0.9678,
        0.9727, 0.9673, 0.9746, 0.9702, 0.9712, 0.9736, 0.9702, 0.9712, 0.9517,
        0.9756, 0.9717, 0.9766, 0.9727, 0.9771, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  10%|█         | 3/29 [00:08<00:50,  1.94s/it]

tensor([0.9746, 0.9761, 0.9756, 0.9766, 0.9771, 0.9746, 0.9785, 0.9761, 0.9756,
        0.9756, 0.9775, 0.9746, 0.9761, 0.9775, 0.9756, 0.9717, 0.9756, 0.9766,
        0.9766, 0.9771, 0.9751, 0.9707, 0.9697, 0.9775, 0.9751, 0.9746, 0.9751,
        0.9761, 0.9736, 0.9673, 0.9736, 0.9673, 0.9717, 0.9771, 0.9741, 0.9810,
        0.9800, 0.9683, 0.9761, 0.9771, 0.9771, 0.9707, 0.9751, 0.9692, 0.9688,
        0.9731, 0.9688, 0.9756, 0.9717, 0.9717, 0.9746, 0.9717, 0.9727, 0.9531,
        0.9756, 0.9727, 0.9775, 0.9731, 0.9780, 0.9746], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  14%|█▍        | 4/29 [00:08<00:32,  1.29s/it]

tensor([0.9756, 0.9761, 0.9751, 0.9766, 0.9771, 0.9746, 0.9785, 0.9761, 0.9756,
        0.9761, 0.9771, 0.9741, 0.9766, 0.9771, 0.9756, 0.9722, 0.9761, 0.9766,
        0.9771, 0.9775, 0.9751, 0.9707, 0.9697, 0.9780, 0.9756, 0.9751, 0.9751,
        0.9756, 0.9741, 0.9668, 0.9741, 0.9668, 0.9717, 0.9766, 0.9736, 0.9805,
        0.9795, 0.9678, 0.9761, 0.9780, 0.9771, 0.9712, 0.9756, 0.9688, 0.9688,
        0.9731, 0.9688, 0.9756, 0.9712, 0.9722, 0.9741, 0.9717, 0.9722, 0.9526,
        0.9761, 0.9722, 0.9780, 0.9731, 0.9780, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  17%|█▋        | 5/29 [00:08<00:22,  1.07it/s]

tensor([0.9756, 0.9761, 0.9756, 0.9766, 0.9766, 0.9741, 0.9785, 0.9761, 0.9756,
        0.9756, 0.9771, 0.9736, 0.9756, 0.9775, 0.9756, 0.9717, 0.9761, 0.9766,
        0.9766, 0.9775, 0.9746, 0.9707, 0.9692, 0.9775, 0.9751, 0.9741, 0.9751,
        0.9761, 0.9736, 0.9668, 0.9736, 0.9658, 0.9707, 0.9766, 0.9736, 0.9805,
        0.9795, 0.9673, 0.9756, 0.9771, 0.9771, 0.9697, 0.9751, 0.9683, 0.9688,
        0.9727, 0.9678, 0.9756, 0.9717, 0.9717, 0.9741, 0.9712, 0.9722, 0.9521,
        0.9756, 0.9717, 0.9771, 0.9727, 0.9785, 0.9746], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  21%|██        | 6/29 [00:09<00:16,  1.40it/s]

tensor([0.9741, 0.9756, 0.9746, 0.9756, 0.9761, 0.9741, 0.9780, 0.9751, 0.9751,
        0.9756, 0.9771, 0.9727, 0.9756, 0.9766, 0.9746, 0.9707, 0.9756, 0.9761,
        0.9761, 0.9771, 0.9741, 0.9702, 0.9692, 0.9775, 0.9741, 0.9731, 0.9741,
        0.9751, 0.9731, 0.9663, 0.9727, 0.9663, 0.9707, 0.9761, 0.9727, 0.9805,
        0.9790, 0.9663, 0.9751, 0.9766, 0.9771, 0.9692, 0.9741, 0.9683, 0.9683,
        0.9727, 0.9673, 0.9746, 0.9712, 0.9712, 0.9741, 0.9707, 0.9717, 0.9521,
        0.9746, 0.9717, 0.9771, 0.9722, 0.9771, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  24%|██▍       | 7/29 [00:09<00:12,  1.72it/s]

tensor([0.9741, 0.9756, 0.9741, 0.9761, 0.9756, 0.9731, 0.9780, 0.9756, 0.9751,
        0.9751, 0.9766, 0.9722, 0.9751, 0.9766, 0.9751, 0.9702, 0.9756, 0.9761,
        0.9761, 0.9766, 0.9741, 0.9692, 0.9678, 0.9771, 0.9751, 0.9736, 0.9741,
        0.9751, 0.9731, 0.9658, 0.9727, 0.9658, 0.9707, 0.9756, 0.9727, 0.9800,
        0.9790, 0.9663, 0.9751, 0.9761, 0.9761, 0.9697, 0.9741, 0.9668, 0.9678,
        0.9727, 0.9663, 0.9751, 0.9702, 0.9707, 0.9741, 0.9707, 0.9717, 0.9521,
        0.9751, 0.9712, 0.9771, 0.9722, 0.9775, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  28%|██▊       | 8/29 [00:09<00:10,  2.03it/s]

tensor([0.9751, 0.9756, 0.9756, 0.9771, 0.9771, 0.9751, 0.9790, 0.9766, 0.9756,
        0.9761, 0.9775, 0.9746, 0.9766, 0.9771, 0.9756, 0.9717, 0.9761, 0.9771,
        0.9761, 0.9775, 0.9756, 0.9712, 0.9697, 0.9780, 0.9751, 0.9746, 0.9751,
        0.9761, 0.9741, 0.9668, 0.9741, 0.9668, 0.9722, 0.9771, 0.9741, 0.9805,
        0.9800, 0.9678, 0.9766, 0.9775, 0.9771, 0.9707, 0.9756, 0.9697, 0.9692,
        0.9731, 0.9683, 0.9761, 0.9717, 0.9722, 0.9741, 0.9712, 0.9727, 0.9526,
        0.9761, 0.9727, 0.9775, 0.9741, 0.9780, 0.9746], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  31%|███       | 9/29 [00:10<00:08,  2.32it/s]

tensor([0.9746, 0.9751, 0.9746, 0.9756, 0.9761, 0.9736, 0.9780, 0.9756, 0.9751,
        0.9751, 0.9766, 0.9731, 0.9756, 0.9771, 0.9746, 0.9712, 0.9751, 0.9761,
        0.9761, 0.9766, 0.9746, 0.9697, 0.9688, 0.9771, 0.9746, 0.9736, 0.9741,
        0.9751, 0.9727, 0.9668, 0.9722, 0.9663, 0.9707, 0.9761, 0.9731, 0.9800,
        0.9790, 0.9668, 0.9751, 0.9766, 0.9761, 0.9697, 0.9751, 0.9683, 0.9673,
        0.9727, 0.9673, 0.9746, 0.9702, 0.9707, 0.9741, 0.9712, 0.9712, 0.9526,
        0.9746, 0.9717, 0.9771, 0.9722, 0.9771, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  34%|███▍      | 10/29 [00:10<00:07,  2.55it/s]

tensor([0.9741, 0.9746, 0.9741, 0.9756, 0.9756, 0.9731, 0.9775, 0.9751, 0.9746,
        0.9736, 0.9761, 0.9731, 0.9746, 0.9761, 0.9746, 0.9707, 0.9746, 0.9756,
        0.9751, 0.9761, 0.9736, 0.9692, 0.9683, 0.9761, 0.9741, 0.9736, 0.9736,
        0.9751, 0.9727, 0.9658, 0.9727, 0.9648, 0.9697, 0.9756, 0.9727, 0.9800,
        0.9785, 0.9663, 0.9751, 0.9756, 0.9766, 0.9692, 0.9736, 0.9683, 0.9678,
        0.9717, 0.9668, 0.9741, 0.9702, 0.9712, 0.9727, 0.9702, 0.9712, 0.9512,
        0.9741, 0.9712, 0.9766, 0.9722, 0.9771, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  38%|███▊      | 11/29 [00:10<00:06,  2.76it/s]

tensor([0.9736, 0.9746, 0.9736, 0.9751, 0.9751, 0.9731, 0.9766, 0.9741, 0.9736,
        0.9736, 0.9761, 0.9727, 0.9741, 0.9761, 0.9741, 0.9702, 0.9741, 0.9751,
        0.9746, 0.9751, 0.9731, 0.9688, 0.9673, 0.9761, 0.9736, 0.9727, 0.9731,
        0.9746, 0.9722, 0.9653, 0.9727, 0.9639, 0.9692, 0.9751, 0.9722, 0.9790,
        0.9780, 0.9653, 0.9746, 0.9756, 0.9761, 0.9683, 0.9731, 0.9673, 0.9673,
        0.9712, 0.9658, 0.9736, 0.9692, 0.9702, 0.9722, 0.9697, 0.9702, 0.9507,
        0.9736, 0.9712, 0.9761, 0.9712, 0.9766, 0.9731], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  41%|████▏     | 12/29 [00:10<00:05,  2.91it/s]

tensor([0.9741, 0.9756, 0.9741, 0.9761, 0.9766, 0.9741, 0.9780, 0.9756, 0.9751,
        0.9756, 0.9766, 0.9727, 0.9756, 0.9766, 0.9751, 0.9707, 0.9756, 0.9761,
        0.9761, 0.9771, 0.9751, 0.9697, 0.9683, 0.9771, 0.9746, 0.9741, 0.9746,
        0.9756, 0.9736, 0.9668, 0.9731, 0.9663, 0.9707, 0.9766, 0.9731, 0.9800,
        0.9790, 0.9658, 0.9756, 0.9766, 0.9766, 0.9697, 0.9741, 0.9683, 0.9678,
        0.9727, 0.9673, 0.9751, 0.9707, 0.9712, 0.9741, 0.9712, 0.9722, 0.9517,
        0.9751, 0.9722, 0.9771, 0.9727, 0.9780, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  45%|████▍     | 13/29 [00:11<00:05,  3.04it/s]

tensor([0.9751, 0.9756, 0.9751, 0.9766, 0.9771, 0.9746, 0.9785, 0.9761, 0.9756,
        0.9756, 0.9771, 0.9736, 0.9766, 0.9771, 0.9756, 0.9717, 0.9756, 0.9766,
        0.9766, 0.9775, 0.9756, 0.9702, 0.9697, 0.9775, 0.9751, 0.9751, 0.9746,
        0.9756, 0.9741, 0.9673, 0.9741, 0.9673, 0.9717, 0.9771, 0.9736, 0.9805,
        0.9795, 0.9678, 0.9761, 0.9775, 0.9775, 0.9697, 0.9751, 0.9692, 0.9688,
        0.9731, 0.9688, 0.9756, 0.9722, 0.9722, 0.9741, 0.9717, 0.9717, 0.9531,
        0.9756, 0.9727, 0.9775, 0.9736, 0.9780, 0.9746], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  48%|████▊     | 14/29 [00:11<00:04,  3.12it/s]

tensor([0.9717, 0.9727, 0.9717, 0.9736, 0.9741, 0.9712, 0.9761, 0.9731, 0.9727,
        0.9722, 0.9741, 0.9707, 0.9727, 0.9746, 0.9722, 0.9673, 0.9731, 0.9736,
        0.9731, 0.9741, 0.9727, 0.9663, 0.9658, 0.9751, 0.9717, 0.9707, 0.9717,
        0.9727, 0.9697, 0.9634, 0.9702, 0.9634, 0.9678, 0.9741, 0.9702, 0.9775,
        0.9766, 0.9629, 0.9727, 0.9741, 0.9741, 0.9658, 0.9717, 0.9653, 0.9644,
        0.9697, 0.9644, 0.9722, 0.9678, 0.9683, 0.9707, 0.9683, 0.9683, 0.9502,
        0.9717, 0.9697, 0.9751, 0.9697, 0.9746, 0.9717], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  52%|█████▏    | 15/29 [00:11<00:04,  3.17it/s]

tensor([0.9727, 0.9741, 0.9731, 0.9751, 0.9751, 0.9722, 0.9771, 0.9741, 0.9736,
        0.9731, 0.9756, 0.9717, 0.9736, 0.9756, 0.9731, 0.9688, 0.9741, 0.9746,
        0.9741, 0.9751, 0.9736, 0.9678, 0.9673, 0.9761, 0.9731, 0.9722, 0.9727,
        0.9736, 0.9707, 0.9648, 0.9712, 0.9648, 0.9688, 0.9751, 0.9717, 0.9785,
        0.9780, 0.9644, 0.9736, 0.9751, 0.9751, 0.9673, 0.9727, 0.9668, 0.9658,
        0.9712, 0.9658, 0.9736, 0.9692, 0.9692, 0.9717, 0.9692, 0.9697, 0.9512,
        0.9731, 0.9712, 0.9761, 0.9707, 0.9756, 0.9727], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  55%|█████▌    | 16/29 [00:12<00:04,  3.23it/s]

tensor([0.9727, 0.9731, 0.9727, 0.9741, 0.9741, 0.9717, 0.9761, 0.9736, 0.9731,
        0.9727, 0.9751, 0.9717, 0.9736, 0.9746, 0.9731, 0.9692, 0.9731, 0.9741,
        0.9741, 0.9746, 0.9727, 0.9683, 0.9668, 0.9756, 0.9727, 0.9727, 0.9722,
        0.9736, 0.9712, 0.9648, 0.9722, 0.9634, 0.9683, 0.9746, 0.9712, 0.9790,
        0.9775, 0.9644, 0.9741, 0.9751, 0.9756, 0.9678, 0.9731, 0.9668, 0.9668,
        0.9707, 0.9658, 0.9731, 0.9692, 0.9692, 0.9717, 0.9688, 0.9697, 0.9497,
        0.9727, 0.9697, 0.9756, 0.9707, 0.9761, 0.9722], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  59%|█████▊    | 17/29 [00:12<00:03,  3.25it/s]

tensor([0.9746, 0.9756, 0.9751, 0.9766, 0.9771, 0.9741, 0.9785, 0.9761, 0.9756,
        0.9751, 0.9775, 0.9741, 0.9761, 0.9771, 0.9751, 0.9717, 0.9756, 0.9766,
        0.9766, 0.9771, 0.9751, 0.9707, 0.9702, 0.9775, 0.9756, 0.9746, 0.9746,
        0.9756, 0.9736, 0.9668, 0.9736, 0.9668, 0.9707, 0.9766, 0.9741, 0.9805,
        0.9795, 0.9683, 0.9766, 0.9771, 0.9771, 0.9707, 0.9746, 0.9692, 0.9683,
        0.9727, 0.9688, 0.9756, 0.9717, 0.9722, 0.9746, 0.9712, 0.9722, 0.9526,
        0.9756, 0.9722, 0.9775, 0.9731, 0.9780, 0.9746], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  62%|██████▏   | 18/29 [00:12<00:03,  3.27it/s]

tensor([0.9736, 0.9741, 0.9736, 0.9751, 0.9756, 0.9731, 0.9775, 0.9746, 0.9736,
        0.9741, 0.9756, 0.9731, 0.9746, 0.9756, 0.9746, 0.9702, 0.9746, 0.9756,
        0.9746, 0.9761, 0.9736, 0.9688, 0.9683, 0.9756, 0.9736, 0.9731, 0.9736,
        0.9746, 0.9727, 0.9653, 0.9722, 0.9648, 0.9702, 0.9756, 0.9727, 0.9795,
        0.9785, 0.9663, 0.9746, 0.9756, 0.9761, 0.9692, 0.9736, 0.9683, 0.9678,
        0.9717, 0.9673, 0.9741, 0.9707, 0.9707, 0.9722, 0.9697, 0.9707, 0.9497,
        0.9736, 0.9707, 0.9766, 0.9717, 0.9771, 0.9727], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  66%|██████▌   | 19/29 [00:13<00:03,  3.28it/s]

tensor([0.9746, 0.9751, 0.9746, 0.9766, 0.9766, 0.9746, 0.9780, 0.9761, 0.9756,
        0.9751, 0.9775, 0.9736, 0.9756, 0.9771, 0.9751, 0.9702, 0.9761, 0.9766,
        0.9756, 0.9771, 0.9746, 0.9702, 0.9692, 0.9775, 0.9746, 0.9736, 0.9746,
        0.9756, 0.9731, 0.9668, 0.9736, 0.9668, 0.9712, 0.9771, 0.9736, 0.9800,
        0.9795, 0.9673, 0.9756, 0.9771, 0.9766, 0.9697, 0.9746, 0.9688, 0.9683,
        0.9727, 0.9683, 0.9751, 0.9717, 0.9717, 0.9741, 0.9707, 0.9717, 0.9526,
        0.9756, 0.9722, 0.9771, 0.9727, 0.9775, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  69%|██████▉   | 20/29 [00:13<00:02,  3.28it/s]

tensor([0.9756, 0.9756, 0.9756, 0.9771, 0.9771, 0.9751, 0.9795, 0.9766, 0.9761,
        0.9761, 0.9775, 0.9751, 0.9766, 0.9775, 0.9761, 0.9722, 0.9766, 0.9771,
        0.9766, 0.9775, 0.9751, 0.9707, 0.9702, 0.9780, 0.9756, 0.9746, 0.9751,
        0.9766, 0.9746, 0.9678, 0.9741, 0.9673, 0.9722, 0.9771, 0.9746, 0.9810,
        0.9805, 0.9688, 0.9761, 0.9771, 0.9780, 0.9707, 0.9756, 0.9702, 0.9697,
        0.9736, 0.9688, 0.9761, 0.9727, 0.9722, 0.9751, 0.9722, 0.9727, 0.9531,
        0.9756, 0.9727, 0.9780, 0.9741, 0.9780, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  72%|███████▏  | 21/29 [00:13<00:02,  3.18it/s]

tensor([0.9741, 0.9756, 0.9741, 0.9761, 0.9761, 0.9736, 0.9785, 0.9756, 0.9751,
        0.9751, 0.9771, 0.9731, 0.9756, 0.9771, 0.9751, 0.9712, 0.9756, 0.9761,
        0.9761, 0.9771, 0.9756, 0.9702, 0.9688, 0.9775, 0.9751, 0.9746, 0.9746,
        0.9751, 0.9736, 0.9668, 0.9736, 0.9668, 0.9707, 0.9766, 0.9736, 0.9800,
        0.9790, 0.9678, 0.9761, 0.9775, 0.9766, 0.9697, 0.9746, 0.9683, 0.9683,
        0.9731, 0.9678, 0.9751, 0.9702, 0.9717, 0.9741, 0.9712, 0.9717, 0.9526,
        0.9746, 0.9717, 0.9775, 0.9727, 0.9780, 0.9736], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  76%|███████▌  | 22/29 [00:13<00:02,  3.23it/s]

tensor([0.9746, 0.9751, 0.9746, 0.9761, 0.9766, 0.9741, 0.9780, 0.9751, 0.9756,
        0.9746, 0.9771, 0.9736, 0.9751, 0.9766, 0.9751, 0.9712, 0.9756, 0.9761,
        0.9761, 0.9771, 0.9746, 0.9702, 0.9688, 0.9771, 0.9746, 0.9741, 0.9746,
        0.9756, 0.9741, 0.9663, 0.9736, 0.9663, 0.9702, 0.9766, 0.9731, 0.9800,
        0.9790, 0.9678, 0.9761, 0.9771, 0.9766, 0.9702, 0.9746, 0.9688, 0.9688,
        0.9727, 0.9683, 0.9746, 0.9712, 0.9722, 0.9741, 0.9707, 0.9722, 0.9521,
        0.9751, 0.9722, 0.9771, 0.9727, 0.9780, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  79%|███████▉  | 23/29 [00:14<00:01,  3.18it/s]

tensor([0.9731, 0.9731, 0.9727, 0.9746, 0.9741, 0.9722, 0.9766, 0.9741, 0.9731,
        0.9741, 0.9751, 0.9712, 0.9736, 0.9746, 0.9731, 0.9688, 0.9741, 0.9741,
        0.9746, 0.9751, 0.9727, 0.9673, 0.9663, 0.9756, 0.9731, 0.9722, 0.9722,
        0.9731, 0.9722, 0.9648, 0.9712, 0.9644, 0.9688, 0.9741, 0.9712, 0.9785,
        0.9780, 0.9639, 0.9746, 0.9751, 0.9746, 0.9678, 0.9727, 0.9663, 0.9663,
        0.9717, 0.9648, 0.9741, 0.9688, 0.9688, 0.9722, 0.9683, 0.9692, 0.9497,
        0.9731, 0.9702, 0.9756, 0.9707, 0.9761, 0.9712], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  83%|████████▎ | 24/29 [00:14<00:01,  3.22it/s]

tensor([0.9741, 0.9751, 0.9741, 0.9756, 0.9756, 0.9731, 0.9771, 0.9756, 0.9746,
        0.9746, 0.9761, 0.9727, 0.9746, 0.9761, 0.9746, 0.9702, 0.9751, 0.9756,
        0.9751, 0.9766, 0.9741, 0.9692, 0.9678, 0.9761, 0.9741, 0.9731, 0.9731,
        0.9746, 0.9727, 0.9648, 0.9722, 0.9648, 0.9702, 0.9756, 0.9722, 0.9795,
        0.9785, 0.9663, 0.9751, 0.9761, 0.9756, 0.9692, 0.9741, 0.9683, 0.9668,
        0.9717, 0.9663, 0.9741, 0.9697, 0.9707, 0.9736, 0.9697, 0.9712, 0.9507,
        0.9746, 0.9712, 0.9766, 0.9722, 0.9775, 0.9731], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  86%|████████▌ | 25/29 [00:14<00:01,  3.23it/s]

tensor([0.9746, 0.9751, 0.9746, 0.9761, 0.9766, 0.9746, 0.9775, 0.9761, 0.9756,
        0.9756, 0.9766, 0.9731, 0.9761, 0.9771, 0.9751, 0.9707, 0.9756, 0.9756,
        0.9766, 0.9766, 0.9746, 0.9697, 0.9683, 0.9780, 0.9746, 0.9741, 0.9746,
        0.9756, 0.9731, 0.9658, 0.9727, 0.9663, 0.9707, 0.9761, 0.9731, 0.9800,
        0.9795, 0.9668, 0.9756, 0.9771, 0.9766, 0.9697, 0.9746, 0.9678, 0.9668,
        0.9722, 0.9673, 0.9751, 0.9702, 0.9712, 0.9736, 0.9702, 0.9707, 0.9517,
        0.9751, 0.9717, 0.9775, 0.9727, 0.9771, 0.9741], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  90%|████████▉ | 26/29 [00:15<00:00,  3.24it/s]

tensor([0.9746, 0.9756, 0.9751, 0.9766, 0.9766, 0.9746, 0.9785, 0.9761, 0.9761,
        0.9751, 0.9771, 0.9736, 0.9756, 0.9771, 0.9751, 0.9707, 0.9761, 0.9766,
        0.9766, 0.9771, 0.9751, 0.9707, 0.9692, 0.9771, 0.9751, 0.9741, 0.9746,
        0.9756, 0.9736, 0.9663, 0.9731, 0.9673, 0.9712, 0.9766, 0.9741, 0.9800,
        0.9795, 0.9678, 0.9761, 0.9766, 0.9761, 0.9707, 0.9746, 0.9683, 0.9678,
        0.9727, 0.9683, 0.9756, 0.9712, 0.9717, 0.9746, 0.9707, 0.9722, 0.9526,
        0.9756, 0.9722, 0.9775, 0.9731, 0.9780, 0.9746], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  93%|█████████▎| 27/29 [00:15<00:00,  3.26it/s]

tensor([0.9741, 0.9741, 0.9741, 0.9761, 0.9761, 0.9736, 0.9775, 0.9746, 0.9746,
        0.9751, 0.9761, 0.9727, 0.9751, 0.9761, 0.9746, 0.9697, 0.9746, 0.9756,
        0.9756, 0.9766, 0.9741, 0.9692, 0.9683, 0.9766, 0.9741, 0.9736, 0.9736,
        0.9751, 0.9731, 0.9653, 0.9727, 0.9658, 0.9702, 0.9756, 0.9727, 0.9795,
        0.9785, 0.9668, 0.9756, 0.9761, 0.9756, 0.9692, 0.9741, 0.9683, 0.9673,
        0.9722, 0.9678, 0.9746, 0.9702, 0.9707, 0.9731, 0.9697, 0.9707, 0.9512,
        0.9746, 0.9712, 0.9766, 0.9722, 0.9771, 0.9731], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes:  97%|█████████▋| 28/29 [00:15<00:00,  3.30it/s]

tensor([0.9727, 0.9736, 0.9731, 0.9741, 0.9746, 0.9717, 0.9766, 0.9736, 0.9736,
        0.9731, 0.9751, 0.9717, 0.9736, 0.9751, 0.9727, 0.9688, 0.9741, 0.9736,
        0.9746, 0.9751, 0.9731, 0.9678, 0.9668, 0.9756, 0.9731, 0.9712, 0.9722,
        0.9736, 0.9707, 0.9648, 0.9712, 0.9639, 0.9688, 0.9746, 0.9712, 0.9780,
        0.9775, 0.9653, 0.9741, 0.9756, 0.9746, 0.9673, 0.9727, 0.9663, 0.9658,
        0.9712, 0.9653, 0.9731, 0.9692, 0.9688, 0.9722, 0.9692, 0.9692, 0.9497,
        0.9731, 0.9697, 0.9756, 0.9707, 0.9756, 0.9717], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:16<00:00,  3.57it/s]

tensor([0.9717, 0.9736, 0.9722, 0.9731, 0.9736, 0.9717, 0.9756, 0.9727, 0.9731,
        0.9731, 0.9741, 0.9702, 0.9731, 0.9746, 0.9727, 0.9683, 0.9731, 0.9736,
        0.9736, 0.9741, 0.9717, 0.9673, 0.9653, 0.9751, 0.9717, 0.9707, 0.9722,
        0.9731, 0.9702, 0.9634, 0.9697, 0.9629, 0.9678, 0.9741, 0.9707, 0.9780,
        0.9771, 0.9639, 0.9731, 0.9736, 0.9741, 0.9663, 0.9722, 0.9648, 0.9644,
        0.9707, 0.9639, 0.9722, 0.9673, 0.9678, 0.9712, 0.9683, 0.9673, 0.9492,
        0.9722, 0.9697, 0.9751, 0.9692, 0.9751, 0.9717], device='cuda:0',
       dtype=torch.float16)


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:16<00:00,  1.76it/s]

🔍 Base classes accuracy: 72.34%
🔍 Novel classes accuracy: 78.94%





## Harmonic Mean
Few-Shot Adaptations papers usually report the Harmonic Mean.
The harmonic mean tends to mitigate the impact of large outliers (base accuracy) and aggravate the impact of small ones (novel accuracy).
Thus, achieving very high base accuracies at the expense of the novel accuracy will be penalized by the HM.

In [32]:
def harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    hm = numerator / denominator
    return hm

print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

🔍 Harmonic Mean: 74.99%
