# Zero-Shot Classification with CLIP

This notebook demonstrates how to use the VLM evaluation framework for zero-shot classification using CLIP.
We will query the model with arbitrary class names and see how it classifies images.

## 1. Setup and Imports

In [None]:
import sys
from pathlib import Path

# Add project to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import torch
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO

# Import framework components
from vlm_eval import EncoderRegistry, HeadRegistry
from vlm_eval.encoders import CLIPEncoder
from vlm_eval.heads import ZeroShotHead

print("✓ Imports successful!")

## 2. Load CLIP Encoder

We'll use the ViT-B-32 variant of CLIP.

In [None]:
# Create encoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

encoder = EncoderRegistry.get(
    "clip", 
    variant="ViT-B-32", 
    pretrained="laion2b_s34b_b79k"
)
encoder = encoder.to(device)
encoder.eval()

print(f"\nEncoder: {encoder.__class__.__name__}")
print(f"Output channels: {encoder.output_channels}")

## 3. Define Classes and Create Head

Here we define the classes we want to query. You can change these to anything!

In [None]:
# Define arbitrary classes
class_names = [
    "a dog",
    "a cat",
    "a car",
    "a bicycle",
    "a person",
    "a tree"
]

print(f"Classes: {class_names}")

# Create zero-shot head
head = HeadRegistry.get(
    "zero_shot",
    encoder=encoder,
    class_names=class_names,
    template="{}"  # We already included 'a' in class names, or use default template
)
head = head.to(device)

print("✓ Zero-shot head created and text embeddings computed!")

## 4. Load and Preprocess Image

We'll load an image from a URL.

In [None]:
# Load image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"  # Cats image
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Display image
plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.axis('off')
plt.show()

# Preprocess
# CLIP encoder has a preprocess method from open_clip
image_tensor = encoder.preprocess(image).unsqueeze(0).to(device)
print(f"Image tensor shape: {image_tensor.shape}")

## 5. Run Prediction

We'll get the logits and convert them to probabilities.

In [None]:
with torch.no_grad():
    # Encoder forward pass
    features = encoder(image_tensor)
    
    # Head forward pass
    logits = head(features)
    
    # Softmax to get probabilities
    probs = logits.softmax(dim=-1)

# Print results
print("\nPredictions:")
for cls, prob in zip(class_names, probs[0]):
    print(f"{cls}: {prob.item():.2%}")
    
# Plot probabilities
plt.figure(figsize=(10, 5))
plt.bar(class_names, probs[0].cpu().numpy())
plt.title("Class Probabilities")
plt.ylabel("Probability")
plt.show()

## 6. Try with different classes

Let's try a different set of classes on the same image.

In [None]:
new_classes = ["indoor", "outdoor", "animal", "vehicle", "furniture"]

# Create new head (fast since we just compute text embeddings)
head2 = HeadRegistry.get(
    "zero_shot",
    encoder=encoder,
    class_names=new_classes
).to(device)

with torch.no_grad():
    logits2 = head2(features)
    probs2 = logits2.softmax(dim=-1)

print("\nNew Predictions:")
for cls, prob in zip(new_classes, probs2[0]):
    print(f"{cls}: {prob.item():.2%}")