# Scoring + Evaluation

This notebook loads a trained SSL checkpoint, computes anomaly scores, and reports metrics.


In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader

from src.augment import get_mae_transform
from src.data import NIHChestXrayDataset
from src.eval import evaluate_scores
from src.scoring import (
    extract_embeddings,
    fit_knn,
    load_ssl_encoder,
    score_knn,
)

TRAIN_CSV = "splits/train.csv"
TEST_CSV = "splits/test.csv"
CHECKPOINT = "checkpoints/simclr/simclr_epoch_100.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

train_df = pd.read_csv(TRAIN_CSV)
test_df = pd.read_csv(TEST_CSV)


In [None]:
transform = get_mae_transform(224)
train_ds = NIHChestXrayDataset(train_df, transform=transform)
test_ds = NIHChestXrayDataset(test_df, transform=transform)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=False, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2)

encoder = load_ssl_encoder(CHECKPOINT, method="simclr", backbone="resnet50")
train_feats, _ = extract_embeddings(encoder, train_loader, torch.device(DEVICE))
test_feats, labels = extract_embeddings(encoder, test_loader, torch.device(DEVICE))

knn = fit_knn(train_feats, k=5)
scores = score_knn(knn, test_feats)

metrics = evaluate_scores(labels, scores)
metrics
