# CLIPP Evaluation: retrieval metrics

This notebook loads the best CLIPP checkpoint, computes image and text embeddings on the validation set,
and reports retrieval metrics (Top-1, Top-5, Top-10).

Ensure you run this from the repository root so relative paths match (or update the paths below).

In [None]:
# Configuration
from pathlib import Path
import sys
repo_root = Path('..').resolve()  # adjust if running from a different CWD
sys.path.append(str(repo_root))

CHECKPOINT_PATH = Path('models/baseCLIPP/checkpoints/best_clipp.pth')
VAL_CSV = Path('data/alpaca_mbj_bandgap_val.csv')
BATCH_SIZE = 32
DEVICE = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}, checkpoint: {CHECKPOINT_PATH}')

In [None]:
# Imports and model/dataset loading
import torch
import pandas as pd
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

# Import CLIPP and ImageTextDataset from the training script
from models.baseCLIPP.training import CLIPP, ImageTextDataset

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = CLIPP(proj_dim=256)
device = torch.device(DEVICE)

# Load checkpoint
assert CHECKPOINT_PATH.exists(), f"Checkpoint not found: {CHECKPOINT_PATH}"
ckpt = torch.load(str(CHECKPOINT_PATH), map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
model.to(device)
model.eval()

# Load validation data
val_df = pd.read_csv(VAL_CSV)
val_ds = ImageTextDataset(val_df, tokenizer, train=False)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
print(f'Validation examples: {len(val_ds)}')

In [None]:
# Compute embeddings for entire validation set
import torch
image_embs = []
text_embs = []
captions = []

with torch.no_grad():
    for batch in val_loader:
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        txts = batch['caption']
        img_e, txt_e = model(images, input_ids, attention_mask)
        image_embs.append(img_e.cpu())
        text_embs.append(txt_e.cpu())
        captions.extend(txts)

image_embeddings = torch.cat(image_embs, dim=0)
text_embeddings = torch.cat(text_embs, dim=0)
print(f'Computed embeddings: images {image_embeddings.shape}, texts {text_embeddings.shape}')

In [None]:
# Compute similarity scores and retrieval metrics
import torch
scores = text_embeddings @ image_embeddings.T  # (N_text, N_image)

# Top-k metrics as requested
top1 = torch.mean((torch.argmax(scores, dim=1) == torch.arange(scores.shape[0], device=scores.device)).float()).item()
top5 = torch.mean(
    torch.tensor([
        i in torch.topk(scores[i], 5).indices.tolist()
        for i in range(scores.shape[0])
    ], dtype=torch.float32, device=scores.device)
).item()
top10 = torch.mean(
    torch.tensor([
        i in torch.topk(scores[i], 10).indices.tolist()
        for i in range(scores.shape[0])
    ], dtype=torch.float32, device=scores.device)
).item()

print(f"Top-1: {top1:.4f}, Top-5: {top5:.4f}, Top-10: {top10:.4f}")

Notes:
- If the validation set is large, computing the full similarity matrix may be memory intensive. Consider computing in chunks or using smaller batches.
- You can extend this notebook to compute per-class metrics, confusion matrices, or visualize nearest neighbor matches.