# CLIPP Evaluation: retrieval metrics

This notebook loads the best CLIPP checkpoint, computes image and text embeddings on the validation set,
and reports retrieval metrics (Top-1, Top-5, Top-10).

Ensure you run this from the repository root so relative paths match (or update the paths below).

In [1]:
# Configuration
from pathlib import Path
import sys
repo_root = Path('..').resolve()  # adjust if running from a different CWD
sys.path.append(str(repo_root))

CHECKPOINT_PATH = Path('checkpoints/best_clipp.pth')
VAL_CSV = Path('../../data/alpaca_mbj_bandgap_test.csv')
BATCH_SIZE = 32
DEVICE = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}, checkpoint: {CHECKPOINT_PATH}')

Using device: cuda, checkpoint: checkpoints/best_clipp.pth


In [2]:
repo_root

PosixPath('/home/jipengsun/MaterialVision/models')

In [3]:
# Imports and model/dataset loading
import torch
import pandas as pd
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

# Import CLIPP and ImageTextDataset from the training script
from training import CLIPP, ImageTextDataset

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = CLIPP(proj_dim=256)
device = torch.device(DEVICE)

# Load checkpoint
assert CHECKPOINT_PATH.exists(), f"Checkpoint not found: {CHECKPOINT_PATH}"
ckpt = torch.load(str(CHECKPOINT_PATH), map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
model.to(device)
model.eval()

# Load validation data
val_df = pd.read_csv(VAL_CSV)
val_ds = ImageTextDataset(val_df, tokenizer, train=False)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
print(f'Validation examples: {len(val_ds)}')

  from .autonotebook import tqdm as notebook_tqdm
2025-10-17 00:58:24,721 INFO: Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2025-10-17 00:58:24,762 INFO: [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Validation examples: 1000


In [4]:
# Compute embeddings for entire validation set
import torch
val_img_embs_list = []
val_txt_embs_list = []
captions = []

with torch.no_grad():
    for batch in val_loader:
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        txts = batch['caption']
        img_e, txt_e = model(images, input_ids, attention_mask)
        val_img_embs_list.append(img_e.cpu())
        val_txt_embs_list.append(txt_e.cpu())
        captions.extend(txts)

val_img_embs = torch.cat(val_img_embs_list, dim=0)
val_txt_embs = torch.cat(val_txt_embs_list, dim=0)
print(f'Computed validation embeddings: images {val_img_embs.shape}, texts {val_txt_embs.shape}')

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Computed validation embeddings: images torch.Size([1000, 256]), texts torch.Size([1000, 256])


In [5]:
# Compute similarity scores and retrieval metrics
import torch
scores = val_txt_embs @ val_img_embs.T  # (N_text, N_image)

# Top-k metrics as requested
top1 = torch.mean((torch.argmax(scores, dim=1) == torch.arange(scores.shape[0], device=scores.device)).float()).item()
top5 = torch.mean(
    torch.tensor([
        i in torch.topk(scores[i], 5).indices.tolist()
        for i in range(scores.shape[0])
    ], dtype=torch.float32, device=scores.device)
).item()
top10 = torch.mean(
    torch.tensor([
        i in torch.topk(scores[i], 10).indices.tolist()
        for i in range(scores.shape[0])
    ], dtype=torch.float32, device=scores.device)
).item()

print(f"Top-1: {top1:.4f}, Top-5: {top5:.4f}, Top-10: {top10:.4f}")

Top-1: 0.1670, Top-5: 0.4040, Top-10: 0.5300


In [6]:
# Compute retrieval metrics on the training set (Top-1 / Top-5 / Top-10)
# WARNING: this computes an N x N similarity matrix and can be memory intensive for large datasets.
TRAIN_CSV = Path('../../data/alpaca_mbj_bandgap_train.csv')
train_df = pd.read_csv(TRAIN_CSV)
train_ds = ImageTextDataset(train_df, tokenizer, train=False)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

train_image_embs = []
train_text_embs = []
with torch.no_grad():
    for batch in train_loader:
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        img_e, txt_e = model(images, input_ids, attention_mask)
        train_image_embs.append(img_e.cpu())
        train_text_embs.append(txt_e.cpu())

train_image_embeddings = torch.cat(train_image_embs, dim=0)
train_text_embeddings = torch.cat(train_text_embs, dim=0)
print(f'Computed training embeddings: images {train_image_embeddings.shape}, texts {train_text_embeddings.shape}')

# compute similarity and retrieval metrics for the training set
scores_train = train_text_embeddings @ train_image_embeddings.T  # (N_text, N_image)

# Top-k metrics
train_top1 = torch.mean((torch.argmax(scores_train, dim=1) == torch.arange(scores_train.shape[0], device=scores_train.device)).float()).item()
train_top5 = torch.mean(
    torch.tensor([
        i in torch.topk(scores_train[i], 5).indices.tolist()
        for i in range(scores_train.shape[0])
    ], dtype=torch.float32, device=scores_train.device)
).item()
train_top10 = torch.mean(
    torch.tensor([
        i in torch.topk(scores_train[i], 10).indices.tolist()
        for i in range(scores_train.shape[0])
    ], dtype=torch.float32, device=scores_train.device)
).item()

print(f"Train Top-1: {train_top1:.4f}, Top-5: {train_top5:.4f}, Top-10: {train_top10:.4f}")


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Computed training embeddings: images torch.Size([5000, 256]), texts torch.Size([5000, 256])
Train Top-1: 0.1930, Top-5: 0.4856, Top-10: 0.6360


# Embedding Visualization

Let's visualize how well our model aligns the image and text embeddings in the shared space using t-SNE dimensionality reduction.

In [11]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

# Combine image and text embeddings
combined_embs = torch.cat([val_img_embs, val_txt_embs], dim=0)

# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42)
combined_tsne = tsne.fit_transform(combined_embs.numpy())

# Split back into image and text embeddings
n = len(val_img_embs)
img_tsne = combined_tsne[:n]
txt_tsne = combined_tsne[n:]

# Create visualization
plt.figure(figsize=(15, 15))

# Plot all points
plt.scatter(img_tsne[:, 0], img_tsne[:, 1], c='blue', label='Images', alpha=0.5, s=50)
plt.scatter(txt_tsne[:, 0], txt_tsne[:, 1], c='red', label='Text', alpha=0.5, s=50)

# Draw lines connecting corresponding pairs for a subset of examples
num_examples = 10  # Number of example pairs to highlight
random_indices = np.random.choice(n, num_examples, replace=False)

for i, idx in enumerate(random_indices):
    # Draw a line connecting the image-text pair
    plt.plot([img_tsne[idx, 0], txt_tsne[idx, 0]], 
             [img_tsne[idx, 1], txt_tsne[idx, 1]], 
             'k-', alpha=0.3)
    
    # Add number labels
    plt.annotate(f'Pair {i+1}', 
                xy=(img_tsne[idx, 0], img_tsne[idx, 1]),
                xytext=(10, 10), textcoords='offset points',
                bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
    plt.annotate(f'Pair {i+1}', 
                xy=(txt_tsne[idx, 0], txt_tsne[idx, 1]),
                xytext=(10, 10), textcoords='offset points',
                bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))

plt.legend(fontsize=12)
plt.title('t-SNE visualization of image and text embeddings (CLIPP-SciBERT)\nValidation Set', fontsize=14)

# Add text descriptions for a few example pairs
# plt.figure(figsize=(10, 5))
# plt.axis('off')
# plt.text(0.1, 1.0, 'Example Pair Descriptions:', fontsize=12, fontweight='bold')
# for i, idx in enumerate(random_indices[:5]):  # Show first 5 pairs
#     plt.text(0.1, 0.9 - i*0.2, f'Pair {i+1}: {captions[idx][:100]}...', 
#              fontsize=10, wrap=True)

plt.savefig('clipp_scibert_tsne.png', dpi=600)
plt.tight_layout()
plt.show()
