# CLIP Model Evaluation and Visualization

Evaluate the trained CLIP model with:
- **Recall@K Metrics**: Image-to-text and text-to-image retrieval
- **Text Query Visualization**: Retrieve top-K images for text queries
- **Zero-shot Classification**: Classify images using text prompts

## Prerequisites
1. Run `coco_dataset_prep.ipynb` to prepare the dataset
2. Run `train_clip.ipynb` to train the model
3. Have a trained model checkpoint at `/content/checkpoints/best_model.pt`

## 1. Setup and Load Model

In [None]:
# Imports
import os
import json
import random
from pathlib import Path
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image

from transformers import CLIPTokenizer, CLIPTextModel
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

In [None]:
# Configuration
DATASET_DIR = Path('/content/coco2014')
CHECKPOINT_PATH = Path('/content/checkpoints/best_model.pt')
MODEL_NAME = 'openai/clip-vit-base-patch32'

CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
IMAGE_SIZE = 224

print(f"Dataset directory: {DATASET_DIR}")
print(f"Checkpoint path: {CHECKPOINT_PATH}")
print(f"Checkpoint exists: {CHECKPOINT_PATH.exists()}")

In [None]:
# Define model architecture (same as training)

class ResNet50ImageEncoder(nn.Module):
    """ResNet50 image encoder."""
    def __init__(self, pretrained=True):
        super().__init__()
        resnet = models.resnet50(pretrained=pretrained)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.output_dim = 2048
        
    def forward(self, x):
        features = self.features(x)
        return features.view(features.size(0), -1)


class ProjectionHead(nn.Module):
    """2-layer MLP projection head."""
    def __init__(self, input_dim=2048, hidden_dim=1024, output_dim=512):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, x):
        return self.projection(x)


class CLIPModel(nn.Module):
    """Combined CLIP model."""
    def __init__(self, text_encoder, freeze_text_encoder=True):
        super().__init__()
        self.text_encoder = text_encoder
        if freeze_text_encoder:
            for param in self.text_encoder.parameters():
                param.requires_grad = False
            self.text_encoder.eval()
        
        self.image_encoder = ResNet50ImageEncoder(pretrained=True)
        self.projection_head = ProjectionHead(input_dim=2048, hidden_dim=1024, output_dim=512)
        
    def encode_image(self, images):
        features = self.image_encoder(images)
        embeddings = self.projection_head(features)
        return F.normalize(embeddings, p=2, dim=1)
    
    def encode_text(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.pooler_output
            return F.normalize(embeddings, p=2, dim=1)
    
    def forward(self, images, input_ids, attention_mask):
        image_embeddings = self.encode_image(images)
        text_embeddings = self.encode_text(input_ids, attention_mask)
        return image_embeddings, text_embeddings

print("✓ Model architecture defined")

In [None]:
# Load tokenizer and create model
print("Loading CLIP text encoder...")
tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME)
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME)
text_encoder = text_encoder.to(device)

# Create model
model = CLIPModel(text_encoder=text_encoder, freeze_text_encoder=True)
model = model.to(device)

# Load trained weights
if CHECKPOINT_PATH.exists():
    print(f"\nLoading checkpoint from {CHECKPOINT_PATH}...")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✓ Loaded checkpoint from epoch {checkpoint['epoch']+1}")
    print(f"  Validation loss: {checkpoint['val_loss']:.4f}")
    print(f"  Validation accuracy: {checkpoint['val_acc']:.4f}")
else:
    print("⚠️  No checkpoint found. Using untrained model.")

model.eval()
print("\n✓ Model ready for evaluation")

## 2. Load Validation Dataset

In [None]:
class COCOClipDataset(Dataset):
    """COCO Dataset for evaluation."""
    def __init__(self, split='val', dataset_dir=DATASET_DIR, return_all_captions=True):
        self.split = split
        self.image_dir = dataset_dir / f'{split}2014'
        self.cache_file = dataset_dir / f'{split}_text_embeddings.pt'
        self.return_all_captions = return_all_captions
        
        # Image transforms
        self.transform = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD)
        ])
        
        # Load cached embeddings
        cache = torch.load(self.cache_file)
        self.cache_data = cache['data']
        
        print(f"Loaded {split} dataset: {len(self.cache_data):,} images")
        
    def __len__(self):
        return len(self.cache_data)
    
    def __getitem__(self, idx):
        item = self.cache_data[idx]
        image_id = item['image_id']
        embeddings = item['embeddings']
        captions = item['captions']
        
        # Load image
        image_filename = f'COCO_{self.split}2014_{image_id:012d}.jpg'
        image_path = self.image_dir / image_filename
        
        try:
            image = Image.open(image_path).convert('RGB')
            image_tensor = self.transform(image)
        except:
            image = None
            image_tensor = torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE)
        
        if self.return_all_captions:
            return {
                'image': image_tensor,
                'image_raw': image,
                'text_embeddings': embeddings,
                'captions': captions,
                'image_id': image_id,
                'image_path': str(image_path)
            }
        else:
            caption_idx = random.randint(0, len(captions) - 1)
            return {
                'image': image_tensor,
                'image_raw': image,
                'text_embedding': embeddings[caption_idx],
                'caption': captions[caption_idx],
                'image_id': image_id,
                'image_path': str(image_path)
            }

# Create validation dataset
val_dataset = COCOClipDataset(split='val', return_all_captions=False)
print(f"✓ Dataset loaded: {len(val_dataset):,} samples")

## 3. Compute Embeddings for Entire Validation Set

In [None]:
@torch.no_grad()
def compute_embeddings(model, dataset, batch_size=128):
    """
    Compute image and text embeddings for entire dataset.
    
    Returns:
        image_embeddings: [N, 512]
        text_embeddings: [N, 512]
        captions: List[str] of length N
        image_ids: List[int] of length N
    """
    model.eval()
    
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    all_image_embeddings = []
    all_text_embeddings = []
    all_captions = []
    all_image_ids = []
    
    print("Computing embeddings...")
    for batch in tqdm(loader):
        images = batch['image'].to(device)
        text_embeddings = batch['text_embedding'].to(device)
        
        # Encode images
        image_embeddings = model.encode_image(images)
        
        # Store
        all_image_embeddings.append(image_embeddings.cpu())
        all_text_embeddings.append(text_embeddings.cpu())
        all_captions.extend(batch['caption'])
        all_image_ids.extend(batch['image_id'].tolist())
    
    # Concatenate
    image_embeddings = torch.cat(all_image_embeddings, dim=0)
    text_embeddings = torch.cat(all_text_embeddings, dim=0)
    
    print(f"✓ Computed embeddings:")
    print(f"  Images: {image_embeddings.shape}")
    print(f"  Texts: {text_embeddings.shape}")
    
    return image_embeddings, text_embeddings, all_captions, all_image_ids

# Compute embeddings
image_embeddings, text_embeddings, captions, image_ids = compute_embeddings(model, val_dataset)

## 4. Recall@K Metrics

Compute retrieval performance metrics:
- **Image → Text**: Given an image, retrieve matching captions
- **Text → Image**: Given a caption, retrieve matching images

In [None]:
def compute_recall_at_k(similarity_matrix, k_values=[1, 5, 10]):
    """
    Compute Recall@K for retrieval.
    
    Args:
        similarity_matrix: [N, M] similarity scores
        k_values: List of K values to compute
        
    Returns:
        recall_dict: Dictionary of {K: recall_value}
    """
    N = similarity_matrix.shape[0]
    
    # Get top-K indices for each query
    # similarity_matrix[i, j] = similarity between query i and candidate j
    # For each query, we want to find if the correct match is in top-K
    
    # Assuming diagonal elements are correct matches (image i ↔ text i)
    top_k_indices = torch.topk(similarity_matrix, k=max(k_values), dim=1, largest=True).indices
    
    recalls = {}
    for k in k_values:
        # Check if correct index (i) is in top-k for query i
        correct_in_top_k = torch.any(
            top_k_indices[:, :k] == torch.arange(N).unsqueeze(1),
            dim=1
        )
        recall = correct_in_top_k.float().mean().item()
        recalls[k] = recall
    
    return recalls


# Compute similarity matrix
print("Computing similarity matrix...")
similarity_matrix = torch.matmul(image_embeddings, text_embeddings.T)  # [N, N]
print(f"Similarity matrix shape: {similarity_matrix.shape}")

# Image → Text retrieval
print("\nImage → Text Retrieval:")
i2t_recalls = compute_recall_at_k(similarity_matrix, k_values=[1, 5, 10])
for k, recall in i2t_recalls.items():
    print(f"  Recall@{k}: {recall:.4f} ({recall*100:.2f}%)")

# Text → Image retrieval
print("\nText → Image Retrieval:")
t2i_recalls = compute_recall_at_k(similarity_matrix.T, k_values=[1, 5, 10])
for k, recall in t2i_recalls.items():
    print(f"  Recall@{k}: {recall:.4f} ({recall*100:.2f}%)")

# Average recalls
print("\nAverage (I2T + T2I):")
for k in [1, 5, 10]:
    avg_recall = (i2t_recalls[k] + t2i_recalls[k]) / 2
    print(f"  Recall@{k}: {avg_recall:.4f} ({avg_recall*100:.2f}%)")

## 5. Text Query → Image Retrieval Visualization

Given a text query, retrieve and display the top-K most similar images.

In [None]:
def denormalize_image(tensor):
    """Denormalize image tensor for display."""
    mean = torch.tensor(CLIP_MEAN).view(3, 1, 1)
    std = torch.tensor(CLIP_STD).view(3, 1, 1)
    tensor = tensor * std + mean
    return torch.clamp(tensor, 0, 1)


@torch.no_grad()
def retrieve_images_for_text(query_text, model, tokenizer, dataset, image_embeddings, top_k=5):
    """
    Retrieve top-K images for a text query.
    
    Args:
        query_text: String text query
        model: CLIP model
        tokenizer: CLIP tokenizer
        dataset: Dataset to retrieve images from
        image_embeddings: Pre-computed image embeddings [N, 512]
        top_k: Number of images to retrieve
        
    Returns:
        top_images: List of PIL images
        top_scores: List of similarity scores
        top_captions: List of captions
        top_indices: List of dataset indices
    """
    model.eval()
    
    # Encode query text
    inputs = tokenizer(
        [query_text],
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors='pt'
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    query_embedding = model.encode_text(inputs['input_ids'], inputs['attention_mask'])
    query_embedding = query_embedding.cpu()
    
    # Compute similarities
    similarities = torch.matmul(query_embedding, image_embeddings.T).squeeze(0)
    
    # Get top-K
    top_scores, top_indices = torch.topk(similarities, k=min(top_k, len(similarities)), largest=True)
    
    # Retrieve images
    top_images = []
    top_captions = []
    
    for idx in top_indices:
        sample = dataset[idx.item()]
        top_images.append(sample['image_raw'])
        top_captions.append(sample['caption'])
    
    return top_images, top_scores.tolist(), top_captions, top_indices.tolist()


def visualize_text_query_results(query_text, top_images, top_scores, top_captions, figsize=(15, 3)):
    """Visualize retrieval results for a text query."""
    num_images = len(top_images)
    
    fig, axes = plt.subplots(1, num_images, figsize=figsize)
    if num_images == 1:
        axes = [axes]
    
    fig.suptitle(f'Query: "{query_text}"', fontsize=14, fontweight='bold', y=1.05)
    
    for i, (img, score, caption) in enumerate(zip(top_images, top_scores, top_captions)):
        axes[i].imshow(img)
        axes[i].axis('off')
        
        # Add score and caption
        title = f"Rank {i+1}\nScore: {score:.3f}"
        axes[i].set_title(title, fontsize=10)
        
        # Add caption below
        wrapped_caption = '\n'.join([caption[j:j+25] for j in range(0, len(caption), 25)])
        axes[i].text(0.5, -0.15, wrapped_caption, 
                     transform=axes[i].transAxes,
                     ha='center', va='top', fontsize=8,
                     bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
    
    plt.tight_layout()
    plt.show()

print("✓ Text query retrieval functions defined")

In [None]:
# Example: Retrieve images for text queries
queries = ['sport', 'a cat', 'food on a plate', 'people playing', 'a car on the street']

for query in queries:
    print(f"\nQuery: '{query}'")
    top_images, top_scores, top_captions, top_indices = retrieve_images_for_text(
        query, model, tokenizer, val_dataset, image_embeddings, top_k=5
    )
    visualize_text_query_results(query, top_images, top_scores, top_captions)
    print("-" * 60)

## 6. Zero-Shot Image Classification

Given an image and a list of class labels, classify the image by computing similarity with each class.

In [None]:
@torch.no_grad()
def zero_shot_classify(image, class_labels, model, tokenizer, use_templates=True):
    """
    Classify an image using zero-shot CLIP.
    
    Args:
        image: PIL Image or tensor [3, 224, 224]
        class_labels: List of class names (e.g., ['a person', 'an animal'])
        model: CLIP model
        tokenizer: CLIP tokenizer
        use_templates: If True, use prompt templates
        
    Returns:
        probs: Probability distribution over classes
        predicted_class: Index of predicted class
        class_scores: Raw similarity scores
    """
    model.eval()
    
    # Prepare image
    if isinstance(image, Image.Image):
        transform = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD)
        ])
        image_tensor = transform(image).unsqueeze(0).to(device)
    else:
        image_tensor = image.unsqueeze(0).to(device)
    
    # Encode image
    image_embedding = model.encode_image(image_tensor)
    
    # Prepare text prompts
    if use_templates:
        # Use prompt templates (like CLIP paper)
        templates = [
            'a photo of {}',
            'a picture of {}',
            'an image of {}',
        ]
        texts = []
        for label in class_labels:
            for template in templates:
                texts.append(template.format(label))
    else:
        texts = class_labels
    
    # Encode texts
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors='pt'
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    text_embeddings = model.encode_text(inputs['input_ids'], inputs['attention_mask'])
    
    # Compute similarities
    if use_templates:
        # Average over templates
        text_embeddings = text_embeddings.view(len(class_labels), len(templates), -1)
        text_embeddings = text_embeddings.mean(dim=1)  # [num_classes, 512]
    
    similarities = torch.matmul(image_embedding, text_embeddings.T).squeeze(0)
    
    # Convert to probabilities
    probs = F.softmax(similarities * 100, dim=0)  # Temperature scaling
    
    predicted_class = torch.argmax(similarities).item()
    
    return probs.cpu(), predicted_class, similarities.cpu()


def visualize_classification(image, class_labels, probs, predicted_class):
    """Visualize zero-shot classification results."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Display image
    if isinstance(image, torch.Tensor):
        image_display = denormalize_image(image).permute(1, 2, 0).numpy()
        ax1.imshow(image_display)
    else:
        ax1.imshow(image)
    ax1.axis('off')
    ax1.set_title('Input Image', fontsize=12, fontweight='bold')
    
    # Display probabilities
    y_pos = np.arange(len(class_labels))
    colors = ['green' if i == predicted_class else 'skyblue' for i in range(len(class_labels))]
    
    ax2.barh(y_pos, probs.numpy(), color=colors)
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(class_labels)
    ax2.set_xlabel('Probability', fontsize=11)
    ax2.set_title('Class Probabilities', fontsize=12, fontweight='bold')
    ax2.set_xlim(0, 1)
    
    # Add percentage labels
    for i, prob in enumerate(probs.numpy()):
        ax2.text(prob + 0.02, i, f'{prob*100:.1f}%', va='center')
    
    # Highlight prediction
    ax2.axhline(predicted_class, color='green', linestyle='--', alpha=0.3, linewidth=2)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Predicted class: '{class_labels[predicted_class]}' ({probs[predicted_class]*100:.2f}%)")

print("✓ Zero-shot classification functions defined")

In [None]:
# Example: Classify random images
num_examples = 5
class_labels = ['a person', 'an animal', 'a landscape', 'food', 'a vehicle']

print(f"Classifying {num_examples} random images...\n")

for i in range(num_examples):
    # Get random sample
    idx = random.randint(0, len(val_dataset) - 1)
    sample = val_dataset[idx]
    
    image = sample['image_raw']
    true_caption = sample['caption']
    
    print(f"Example {i+1}")
    print(f"True caption: '{true_caption}'")
    
    # Classify
    probs, predicted_class, scores = zero_shot_classify(
        image, class_labels, model, tokenizer, use_templates=True
    )
    
    # Visualize
    visualize_classification(image, class_labels, probs, predicted_class)
    print("-" * 80)
    print()

## 7. Custom Classification Examples

Try your own classification tasks!

In [None]:
# Example 1: Indoor vs Outdoor
idx = random.randint(0, len(val_dataset) - 1)
sample = val_dataset[idx]

class_labels = ['an indoor scene', 'an outdoor scene']
probs, pred, scores = zero_shot_classify(sample['image_raw'], class_labels, model, tokenizer)

print(f"True caption: '{sample['caption']}'")
visualize_classification(sample['image_raw'], class_labels, probs, pred)

In [None]:
# Example 2: Activity classification
idx = random.randint(0, len(val_dataset) - 1)
sample = val_dataset[idx]

class_labels = ['people eating', 'people playing sports', 'people working', 'people relaxing']
probs, pred, scores = zero_shot_classify(sample['image_raw'], class_labels, model, tokenizer)

print(f"True caption: '{sample['caption']}'")
visualize_classification(sample['image_raw'], class_labels, probs, pred)

In [None]:
# Example 3: Object detection
idx = random.randint(0, len(val_dataset) - 1)
sample = val_dataset[idx]

class_labels = ['a dog', 'a cat', 'a bird', 'a horse', 'a cow', 'a sheep']
probs, pred, scores = zero_shot_classify(sample['image_raw'], class_labels, model, tokenizer)

print(f"True caption: '{sample['caption']}'")
visualize_classification(sample['image_raw'], class_labels, probs, pred)

## 8. Summary Report

In [None]:
# Generate evaluation summary
print(f"{'='*60}")
print("EVALUATION SUMMARY")
print(f"{'='*60}\n")

print("Dataset:")
print(f"  Validation set size: {len(val_dataset):,} images")
print(f"  Embedding dimension: {image_embeddings.shape[1]}")

print("\nRetrieval Performance:")
print("\n  Image → Text Retrieval:")
for k, recall in i2t_recalls.items():
    print(f"    Recall@{k}: {recall:.4f} ({recall*100:.2f}%)")

print("\n  Text → Image Retrieval:")
for k, recall in t2i_recalls.items():
    print(f"    Recall@{k}: {recall:.4f} ({recall*100:.2f}%)")

print("\n  Average Recall:")
for k in [1, 5, 10]:
    avg_recall = (i2t_recalls[k] + t2i_recalls[k]) / 2
    print(f"    Recall@{k}: {avg_recall:.4f} ({avg_recall*100:.2f}%)")

print("\nCapabilities Demonstrated:")
print("  ✓ Text query → Image retrieval")
print("  ✓ Zero-shot image classification")
print("  ✓ Multi-class categorization")
print("  ✓ Prompt template ensembling")

print(f"\n{'='*60}")

# Save evaluation results
eval_results = {
    'i2t_recalls': i2t_recalls,
    't2i_recalls': t2i_recalls,
    'dataset_size': len(val_dataset),
    'embedding_dim': image_embeddings.shape[1],
}

torch.save(eval_results, '/content/logs/evaluation_results.pt')
print("\n✓ Evaluation results saved to /content/logs/evaluation_results.pt")