# Recall @K

In [None]:
import torch
from text_encoding_exp import TextEncoder, VisionEncoder

# import from the training notebook

In [None]:
# will check and use MPS if available, otherwise CUDA, otherwise CPU
# mps is super fast mac thing
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using device: MPS (Apple)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using device: CUDA (GPU) - {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using device: CPU")

In [None]:
# Load the trained weights
text_enc = TextEncoder(embed_dim=512)
vision_enc = VisionEncoder(embed_dim=512)

text_enc.load_state_dict(torch.load("text_encoder.pth"))
vision_enc.load_state_dict(torch.load("vision_encoder.pth"))

text_enc.eval()
vision_enc.eval()
text_enc = text_enc.to(device)
vision_enc = vision_enc.to(device)

### generate embeddings for validation set

In [None]:
# copied these blocks from the training notebook for getting val loader
val_full = pd.read_csv("CheXpert-v1.0-small/valid.csv")

In [None]:
val_dataset = CustomDataset(val_full)

In [None]:
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
# Get all image and text embeddings
image_embeddings = []
text_embeddings = []

with torch.no_grad():
    for batch in val_loader:
        images = batch["image_tensor"].to(device)
        token_ids = batch["token_ids"].squeeze(1).to(device)
        attention_masks = batch["attention_masks"].squeeze(1).to(device)

        img_emb = vision_enc(images)
        txt_emb = text_enc(token_ids, attention_masks)

        image_embeddings.append(img_emb.cpu())
        text_embeddings.append(txt_emb.cpu())

# Concatenate all batches
image_embeddings = torch.cat(image_embeddings, dim=0)  # Shape: [N, 512]
text_embeddings = torch.cat(text_embeddings, dim=0)  # Shape: [N, 512]

### Implement recall @k

In [None]:
def recall_at_k(image_embeds, text_embeds, k_values=[1, 5, 10]):
    """
    Compute Recall@K for image-to-text and text-to-image retrieval
    """
    # Compute similarity matrix: [N_images, N_texts]
    similarity = torch.matmul(image_embeds, text_embeds.T)

    results = {}

    # Image-to-Text Recall@K
    for k in k_values:
        # For each image, get top-k most similar texts
        top_k_indices = similarity.topk(k, dim=1).indices  # [N, k]

        # Check if correct text (same index) is in top-k
        correct = torch.zeros(len(image_embeds))
        for i in range(len(image_embeds)):
            if i in top_k_indices[i]:
                correct[i] = 1

        recall = correct.mean().item()
        results[f"image_to_text_recall@{k}"] = recall

    # Text-to-Image Recall@K
    for k in k_values:
        # For each text, get top-k most similar images
        top_k_indices = similarity.T.topk(k, dim=1).indices  # [N, k]

        # Check if correct image (same index) is in top-k
        correct = torch.zeros(len(text_embeds))
        for i in range(len(text_embeds)):
            if i in top_k_indices[i]:
                correct[i] = 1

        recall = correct.mean().item()
        results[f"text_to_image_recall@{k}"] = recall

    return results


# Compute metrics
k_values = [1, 5, 10]
results = recall_at_k(image_embeddings, text_embeddings, k_values)

# Print results
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")