# CLIP zero-shot Evaluation
This short notebook implements the dataset split into base and novel categories (see project assignment) and runs the zero-shot evaluation with CLIP.
Feel free to copy the code contained in this notebook or to directly use this notebook as starting point for you project.

In [1]:
#@title Dependencies Installation
%pip install -q google-genai pydantic openai_clip

In [2]:
#@title Imports
import torch
import torchvision
import clip
from tqdm import tqdm
import os
import gc
from torchvision import transforms
from google.colab import userdata
import google.generativeai as genai
from pydantic import BaseModel, Field
from typing import List, Dict, Any
import json
from torchvision.datasets import Flowers102 as datset_used
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]


In [3]:
#@title Clip Utils functions

def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = datset_used(root=data_dir, split="train", download=True, transform=transform)
    val = datset_used(root=data_dir, split="val", download=True, transform=transform)
    test = datset_used(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

def base_novel_categories(dataset):
    # set returns the unique set of all dataset classes
    all_classes = set(dataset._labels)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

def split_data(dataset, base_classes):
    # these two lists will store the sample indexes
    base_categories_samples = []
    novel_categories_samples = []

    # we create a set of base classes to compute the test below in O(1)
    # this is optional and can be removed
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    # here we create the dataset subsets
    # the torch Subset is just a wrapper around the dataset
    # it simply stores the subset indexes and the original dataset (your_subset.dataset)
    # when asking for sample i in the subset, torch will look for its original position in the dataset and retrieve it
    # https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

@torch.no_grad()
def eval(model, dataset, categories, batch_size, device, text_features, label=""):
    from tqdm import tqdm
    import numpy as np

    model.eval()

    # Remap categories to contiguous label space
    contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    correct_top1 = 0
    correct_top5 = 0
    correct_top10 = 0
    total = 0

    # Confidence gaps split by error category
    gap_top1_hit = []
    gap_top5_hit = []
    gap_top10_hit = []
    gap_top10_miss = []

    for images, targets in tqdm(dataloader, desc=label):
        targets = torch.tensor([contig_cat2idx[t.item()] for t in targets], dtype=torch.long).to(device)
        images = images.to(device)

        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        similarities = image_features @ text_features.T
        top10 = similarities.topk(10, dim=-1)

        top10_indices = top10.indices
        top10_values = top10.values

        # Accuracies
        correct_top1 += (top10_indices[:, 0] == targets).sum().item()
        correct_top5 += sum([targets[i] in top10_indices[i, :5] for i in range(len(targets))])
        correct_top10 += sum([targets[i] in top10_indices[i, :10] for i in range(len(targets))])

        for i in range(len(targets)):
            true_idx = targets[i].item()
            pred_conf = top10_values[i, 0].item()
            true_conf = similarities[i, true_idx].item()

            if top10_indices[i, 0].item() == true_idx:
                second_conf = top10_values[i, 1].item()
                gap_top1_hit.append((pred_conf - second_conf) * 100)
                continue
            if true_idx in top10_indices[i, 1:5]:
                gap_top5_hit.append((pred_conf - true_conf) * 100)
            elif true_idx in top10_indices[i, 5:10]:
                gap_top10_hit.append((pred_conf - true_conf) * 100)
            else:
                gap_top10_miss.append((pred_conf - true_conf) * 100)

        total += targets.size(0)

    # Final metrics
    top1_acc = correct_top1 / total
    top5_acc = correct_top5 / total
    top10_acc = correct_top10 / total

    print(f"\n📊 Total samples evaluated: {total}\n")

    print(f"✅ Top-1 Accuracy:      {top1_acc*100:.2f}%")
    print(f"✅ Top-5 Accuracy:      {top5_acc*100:.2f}%")
    print(f"✅ Top-10 Accuracy:     {top10_acc*100:.2f}%")
    def safe_mean(arr): return np.mean(arr) if arr else 0.0
    print(f"✅ Avg. Conf. Gap (Top-1 hit):      {safe_mean(gap_top1_hit):.2f}%")

    print(f"❌ Avg. Conf. Gap (Top-5 hit):     {safe_mean(gap_top5_hit):.2f}%")
    print(f"❌ Avg. Conf. Gap (Top-10 hit):    {safe_mean(gap_top10_hit):.2f}%")
    print(f"❌ Avg. Conf. Gap (Beyond top-10): {safe_mean(gap_top10_miss):.2f}%")

    del text_features, dataloader, image_features, similarities, targets, images, top10, top10_indices, top10_values
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    return {
    "top1": top1_acc,
    "top5": top5_acc,
    "top10": top10_acc,
    "avg_gap_top1_hit": safe_mean(gap_top1_hit),
    "avg_error_top5_hit": safe_mean(gap_top5_hit),
    "avg_error_top10_hit": safe_mean(gap_top10_hit),
    "avg_error_top10_miss": safe_mean(gap_top10_miss),
    }

In [4]:
#@title CLIP Loading
device = "cuda" if torch.cuda.is_available() else "cpu"
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
model, preprocess = clip.load("ViT-B/16", device=device)

# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)


In [9]:
#@title Compute Zero-Shot Standard Predictions

base_text_inputs = clip.tokenize(
        [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in base_classes]
    ).to(device)
base_text_features = model.encode_text(base_text_inputs)
base_text_features /= base_text_features.norm(dim=-1, keepdim=True)
base_accuracy = eval(model=model, dataset=test_base, categories=base_classes, batch_size=128, device=device, text_features=base_text_features, label="🧠 Zero-shot evaluation on Base Classes")
del base_text_inputs

novel_text_inputs = clip.tokenize(
        [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in novel_classes]
    ).to(device)
novel_text_features = model.encode_text(novel_text_inputs)
novel_text_features /= novel_text_features.norm(dim=-1, keepdim=True)
novel_accuracies = eval(model=model, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, text_features=novel_text_features, label="🧠 Zero-shot evaluation on Novel Classes")
del novel_text_inputs

torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

🧠 Zero-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:19<00:00,  1.05it/s]



📊 Total samples evaluated: 2473

✅ Top-1 Accuracy:      71.29%
✅ Top-5 Accuracy:      90.86%
✅ Top-10 Accuracy:     97.53%
❌ Avg. Conf. Gap (Top-5 hit):     1.67%
❌ Avg. Conf. Gap (Top-10 hit):    3.86%
❌ Avg. Conf. Gap (Beyond top-10): 5.00%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:29<00:00,  1.02s/it]



📊 Total samples evaluated: 3676

✅ Top-1 Accuracy:      78.24%
✅ Top-5 Accuracy:      89.15%
✅ Top-10 Accuracy:     92.79%
❌ Avg. Conf. Gap (Top-5 hit):     1.37%
❌ Avg. Conf. Gap (Top-10 hit):    3.45%
❌ Avg. Conf. Gap (Beyond top-10): 5.70%


In [5]:
#@title Gemini setup
# Load your API key from Colab secrets.
try:
    GENAI_API_KEY = userdata.get('GENAI_KEY')
    genai.configure(api_key=GENAI_API_KEY)
except userdata.SecretNotFoundError:
    print("GENAI_KEY not found in Colab secrets. Please add it to proceed.")
    exit()
except Exception as e:
    print(f"An error occurred during API key setup: {e}")
    exit()

print("Google GenAI SDK configured successfully!")

# Define the Pydantic schema for the LLM output
class PromptDescriptions(BaseModel):
    # This will hold a list of lists: [[desc1_flower1, desc2_flower1, ...], [desc1_flower2, ...], ...]
    flower_descriptions: List[List[str]] = Field(
        description="A list where each element is a list of 5 short descriptions for a specific flower."
    )

print("Pydantic schema defined.")

GENAI_KEY not found in Colab secrets. Please add it to proceed.
Google GenAI SDK configured successfully!
Pydantic schema defined.


In [None]:
# @title 3. LLM Prompt Generation (Batched Calls)

genai_model = genai.GenerativeModel('gemini-2.5-flash')  # Using Flash for speed

generated_prompts_for_classes = {}  # {class_name: [list of 5 prompts]}
prompt_batch_size = 13

# Define schema using Pydantic
class Prompts(BaseModel):
    prompt1: str
    prompt2: str
    prompt3: str
    prompt4: str
    prompt5: str

class PromptDescriptions(BaseModel):
    flower_descriptions: List[Prompts]

print("\n--- Generating 5 prompts for each flower class using LLM ---")
print("This might take a few minutes depending on the number of classes and API response times.")

for i in tqdm(range(0, len(CLASS_NAMES), prompt_batch_size), desc="Generating prompts in batches"):
    current_batch_classes = CLASS_NAMES[i : min(i + prompt_batch_size, len(CLASS_NAMES))]

    # Dynamically build the prompt for the current batch
    prompt_batch = f"""You are a professional botanical photographer and creative writer for a visual nature magazine.

Given a list of flower species, generate exactly 5 short and visually evocative descriptions per flower. Each description should be phrased as if accompanying a stunning photograph of the flower, aiming to evoke its unique beauty and essence for an AI model like CLIP.

Each prompt must be:
- Descriptive, with strong visual language (color, shape, textures, dimension of petals, feeling)
- Distinct across the 5 prompts (no overlap or simple rewordings)
- Specific to the given flower (no generic phrases)
- Emphasize features that distinguish this flower from visually similar species
- Realistic and natural (imagine you're writing the description of your own photo of the flower)

Use the word 'flower' in at least 2 out of 5 prompts. Prioritize conciseness (max 20 words) and features that distinguish the flower.

Return the output as a JSON object with a single key 'flower_descriptions', whose value is a list of objects. Each object corresponds to a flower and includes exactly 5 fields: 'prompt1' through 'prompt5' (all strings).

Here are the flowers for this batch: {current_batch_classes} """


    try:
        response_llm = genai_model.generate_content(
            prompt_batch,
            generation_config=genai.types.GenerationConfig(
                response_mime_type="application/json",
                response_schema=PromptDescriptions
            )
        )

        parsed_response = PromptDescriptions.model_validate_json(response_llm.text)

        for j, prompt_obj in enumerate(parsed_response.flower_descriptions):
            flower_name = current_batch_classes[j]
            generated_prompts_for_classes[flower_name] = [
                prompt_obj.prompt1,
                prompt_obj.prompt2,
                prompt_obj.prompt3,
                prompt_obj.prompt4,
                prompt_obj.prompt5,
            ]

    except Exception as e:
        print(f"\nError generating prompts for batch starting with '{current_batch_classes[0]}': {e}")
        print(f"Raw response (if available): {response_llm.text if 'response_llm' in locals() else 'N/A'}")

print("\nLLM prompt generation complete.")

with open("generated_prompts.json", "w") as f:
    json.dump(generated_prompts_for_classes, f, indent=4)


In [5]:
#@title Import generated Prompts
import json

# 1. Load prompts
with open("generated_prompts.json", "r") as f:
    generated_prompts_for_classes = json.load(f)

# 2. Generate averaged text embeddings per class
def get_llm_text_features(model, prompt_dict, class_ids, class_names, device):
    """
    Compute mean text embedding per class from LLM-generated prompts.
    """
    text_features = []

    for c in class_ids:
        class_name = class_names[c]
        prompts = prompt_dict[class_name]

        # Tokenize and encode prompts
        text_inputs = clip.tokenize(prompts).to(device)
        embeddings = model.encode_text(text_inputs)
        embeddings /= embeddings.norm(dim=-1, keepdim=True)  # Normalize

        mean_embedding = embeddings.mean(dim=0)
        mean_embedding /= mean_embedding.norm(dim=-1, keepdim=True)  # Re-normalize after mean
        text_features.append(mean_embedding)
        del text_inputs, embeddings, mean_embedding
        torch.cuda.empty_cache()

    return torch.stack(text_features).to(device)


In [6]:
#@title Base classes evaluation
base_text_features = get_llm_text_features(
    model=model,
    prompt_dict=generated_prompts_for_classes,
    class_ids=base_classes,
    class_names=CLASS_NAMES,
    device=device
)

# Evaluate using precomputed text features
base_accuracies = eval(
    model=model,
    dataset=test_base,
    categories=base_classes,
    batch_size=128,
    device=device,
    text_features=base_text_features,
    label="🌸 Zero-shot eval with LLM prompts on Base Classes"
)

# Clean up large tensors
del base_text_features, base_accuracies
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

🌸 Zero-shot eval with LLM prompts on Base Classes: 100%|██████████| 20/20 [00:19<00:00,  1.01it/s]



📊 Total samples evaluated: 2473

✅ Top-1 Accuracy:      79.18%
✅ Top-5 Accuracy:      94.99%
✅ Top-10 Accuracy:     96.24%
✅ Avg. Conf. Gap (Top-1 hit):      3.85%
❌ Avg. Conf. Gap (Top-5 hit):     1.51%
❌ Avg. Conf. Gap (Top-10 hit):    4.15%
❌ Avg. Conf. Gap (Beyond top-10): 7.13%


0

In [7]:
#@title Novel classes evaluation
novel_text_features = get_llm_text_features(
    model=model,
    prompt_dict=generated_prompts_for_classes,
    class_ids=novel_classes,
    class_names=CLASS_NAMES,
    device=device
)

# Evaluate using precomputed text features
novel_accuracies = eval(
    model=model,
    dataset=test_novel,
    categories=novel_classes,
    batch_size=128,
    device=device,
    text_features=novel_text_features,
    label="🌸 Zero-shot eval with LLM prompts on Novel Classes"
)

# Clean up large tensors
del novel_text_features, novel_accuracies
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

🌸 Zero-shot eval with LLM prompts on Novel Classes: 100%|██████████| 29/29 [00:28<00:00,  1.02it/s]



📊 Total samples evaluated: 3676

✅ Top-1 Accuracy:      80.66%
✅ Top-5 Accuracy:      92.19%
✅ Top-10 Accuracy:     93.25%
✅ Avg. Conf. Gap (Top-1 hit):      4.32%
❌ Avg. Conf. Gap (Top-5 hit):     1.93%
❌ Avg. Conf. Gap (Top-10 hit):    4.39%
❌ Avg. Conf. Gap (Beyond top-10): 13.61%


0

## Harmonic Mean
Few-Shot Adaptations papers usually report the Harmonic Mean.
The harmonic mean tends to mitigate the impact of large outliers (base accuracy) and aggravate the impact of small ones (novel accuracy).
Thus, achieving very high base accuracies at the expense of the novel accuracy will be penalized by the HM.

In [None]:
def harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    hm = numerator / denominator
    return hm

print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

🔍 Harmonic Mean: 74.62%
