# Recall @K

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import BertModel
from torchvision import models
from tqdm import tqdm
import pandas as pd
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image

In [35]:
# will check and use MPS if available, otherwise CUDA, otherwise CPU
# mps is super fast mac thing
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using device: MPS (Apple)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using device: CUDA (GPU) - {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using device: CPU")

Using device: MPS (Apple)


In [36]:
best_chckpt = torch.load("checkpoint_iter_14000.pt", map_location=device)

In [11]:
print("checkpoint keys:", best_chckpt.keys())

checkpoint keys: dict_keys(['epoch', 'iteration', 'text_enc', 'vision_enc', 'optimizer', 'scaler', 'train_loss'])


In [21]:
# our encoder class definitions from training
from transformers import AutoTokenizer, AutoModel

MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"


class ClinicalTextEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super(ClinicalTextEncoder, self).__init__()
        self.bert = AutoModel.from_pretrained(MODEL_NAME)
        self.proj = nn.Linear(768, embed_dim)

    def forward(self, token_ids, attention_masks):
        outputs = self.bert(token_ids, attention_mask=attention_masks)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]
        embeddings = self.proj(cls_embeddings)
        # normalizing because we need to compare with image embeddings later
        # for the contrastive similarity
        embeddings = F.normalize(embeddings, p=2, dim=-1, eps=1e-6)
        return embeddings


# Use the torchvision's implementation of ResNeXt, but add FC layer to generate 512d embedding.
class VisionEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        resnet = models.resnext50_32x4d(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        in_dim = resnet.fc.in_features
        self.proj = nn.Linear(in_dim, embed_dim)

    def forward(self, x):
        features = self.backbone(x)
        features = features.squeeze(-1).squeeze(-1)
        z = self.proj(features)
        # convert to unit vectors for cosine similarity later
        z = z / z.norm(dim=-1, keepdim=True)
        return z

In [22]:
text_encoder = ClinicalTextEncoder(embed_dim=512)
image_encoder = VisionEncoder(embed_dim=512)

# load the trained weights
text_encoder.load_state_dict(best_chckpt["text_enc"])
image_encoder.load_state_dict(best_chckpt["vision_enc"])

text_enc = text_encoder.to(device).eval()
image_encoder = image_encoder.to(device).eval()



config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]



In [26]:
def generate_report_updated(row):
    labels = row.iloc[5:]  # Skip Path, Sex, Age, Frontal/Lateral, AP/PA

    # Separate findings by certainty
    positive_findings = list(labels[labels == 1.0].index)
    uncertain_findings = list(labels[labels == -1.0].index)

    # Build the report
    report_parts = []

    # Add patient demographics for context
    age = int(row["Age"]) if pd.notna(row["Age"]) else None
    sex = row["Sex"].lower() if pd.notna(row["Sex"]) else None
    view = row["Frontal/Lateral"].lower() if pd.notna(row["Frontal/Lateral"]) else None

    # Start with view type
    if view:
        report_parts.append(f"{view.capitalize()} chest radiograph")
    else:
        report_parts.append("Chest radiograph")

    # Add demographics
    demo = []
    if age:
        demo.append(f"{age}-year-old")
    if sex:
        demo.append(sex)
    if demo:
        report_parts.append(f"of {' '.join(demo)} patient")

    # Add findings
    if len(positive_findings) == 0 and len(uncertain_findings) == 0:
        report_parts.append("demonstrates no acute cardiopulmonary abnormality")
    else:
        findings_text = []

        # Definite findings
        if positive_findings:
            findings_clean = [f.lower().replace("_", " ") for f in positive_findings]
            findings_text.append("shows " + ", ".join(findings_clean))

        # Uncertain findings (optional - you might want to treat these differently)
        if uncertain_findings:
            uncertain_clean = [f.lower().replace("_", " ") for f in uncertain_findings]
            findings_text.append("possible " + ", ".join(uncertain_clean))

        report_parts.append(". ".join(findings_text))

    return " ".join(report_parts) + "."

In [27]:
# dataset definition from training file
class ClinicalCustomDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        df = df.reset_index(drop=True)  # Reset index to ensure 0-based indexing
        # Text
        self.text_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        self.reports = df.apply(
            generate_report_updated, axis=1
        )  # Maybe move generate report to inside the dataset later

        # Vision
        self.images = df["Path"]
        self.transform = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def __len__(self):
        return len(self.reports)  # This could work or we could do another way

    def __getitem__(self, idx):
        # Text part
        report = self.reports[idx]
        encoder = self.text_tokenizer.encode_plus(
            report,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        # Vision part
        img_path = self.images[idx]
        image = Image.open(img_path).convert("RGB")
        img_tensor = self.transform(image)

        return {
            "token_ids": encoder["input_ids"],
            "attention_masks": encoder["attention_mask"],
            "image_tensor": img_tensor,
            "report": report,
            "img_path": img_path,
        }

### generate embeddings for validation set

In [30]:
# copied these blocks from the training notebook for getting val loader
val_full = pd.read_csv("CheXpert-v1.0-small/valid.csv")

In [32]:
val_dataset = ClinicalCustomDataset(val_full)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)



In [37]:
# Get all image and text embeddings
image_embeddings = []
text_embeddings = []

with torch.no_grad():
    for batch in val_loader:
        images = batch["image_tensor"].to(device)
        token_ids = batch["token_ids"].squeeze(1).to(device)
        attention_masks = batch["attention_masks"].squeeze(1).to(device)

        img_emb = image_encoder(images)
        txt_emb = text_encoder(token_ids, attention_masks)

        image_embeddings.append(img_emb.cpu())
        text_embeddings.append(txt_emb.cpu())

# Concatenate all batches
image_embeddings = torch.cat(image_embeddings, dim=0)
text_embeddings = torch.cat(text_embeddings, dim=0)

### Implement recall @k

In [39]:
def recall_at_k(image_embeds, text_embeds, k_values=[1, 5, 10]):
    """
    Compute Recall@K for image-to-text and text-to-image retrieval
    """
    # Compute similarity matrix: [N_images, N_texts]
    similarity = torch.matmul(image_embeds, text_embeds.T)

    results = {}

    # Image-to-Text Recall@K
    for k in k_values:
        # For each image, get top-k most similar texts
        top_k_indices = similarity.topk(k, dim=1).indices  # [N, k]

        # Check if correct text (same index) is in top-k
        correct = torch.zeros(len(image_embeds))
        for i in range(len(image_embeds)):
            if i in top_k_indices[i]:
                correct[i] = 1

        recall = correct.mean().item()
        results[f"image_to_text_recall@{k}"] = recall

    # Text-to-Image Recall@K
    for k in k_values:
        # For each text, get top-k most similar images
        top_k_indices = similarity.T.topk(k, dim=1).indices  # [N, k]

        # Check if correct image (same index) is in top-k
        correct = torch.zeros(len(text_embeds))
        for i in range(len(text_embeds)):
            if i in top_k_indices[i]:
                correct[i] = 1

        recall = correct.mean().item()
        results[f"text_to_image_recall@{k}"] = recall

    return results

In [40]:
# Compute metrics
k_values = [1, 5, 10]
results = recall_at_k(image_embeddings, text_embeddings, k_values)

# Print results
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")

image_to_text_recall@1: 0.2521
image_to_text_recall@5: 0.6453
image_to_text_recall@10: 0.8162
text_to_image_recall@1: 0.2650
text_to_image_recall@5: 0.6282
text_to_image_recall@10: 0.7735
