# CLIP zero-shot Evaluation
This short notebook implements the dataset split into base and novel categories (see project assignment) and runs the zero-shot evaluation with CLIP.
Feel free to copy the code contained in this notebook or to directly use this notebook as starting point for you project.

In [1]:
#@title Dependencies Installation
%pip install -q openai_clip

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for openai_clip (setup.py) ... [?25l[?25hdone


In [2]:
#@title Imports
import torch
import torchvision
import clip
import os
import gc
from torchvision import transforms
from google.colab import userdata
from google import genai
from google.genai import types
from pydantic import BaseModel, Field
from typing import List, Dict, Any
import json
from torchvision.datasets import Flowers102 as dataset_used
import contextlib
import numpy as np
import time
from tqdm.notebook import tqdm

print("✅ All dependencies imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🖼️ Torchvision version: {torchvision.__version__}")
print(f"🤖 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name()}")

✅ All dependencies imported successfully!
🔥 PyTorch version: 2.6.0+cu124
🖼️ Torchvision version: 0.21.0+cu124
🤖 CUDA available: True
🎮 GPU: Tesla T4


In [3]:
#@title Classes setup
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
dataset_name = "Oxford flowers102"
generic_category = "flowers"

print(f"📊 Total flower classes: {len(CLASS_NAMES)}")
print(f"🌸 Sample classes: {CLASS_NAMES[:5]}...")

📊 Total flower classes: 102
🌸 Sample classes: ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold']...


In [4]:
#@title Memory management utils
def aggressive_cleanup():
    """
    Comprehensive memory cleanup function to prevent VRAM leaks.

    This function:
    - Forces Python garbage collection
    - Empties CUDA cache
    - Collects IPC (Inter-Process Communication) resources
    - Synchronizes CUDA operations
    """
    gc.collect()                    # Python garbage collection
    torch.cuda.empty_cache()        # Clear CUDA cache
    torch.cuda.ipc_collect()        # Clean up IPC resources
    if torch.cuda.is_available():
        torch.cuda.synchronize()    # Wait for all CUDA operations to complete

def print_memory_stats():
    """Display current GPU memory usage statistics"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3  # Convert to GB
        reserved = torch.cuda.memory_reserved() / 1024**3    # Convert to GB
        print(f"🔋 GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
    else:
        print("❌ CUDA not available - running on CPU")

In [5]:
#@title Utils functions

def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    print("📥 Downloading and loading Flowers102 dataset...")
    train = dataset_used(root=data_dir, split="train", download=True, transform=transform)
    val = dataset_used(root=data_dir, split="val", download=True, transform=transform)
    test = dataset_used(root=data_dir, split="test", download=True, transform=transform)
    print(f"✅ Dataset loaded - Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")
    return train, val, test

def base_novel_categories(dataset):
    """
    Split all classes into base and novel categories.

    Base classes: First half of classes (for training/fine-tuning scenarios)
    Novel classes: Second half of classes (for zero-shot evaluation)

    Args:
        dataset: PyTorch dataset object

    Returns:
        base_classes, novel_classes: Lists of class indices
    """
    all_classes = set(dataset._labels)
    num_classes = len(all_classes)

    # Split classes 50/50
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]

    print(f"📊 Class split - Base: {len(base_classes)}, Novel: {len(novel_classes)}")
    return base_classes, novel_classes

def split_data(dataset, base_classes):
    """
    Split dataset samples based on base/novel class membership.

    Args:
        dataset: PyTorch dataset
        base_classes: List of base class indices

    Returns:
        base_dataset, novel_dataset: Subsets containing only base/novel samples
    """
    base_categories_samples = []
    novel_categories_samples = []
    base_set = set(base_classes)

    # Iterate through all samples and categorize by class
    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    # Create dataset subsets
    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)

    print(f"📊 Data split - Base samples: {len(base_dataset)}, Novel samples: {len(novel_dataset)}")
    return base_dataset, novel_dataset


In [6]:
#@title CLIP context manager
@contextlib.contextmanager
def clip_model_context(model_name="ViT-B/16"):
    """
    Context manager for CLIP model to ensure proper cleanup.

    This ensures that:
    - Model is properly loaded on the correct device
    - Model is set to evaluation mode
    - Memory is cleaned up after use, even if errors occur

    Args:
        model_name: CLIP model variant to load

    Yields:
        model: CLIP model
        preprocess: Image preprocessing function
        device: Device (cuda/cpu) the model is loaded on
    """
    print(f"🤖 Loading CLIP model: {model_name}")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load CLIP model and preprocessing
    model, preprocess = clip.load(model_name, device=device)
    model.eval()  # Set to evaluation mode

    print(f"✅ CLIP model loaded on {device}")

    try:
        yield model, preprocess, device
    finally:
        # Ensure cleanup happens even if errors occur
        print("🧹 Cleaning up CLIP model...")
        del model
        aggressive_cleanup()

In [7]:
#@title Evaluation function
@torch.no_grad()
def eval(model, dataset, categories, batch_size, device, text_features, label=""):
    """
    Memory-optimized evaluation function for CLIP zero-shot classification.

    Computes:
    - Top-1, Top-5, Top-10 accuracies
    - Confidence gaps between predicted and true classes
    - Splits confidence gaps by error type (top-1 hit, top-5 hit, etc.)

    Args:
        model: CLIP model
        dataset: PyTorch dataset to evaluate
        categories: List of class indices for this evaluation
        batch_size: Batch size for evaluation
        device: Device to run on
        text_features: Pre-computed text embeddings
        label: Description label for progress bar

    Returns:
        Dictionary with accuracy metrics and confidence gaps
    """

    model.eval()

    # Create mapping from original class IDs to contiguous indices (0, 1, 2, ...)
    contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

    # Create dataloader (num_workers=0 to avoid subprocess memory issues in Colab)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    # Initialize counters
    correct_top1 = 0
    correct_top5 = 0
    correct_top10 = 0
    total = 0

    # Lists to store confidence gaps for different error categories
    gap_top1_hit = []      # When prediction is correct
    gap_top5_hit = []      # When true class is in top-5 but not top-1
    gap_top10_hit = []     # When true class is in top-10 but not top-5
    gap_top10_miss = []    # When true class is not in top-10

    try:
        for batch_idx, (images, targets) in enumerate(tqdm(dataloader, desc=label)):
            # Remap targets to contiguous space and move to device
            targets = torch.tensor([contig_cat2idx[t.item()] for t in targets], dtype=torch.long).to(device)
            images = images.to(device)

            # Encode images
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            # Compute similarities between images and text
            similarities = image_features @ text_features.T

            # Get top-10 predictions
            top10 = similarities.topk(10, dim=-1)
            top10_indices = top10.indices  # Class predictions
            top10_values = top10.values    # Confidence scores

            # Calculate accuracies
            correct_top1 += (top10_indices[:, 0] == targets).sum().item()
            correct_top5 += sum([targets[i] in top10_indices[i, :5] for i in range(len(targets))])
            correct_top10 += sum([targets[i] in top10_indices[i, :10] for i in range(len(targets))])

            # Calculate confidence gaps for each sample
            for i in range(len(targets)):
                true_idx = targets[i].item()
                pred_conf = top10_values[i, 0].item()  # Highest prediction confidence
                true_conf = similarities[i, true_idx].item()  # True class confidence

                # Categorize by error type
                if top10_indices[i, 0].item() == true_idx:
                    # Correct prediction - gap between 1st and 2nd choice
                    second_conf = top10_values[i, 1].item()
                    gap_top1_hit.append((pred_conf - second_conf) * 100)
                elif true_idx in top10_indices[i, 1:5]:
                    # True class in top-5 but not top-1
                    gap_top5_hit.append((pred_conf - true_conf) * 100)
                elif true_idx in top10_indices[i, 5:10]:
                    # True class in top-10 but not top-5
                    gap_top10_hit.append((pred_conf - true_conf) * 100)
                else:
                    # True class not in top-10
                    gap_top10_miss.append((pred_conf - true_conf) * 100)

            total += targets.size(0)

            # Clean up batch tensors immediately to save memory
            del images, targets, image_features, similarities, top10, top10_indices, top10_values

            # Periodic cleanup during long evaluations
            if batch_idx % 10 == 0:
                aggressive_cleanup()

    finally:
        # Clean up dataloader
        del dataloader
        aggressive_cleanup()

    # Calculate final metrics
    top1_acc = correct_top1 / total
    top5_acc = correct_top5 / total
    top10_acc = correct_top10 / total

    # Display results
    print(f"\n📊 Total samples evaluated: {total}\n")
    print(f"✅ Top-1 Accuracy:      {top1_acc*100:.2f}%")
    print(f"✅ Top-5 Accuracy:      {top5_acc*100:.2f}%")
    print(f"✅ Top-10 Accuracy:     {top10_acc*100:.2f}%")
    print(f"✅ Avg. Conf. Gap (Top-1 hit):      {safe_mean(gap_top1_hit):.2f}%")
    print(f"❌ Avg. Conf. Gap (Top-5 hit):     {safe_mean(gap_top5_hit):.2f}%")
    print(f"❌ Avg. Conf. Gap (Top-10 hit):    {safe_mean(gap_top10_hit):.2f}%")
    print(f"❌ Avg. Conf. Gap (Beyond top-10): {safe_mean(gap_top10_miss):.2f}%")

    return {
        "top1": top1_acc,
        "top5": top5_acc,
        "top10": top10_acc,
        "avg_gap_top1_hit": safe_mean(gap_top1_hit),
        "avg_error_top5_hit": safe_mean(gap_top5_hit),
        "avg_error_top10_hit": safe_mean(gap_top10_hit),
        "avg_error_top10_miss": safe_mean(gap_top10_miss),
    }

In [8]:
#@title Text features functions
def get_text_features_standard(model, class_ids, device):
    """
    Generate standard CLIP text features using simple template prompts.

    Uses the template: "a photo of a {class_name}, a type of flower."

    Args:
        model: CLIP model
        class_ids: List of class indices to generate features for
        device: Device to run computation on

    Returns:
        text_features: Normalized text embeddings tensor
    """
    # Create simple template prompts
    prompts = [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in class_ids]
    print(f"📝 Generated {len(prompts)} standard prompts")
    print(f"📄 Example prompt: '{prompts[0]}'")

    # Tokenize prompts
    text_inputs = clip.tokenize(prompts).to(device)

    try:
        # Generate embeddings
        with torch.no_grad():
            text_features = model.encode_text(text_inputs)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Create a copy to ensure we don't keep references
        result = text_features.clone()
        return result

    finally:
        # Clean up intermediate tensors
        del text_inputs, text_features
        aggressive_cleanup()

def get_llm_text_features(model, prompt_dict, class_ids, class_names, device):
    """
    Generate text features from LLM-generated prompts.

    For each class, uses multiple detailed prompts generated by an LLM,
    then averages their embeddings to get a richer representation.

    Args:
        model: CLIP model
        prompt_dict: Dictionary mapping class names to lists of prompts
        class_ids: List of class indices
        class_names: List of all class names
        device: Device to run computation on

    Returns:
        text_features: Tensor of averaged normalized embeddings per class
    """
    text_features = []

    print(f"🤖 Processing LLM-generated prompts for {len(class_ids)} classes...")

    for c in class_ids:
        class_name = class_names[c]
        prompts = prompt_dict[class_name]

        #print(f"📝 Processing {len(prompts)} prompts for '{class_name}'")

        # Tokenize all prompts for this class
        text_inputs = clip.tokenize(prompts).to(device)

        try:
            # Generate embeddings for all prompts
            with torch.no_grad():
                embeddings = model.encode_text(text_inputs)
                embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)

                # Average across all prompts for this class
                mean_embedding = embeddings.mean(dim=0)
                mean_embedding = mean_embedding / mean_embedding.norm(dim=-1, keepdim=True)

                # Store the averaged embedding
                text_features.append(mean_embedding.clone())

        finally:
            # Clean up intermediate tensors
            del text_inputs, embeddings, mean_embedding
            aggressive_cleanup()

    return torch.stack(text_features).to(device)

def safe_mean(arr):
    """Safely compute mean, returning 0 if array is empty"""
    return np.mean(arr) if arr else 0.0

In [9]:
#@title Evaluation suite
def run_evaluation_suite():
    """
    Run complete evaluation suite with proper memory management.

    Evaluations performed:
    1. Standard prompts on base classes
    2. Standard prompts on novel classes
    3. LLM-enhanced prompts on base classes (if available)
    4. LLM-enhanced prompts on novel classes (if available)

    Returns:
        Dictionary containing all evaluation results
    """
    print("🚀 Starting memory-optimized CLIP evaluation suite...")

    # Try to load LLM-generated prompts
    try:
        with open("generated_prompts.json", "r") as f:
            generated_prompts_for_classes = json.load(f)
        has_llm_prompts = True
        print("✅ Found generated prompts file")
        print(f"📊 Loaded prompts for {len(generated_prompts_for_classes)} classes")
    except FileNotFoundError:
        print("❌ No generated prompts found. Will only run standard evaluation.")
        has_llm_prompts = False

    results = {}

    # Use context manager for proper model lifecycle management
    with clip_model_context("ViT-B/16") as (model, preprocess, device):
        print(f"📱 Using device: {device}")
        batch_sz = 128

        # Load and prepare datasets
        print("\n📥 Loading and preparing datasets...")
        train_set, val_set, test_set = get_data(transform=preprocess)
        base_classes, novel_classes = base_novel_categories(train_set)

        # Split datasets by base/novel classes
        train_base, _ = split_data(train_set, base_classes)
        val_base, _ = split_data(val_set, base_classes)
        test_base, test_novel = split_data(test_set, base_classes)

        print(f"📊 Dataset prepared:")
        print(f"   Base classes: {len(base_classes)} Novel classes: {len(novel_classes)}")
        print(f"   Test base samples: {len(test_base)} Test novel samples: {len(test_novel)}")

        # EVALUATION 1: Standard Base Classes
        print("\n" + "="*60)
        print("🔄 EVALUATION 1: Standard Prompts on Base Classes")
        print("="*60)

        base_text_features = get_text_features_standard(model, base_classes, device)
        try:
            results['standard_base'] = eval(
                model=model,
                dataset=test_base,
                categories=base_classes,
                batch_size=batch_sz,
                device=device,
                text_features=base_text_features,
                label="🧠 Zero-shot evaluation on Base Classes"
            )
        finally:
            del base_text_features
            aggressive_cleanup()

        # EVALUATION 2: Standard Novel Classes
        print("\n" + "="*60)
        print("🔄 EVALUATION 2: Standard Prompts on Novel Classes")
        print("="*60)

        novel_text_features = get_text_features_standard(model, novel_classes, device)
        try:
            results['standard_novel'] = eval(
                model=model,
                dataset=test_novel,
                categories=novel_classes,
                batch_size=batch_sz,
                device=device,
                text_features=novel_text_features,
                label="🧠 Zero-shot evaluation on Novel Classes"
            )
        finally:
            del novel_text_features
            aggressive_cleanup()

        # EVALUATIONS 3 & 4: LLM-Enhanced (if prompts available)
        if has_llm_prompts:
            print("\n" + "="*60)
            print("🔄 EVALUATION 3: LLM-Enhanced Prompts on Base Classes")
            print("="*60)

            base_llm_features = get_llm_text_features(
                model, generated_prompts_for_classes, base_classes, CLASS_NAMES, device
            )
            try:
                results['llm_base'] = eval(
                    model=model,
                    dataset=test_base,
                    categories=base_classes,
                    batch_size=batch_sz,
                    device=device,
                    text_features=base_llm_features,
                    label="🌸 Zero-shot eval with LLM prompts on Base Classes"
                )
            finally:
                del base_llm_features
                aggressive_cleanup()

            print("\n" + "="*60)
            print("🔄 EVALUATION 4: LLM-Enhanced Prompts on Novel Classes")
            print("="*60)

            novel_llm_features = get_llm_text_features(
                model, generated_prompts_for_classes, novel_classes, CLASS_NAMES, device
            )
            try:
                results['llm_novel'] = eval(
                    model=model,
                    dataset=test_novel,
                    categories=novel_classes,
                    batch_size=batch_sz,
                    device=device,
                    text_features=novel_llm_features,
                    label="🌸 Zero-shot eval with LLM prompts on Novel Classes"
                )
            finally:
                del novel_llm_features
                aggressive_cleanup()

    return results

print("✅ Main evaluation suite function defined")

✅ Main evaluation suite function defined


In [10]:
#@title Results
def analyze_results(results):
    """
    Analyze and display comprehensive results from all evaluations.

    Calculates harmonic means between base and novel class performance,
    which is a standard metric for few-shot learning evaluation.

    Args:
        results: Dictionary containing evaluation results
    """
    print("\n" + "="*60)
    print("📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("="*60)

    def harmonic_mean(base_acc, novel_acc):
        """Calculate harmonic mean of two accuracies"""
        if base_acc > 0 and novel_acc > 0:
            return 2 / (1/base_acc + 1/novel_acc)
        return 0

    # Standard prompts analysis
    if 'standard_base' in results and 'standard_novel' in results:
        base_top1 = results['standard_base']['top1']
        novel_top1 = results['standard_novel']['top1']
        std_hm = harmonic_mean(base_top1, novel_top1)

        print("🔤 STANDARD PROMPTS RESULTS:")
        print(f"   📈 Harmonic Mean (Top-1): {std_hm*100:.2f}%")
        print(f"   🎯 Base Classes Top-1:    {base_top1*100:.2f}%")
        print(f"   🆕 Novel Classes Top-1:   {novel_top1*100:.2f}%")
        print(f"   📊 Base Classes Top-5:    {results['standard_base']['top5']*100:.2f}%")
        print(f"   📊 Novel Classes Top-5:   {results['standard_novel']['top5']*100:.2f}%")
        print()

    # LLM-enhanced prompts analysis
    if 'llm_base' in results and 'llm_novel' in results:
        base_top1_llm = results['llm_base']['top1']
        novel_top1_llm = results['llm_novel']['top1']
        llm_hm = harmonic_mean(base_top1_llm, novel_top1_llm)

        print("🤖 LLM-ENHANCED PROMPTS RESULTS:")
        print(f"   📈 Harmonic Mean (Top-1): {llm_hm*100:.2f}%")
        print(f"   🎯 Base Classes Top-1:    {base_top1_llm*100:.2f}%")
        print(f"   🆕 Novel Classes Top-1:   {novel_top1_llm*100:.2f}%")
        print(f"   📊 Base Classes Top-5:    {results['llm_base']['top5']*100:.2f}%")
        print(f"   📊 Novel Classes Top-5:   {results['llm_novel']['top5']*100:.2f}%")
        print()

    # Improvement analysis
    if all(key in results for key in ['standard_base', 'standard_novel', 'llm_base', 'llm_novel']):
        base_improvement = (base_top1_llm - base_top1) * 100
        novel_improvement = (novel_top1_llm - novel_top1) * 100
        hm_improvement = (llm_hm - std_hm) * 100

        print("📈 IMPROVEMENT WITH LLM PROMPTS:")
        print(f"   🎯 Base Classes:    {base_improvement:+.2f} percentage points")
        print(f"   🆕 Novel Classes:   {novel_improvement:+.2f} percentage points")
        print(f"   📈 Harmonic Mean:   {hm_improvement:+.2f} percentage points")
        print()

    # Confidence gap analysis
    print("🔍 CONFIDENCE GAP ANALYSIS:")
    for eval_name, eval_results in results.items():
        eval_display = eval_name.replace('_', ' ').title()
        print(f"\n   {eval_display}:")
        print(f"     ✅ Top-1 Hit Gap:     {eval_results['avg_gap_top1_hit']:.2f}%")
        print(f"     ❌ Top-5 Hit Gap:    {eval_results['avg_error_top5_hit']:.2f}%")
        print(f"     ❌ Top-10 Hit Gap:   {eval_results['avg_error_top10_hit']:.2f}%")
        print(f"     ❌ Top-10 Miss Gap:  {eval_results['avg_error_top10_miss']:.2f}%")

def save_results(results, filename="clip_evaluation_results.json"):
    """Save results to JSON file for later analysis"""
    with open(filename, 'w') as f:
        json.dump(results)

In [57]:
#@title Gemini setup
from google.colab import userdata

# Configure the client
GEMINI_API_KEY = "AIzaSyB1eqziqPqvJwJKS9eSGlzgoMeZsFkhfIs"
client = genai.Client(api_key=GEMINI_API_KEY)

# Define the grounding tool
grounding_tool = types.Tool(
    google_search=types.GoogleSearch()
)

# Configure generation settings for web search
web_call_config = types.GenerateContentConfig(
    tools=[grounding_tool]
)

# Define schema using Pydantic
class Prompts(BaseModel):
    d1: str
    d2: str
    d3: str
    d4: str
    d5: str
    d6: str
    d7: str

class PromptDescriptions(BaseModel):
    # This now expects a list of 'Prompts' objects
    # And we enforce that this list must contain exactly the same number
    # of elements as there are flower classes in the batch
    descriptions: List[Prompts] = Field(
        ..., # '...' indicates the field is required
        min_length=1, # This will be dynamically set or handled in the loop
        description="A list where each element is an object with 7 short descriptions for a specific category (d1-d10)."
    )

# Configure generation settings for prompt generation
prompt_gen_config = types.GenerateContentConfig(
    temperature=0.0,
    response_mime_type="application/json",
    response_schema=PromptDescriptions
)

def get_visual_features(class_name, dataset_name, generic_category)->str:
  web_search_prompt = f"""What are the unique visual characteristics that distinguish an object of the class {class_name} within the {dataset_name} dataset?
  If applicable, consider its features in relation to {generic_category} in general.
  Please perform a web search to find reliable references to ensure the accuracy of the visual description.
  Provide a detailed yet concise summary (bullet list style) (max 100 tokens) focusing exclusively on the visual features that are most helpful for differentiating this class from other classes within the dataset.
  You should avoid writing information about the dataset"""

  # Make the request
  response = client.models.generate_content(
      model="gemini-2.0-flash",
      contents=web_search_prompt,
      config=web_call_config,
  )

  # Return the grounded response
  return response.text

print("Google GenAI SDK configured successfully!")

Google GenAI SDK configured successfully!


In [None]:
#@title Feature Extraction and Saving (Resumable & Interactive)

# --- Configuration for resuming ---
output_filename = "features.json"
starting_index = 0

# --- Load existing features or create an empty list ---
features = []
if os.path.exists(output_filename):
    try:
        with open(output_filename, 'r') as f:
            features = json.load(f)
        print(f"Loaded {len(features)} existing features from '{output_filename}'.")
        # Adjust starting_index if loaded features are more than starting_index
        if starting_index < len(features):
            print(f"Adjusting starting_index to {len(features)} to append new entries.")
            starting_index = len(features)
    except json.JSONDecodeError:
        print(f"Warning: '{output_filename}' exists but is corrupted. Starting from scratch.")
        features = []
else:
    print(f"'{output_filename}' not found. A new file will be created.")

# 1) features list and saving file are handled above.

print(f"Starting feature extraction from index {starting_index} for {len(CLASS_NAMES) - starting_index} remaining classes...")

last_successfully_saved_index = -1
if features: # If features were loaded, the last index is length - 1
    last_successfully_saved_index = len(features) - 1

# Use slicing to iterate only over the remaining classes
for i in tqdm(range(starting_index, len(CLASS_NAMES)), desc="Extracting features"):
    class_name = CLASS_NAMES[i]
    print(f"\n--- Processing class {i+1}/{len(CLASS_NAMES)}: '{class_name}' (Index: {i}) ---")

    current_feature_description = None # To hold the result before user decision

    try:
        # Call the function to get visual features
        current_feature_description = get_visual_features(class_name, dataset_name, generic_category)
        print(f"Description of {class_name}:\n{current_feature_description}")

        user_input = input("Press C to save and continue, or A to abort: ").strip().upper()

        if user_input != 'A':
            # Ensure the features list has enough space up to the current index
            # This handles cases where starting_index might have been manually set
            while len(features) <= i:
                features.append(None) # Pad with None if needed

            features[i] = current_feature_description
            print(f"Description for '{class_name}' SAVED.")

            # Save the list into a features.json file immediately after saving a class
            try:
                with open(output_filename, 'w') as f:
                    json.dump(features, f, indent=4)
                print(f"Features saved incrementally to '{output_filename}' (current count: {len(features)})")
                last_successfully_saved_index = i # Update last saved index
            except Exception as e:
                print(f"CRITICAL ERROR: Could not save features to file after processing '{class_name}': {e}")
                print("Data might be lost if Colab session terminates. Please check file permissions or disk space.")
                # Even if saving failed, we continue unless 'A' was pressed, but warn the user.
        else:
            print(f"Aborting execution. Description for '{class_name}' was NOT SAVED.")
            break # Exit the loop

    except Exception as e:
        print(f"\nError processing class '{class_name}' at index {i}: {e}")
        print("Appending an error placeholder to maintain order if saving occurs. This entry will NOT be saved unless 'C' is pressed.")
        # If an error occurs, we still ask for user input to decide to save or abort.
        # The user can decide to save the error placeholder or abort.
        user_input_after_error = input("Error occurred. Press C to save an error placeholder and continue, or A to abort: ").strip().upper()

        if user_input_after_error == 'C':
            while len(features) <= i:
                features.append(None) # Pad with None if needed
            features[i] = f"ERROR at index {i}: {e}"
            print(f"Error placeholder for '{class_name}' SAVED.")
            try:
                with open(output_filename, 'w') as f:
                    json.dump(features, f, indent=4)
                print(f"Features saved incrementally to '{output_filename}' (current count: {len(features)})")
                last_successfully_saved_index = i
            except Exception as e_save:
                print(f"CRITICAL ERROR: Could not save error placeholder to file after processing '{class_name}': {e_save}")
                print("Data might be lost if Colab session terminates. Please check file permissions or disk space.")
        elif user_input_after_error == 'A':
            print(f"Aborting execution due to error. Description for '{class_name}' was NOT SAVED.")
            break # Exit the loop
        else:
            print(f"Invalid input after error ('{user_input_after_error}'). Skipping this class. Continuing to next.")


print("\n--- Feature Extraction Finished ---")
print(f"Final count of features in list: {len(features)}")
if last_successfully_saved_index != -1:
    print(f"Last successfully saved class index: {last_successfully_saved_index} (Class: '{CLASS_NAMES[last_successfully_saved_index]}')")
else:
    print("No classes were successfully saved during this run.")

In [61]:
# @title 3. LLM Prompt Generation (Batched Calls)

generated_prompts_for_classes = {}  # {class_name: [list of prompts]}
prompt_batch_size = 34
# Try to load LLM-generated descriptions
try:
    with open("features.json", "r") as f:
        classes_features = json.load(f)
    has_llm_descriptions = True
    print("✅ Found features file")
    print(f"📊 Loaded features for {len(classes_features)} classes")
except FileNotFoundError:
    print("❌ No features found. Will only run standard evaluation.")
    has_llm_descriptions = False

print("\n--- Generating the prompts for each class using LLM ---")
print("This might take a few minutes depending on the number of classes and API response times.")

for i in tqdm(range(0, len(CLASS_NAMES), prompt_batch_size), desc="Generating prompts in batches"):
    current_batch_classes = CLASS_NAMES[i : min(i + prompt_batch_size, len(CLASS_NAMES))]
    current_batch_descriptions = classes_features[i : min(i + prompt_batch_size, len(classes_features))]

    current_batch_for_llm = []
    for class_name, description_string in zip(current_batch_classes, current_batch_descriptions):
        current_batch_for_llm.append({
            "class_name": class_name,
            "visual_features_description": description_string.strip()
        })
    llm_input_json_string = json.dumps(current_batch_for_llm, indent=4)

    # Dynamically build the prompt for the current batch
    prompt_batch = f"""You are an expert in visual recognition and prompt engineering for Vision-Language Models (VLMs) like CLIP. Your task is to generate **7 diverse, concise, and highly discriminative text prompts** for each object class, specifically optimized for **zero-shot classification performance with CLIP**.

You will be provided with a JSON list of object classes. For each object class in the list, you will receive:
- `class_name`: The precise name of the object category.
- `visual_features_description`: A detailed and reliable description of its key distinguishing visual characteristics. This description includes features that make the category unique. Rely more on this description than on your knowledge.

**Your Goal:**
For each `class_name` provided, generate exactly **5 distinct and effective text prompts**. These prompts should aim to maximize the accuracy of a CLIP model when classifying images of this particular class.

**Guidelines for Prompt Construction:**
1.  **Direct CLIP Alignment:** Focus exclusively on visual attributes that a VLM like CLIP can effectively recognize from an image. Avoid abstract concepts, misleading ategories names, or non-visual information.
2.  **Leverage Visual Description:** Integrate key details from the `visual_features_description` into your prompts. This is crucial for distinguishing similar classes.
3.  **Attribute Emphasis:** Clearly highlight unique visual traits such as color, shape, pattern, texture, size, or distinctive parts of the object.
4.  **Contextualization & Image Type:** In addition to object-specific contexts, consider the dataset's characteristics. For the current dataset, which is '{dataset_name}', you should sometimes include terms describing the image style (sketch, drawing, professional photo, headshot, etc.) or broader environmental context relevant to the dataset (e.g. day/night, surroundings, distance, feeling).
5.  **Conciseness & Token Limit:** Keep each prompt succinct, aiming for 20 tokens (10-12 words). Focus on high-signal visual words.
6.  **Diversity for Ensembling:** Ensure the 5 prompts for a single class are not mere paraphrases. They should explore different facets of the class's visual identity, leading to diverse embeddings that enhance the power of prompt ensembling.

**Input Classes for Prompt Generation:**
```json
{llm_input_json_string}
```
"""

    try:
      response_llm = client.models.generate_content(
          model="gemini-2.5-flash",
          contents=prompt_batch,
          config=prompt_gen_config,
      )

      parsed_response = PromptDescriptions.model_validate_json(response_llm.text)

      for j, prompt_obj in enumerate(parsed_response.descriptions):
          category_name = current_batch_classes[j]
          generated_prompts_for_classes[category_name] = [
              prompt_obj.d1,
              prompt_obj.d2,
              prompt_obj.d3,
              prompt_obj.d4,
              prompt_obj.d5,
              prompt_obj.d6,
              prompt_obj.d7,
          ]

    except Exception as e:
        print(f"\nError generating prompts for batch starting with '{current_batch_classes[0]}': {e}")
        print(f"Raw response (if available): {response_llm.text if 'response_llm' in locals() else 'N/A'}")

print("\nLLM prompt generation complete.")

with open("generated_prompts.json", "w") as f:
    json.dump(generated_prompts_for_classes, f, indent=4)


✅ Found features file
📊 Loaded features for 102 classes

--- Generating the prompts for each class using LLM ---
This might take a few minutes depending on the number of classes and API response times.


Generating prompts in batches:   0%|          | 0/3 [00:00<?, ?it/s]


LLM prompt generation complete.


In [62]:
#@title Run Evaluation
results = run_evaluation_suite()

🚀 Starting memory-optimized CLIP evaluation suite...
✅ Found generated prompts file
📊 Loaded prompts for 102 classes
🤖 Loading CLIP model: ViT-B/16
✅ CLIP model loaded on cuda
📱 Using device: cuda

📥 Loading and preparing datasets...
📥 Downloading and loading Flowers102 dataset...
✅ Dataset loaded - Train: 1020, Val: 1020, Test: 6149
📊 Class split - Base: 51, Novel: 51
📊 Data split - Base samples: 510, Novel samples: 510
📊 Data split - Base samples: 510, Novel samples: 510
📊 Data split - Base samples: 2473, Novel samples: 3676
📊 Dataset prepared:
   Base classes: 51 Novel classes: 51
   Test base samples: 2473 Test novel samples: 3676

🔄 EVALUATION 1: Standard Prompts on Base Classes
📝 Generated 51 standard prompts
📄 Example prompt: 'a photo of a pink primrose, a type of flower.'


🧠 Zero-shot evaluation on Base Classes:   0%|          | 0/20 [00:00<?, ?it/s]


📊 Total samples evaluated: 2473

✅ Top-1 Accuracy:      71.29%
✅ Top-5 Accuracy:      90.86%
✅ Top-10 Accuracy:     97.53%
✅ Avg. Conf. Gap (Top-1 hit):      3.27%
❌ Avg. Conf. Gap (Top-5 hit):     1.67%
❌ Avg. Conf. Gap (Top-10 hit):    3.86%
❌ Avg. Conf. Gap (Beyond top-10): 5.00%

🔄 EVALUATION 2: Standard Prompts on Novel Classes
📝 Generated 51 standard prompts
📄 Example prompt: 'a photo of a wild pansy, a type of flower.'


🧠 Zero-shot evaluation on Novel Classes:   0%|          | 0/29 [00:00<?, ?it/s]


📊 Total samples evaluated: 3676

✅ Top-1 Accuracy:      78.24%
✅ Top-5 Accuracy:      89.15%
✅ Top-10 Accuracy:     92.79%
✅ Avg. Conf. Gap (Top-1 hit):      3.69%
❌ Avg. Conf. Gap (Top-5 hit):     1.37%
❌ Avg. Conf. Gap (Top-10 hit):    3.45%
❌ Avg. Conf. Gap (Beyond top-10): 5.70%

🔄 EVALUATION 3: LLM-Enhanced Prompts on Base Classes
🤖 Processing LLM-generated prompts for 51 classes...


🌸 Zero-shot eval with LLM prompts on Base Classes:   0%|          | 0/20 [00:00<?, ?it/s]


📊 Total samples evaluated: 2473

✅ Top-1 Accuracy:      77.07%
✅ Top-5 Accuracy:      94.86%
✅ Top-10 Accuracy:     98.79%
✅ Avg. Conf. Gap (Top-1 hit):      3.66%
❌ Avg. Conf. Gap (Top-5 hit):     1.97%
❌ Avg. Conf. Gap (Top-10 hit):    4.62%
❌ Avg. Conf. Gap (Beyond top-10): 6.12%

🔄 EVALUATION 4: LLM-Enhanced Prompts on Novel Classes
🤖 Processing LLM-generated prompts for 51 classes...


🌸 Zero-shot eval with LLM prompts on Novel Classes:   0%|          | 0/29 [00:00<?, ?it/s]


📊 Total samples evaluated: 3676

✅ Top-1 Accuracy:      78.40%
✅ Top-5 Accuracy:      91.57%
✅ Top-10 Accuracy:     93.12%
✅ Avg. Conf. Gap (Top-1 hit):      4.37%
❌ Avg. Conf. Gap (Top-5 hit):     1.49%
❌ Avg. Conf. Gap (Top-10 hit):    4.96%
❌ Avg. Conf. Gap (Beyond top-10): 9.95%
🧹 Cleaning up CLIP model...


In [63]:
#@title Show results
analyze_results(results=results)


📊 COMPREHENSIVE RESULTS ANALYSIS
🔤 STANDARD PROMPTS RESULTS:
   📈 Harmonic Mean (Top-1): 74.60%
   🎯 Base Classes Top-1:    71.29%
   🆕 Novel Classes Top-1:   78.24%
   📊 Base Classes Top-5:    90.86%
   📊 Novel Classes Top-5:   89.15%

🤖 LLM-ENHANCED PROMPTS RESULTS:
   📈 Harmonic Mean (Top-1): 77.73%
   🎯 Base Classes Top-1:    77.07%
   🆕 Novel Classes Top-1:   78.40%
   📊 Base Classes Top-5:    94.86%
   📊 Novel Classes Top-5:   91.57%

📈 IMPROVEMENT WITH LLM PROMPTS:
   🎯 Base Classes:    +5.78 percentage points
   🆕 Novel Classes:   +0.16 percentage points
   📈 Harmonic Mean:   +3.13 percentage points

🔍 CONFIDENCE GAP ANALYSIS:

   Standard Base:
     ✅ Top-1 Hit Gap:     3.27%
     ❌ Top-5 Hit Gap:    1.67%
     ❌ Top-10 Hit Gap:   3.86%
     ❌ Top-10 Miss Gap:  5.00%

   Standard Novel:
     ✅ Top-1 Hit Gap:     3.69%
     ❌ Top-5 Hit Gap:    1.37%
     ❌ Top-10 Hit Gap:   3.45%
     ❌ Top-10 Miss Gap:  5.70%

   Llm Base:
     ✅ Top-1 Hit Gap:     3.66%
     ❌ Top-5 Hit Gap: