# IRMAS Test Evaluation

Load the trained `CNNVarTime` checkpoint, run it on the precomputed IRMAS test mel windows, aggregate predictions back to the clip level, and report clip-level accuracy.


In [1]:
# -------- IRMAS test evaluation (single-class model, clip-level) --------
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from collections import defaultdict

import torch
from torch.utils.data import DataLoader

from src.models import CNNVarTime
from src.utils.datasets import IRMASTestWindowDataset
from src.classes import IRMAS_CLASSES  # only for label names if ckpt lacks map

# ----------------- CONFIG -----------------
IRMAS_TEST_MANIFEST = "data/manifests/irmas_test_mels.csv"   # <-- set yours
RESUME_CKPT         = "saved_weights/irmas_pretrain_single_class/train_2/best_epoch_0075_val_acc_66.50.pt"
BATCH_SIZE          = 64
NUM_WORKERS         = 2
PROJECT_ROOT        = Path.cwd()

DEVICE  = "cuda" if torch.cuda.is_available() else "cpu"
PIN_MEM = (DEVICE == "cuda")


In [16]:
# ----------------- HELPERS -----------------
from typing import Counter


def load_model_state(ckpt_path: str) -> Dict[str, torch.Tensor]:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    sd = None
    if isinstance(ckpt, dict):
        for k in ("model_state", "model_state_dict", "state_dict", "model"):
            if isinstance(ckpt.get(k), dict):
                sd = ckpt[k]
                break
    if sd is None:
        sd = ckpt if isinstance(ckpt, dict) else ckpt
    return {(k[7:] if k.startswith("module.") else k): v for k, v in sd.items()}

def class_names_from_ckpt(ckpt_path: str, fallback: List[str]) -> Tuple[List[str], Dict[str, int]]:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    if isinstance(ckpt, dict) and isinstance(ckpt.get("label_to_idx"), dict):
        l2i = ckpt["label_to_idx"]
        ordered = [name for name, idx in sorted(l2i.items(), key=lambda kv: kv[1])]
        return ordered, l2i
    return list(fallback), {name: i for i, name in enumerate(fallback)}

# ----------------- CLASS ORDER (LOCK TO CKPT) -----------------
train_order_names, label_to_idx = class_names_from_ckpt(RESUME_CKPT, IRMAS_CLASSES)
num_classes = len(train_order_names)
print("Class order from ckpt:", train_order_names)

# ----------------- MODEL -----------------
model = CNNVarTime(in_ch=2, num_classes=num_classes, p_drop=0.5).to(DEVICE)
state_dict = load_model_state(RESUME_CKPT)
missing, unexpected = model.load_state_dict(state_dict, strict=True)
assert not missing and not unexpected, f"state_dict mismatch: missing={missing} unexpected={unexpected}"
model.eval()

# ----------------- DATASET / LOADER -----------------
test_ds = IRMASTestWindowDataset(
    manifest_csv=Path(IRMAS_TEST_MANIFEST),
    project_root=PROJECT_ROOT,
    class_names=train_order_names,
    per_example_norm=True,
)

assert len(test_ds) > 0, "Empty test dataset."
x0, y0, clip0, p0 = test_ds[0]
assert x0.ndim == 3 and x0.shape[0] == 2, f"Expected mel shape (2, 128, T); got {tuple(x0.shape)}"
assert y0.numel() == num_classes, f"Target length {y0.numel()} != num_classes {num_classes}"

test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEM,
)
print("Test windows:", len(test_ds))

clip_logits = defaultdict(list)
clip_targets = {}
window_count = 0

with torch.no_grad():
    for inputs, targets, clip_ids, _ in test_loader:
        inputs = inputs.to(DEVICE, non_blocking=PIN_MEM)
        logits = model(inputs)
        logits_cpu = logits.detach().cpu().numpy()
        targets_np = targets.cpu().numpy()

        for clip_id, logit_vec, target_vec in zip(clip_ids, logits_cpu, targets_np):
            clip_logits[clip_id].append(logit_vec)
            clip_targets[clip_id] = target_vec.astype(np.float32)
            window_count += 1

def softmax1d(vec: np.ndarray) -> np.ndarray:
    shifted = vec - vec.max()
    exp = np.exp(shifted)
    return exp / exp.sum()

rows = []

for clip_id, logit_list in clip_logits.items():
    # (W, C): W = windows, C = classes
    stacked = np.stack(logit_list, axis=0)

    # ---- per-window tally (reset each clip) ----
    per_clip_tally = Counter()
    per_window_pred_idx = stacked.argmax(axis=1)            # top-1 per window (no softmax needed)
    for idx in per_window_pred_idx:
        per_clip_tally[train_order_names[idx]] += 1

    # sanity: tally sums to number of windows
    windows = len(logit_list)
    assert sum(per_clip_tally.values()) == windows

    # ---- clip-level aggregation (mean over windows) ----
    mean_logits = stacked.mean(axis=0)
    probs = softmax1d(mean_logits)
    pred_idx = int(mean_logits.argmax())

    true_vec = clip_targets[clip_id]
    true_names = [train_order_names[i] for i, v in enumerate(true_vec) if v > 0.5]
    top_indices = np.argsort(mean_logits)[::-1]

    rows.append({
        "clip": clip_id,
        "true_labels": "|".join(true_names),
        "pred_label": train_order_names[pred_idx],
        "classes_scores_aggregated": ", ".join(
            f"{train_order_names[i]} ({probs[i]:.2f})" for i in top_indices
        ),
        "class_tally": dict(per_clip_tally),   # resets every clip; sums to `windows`
        "windows": windows,
        "hit@1": bool(true_vec[pred_idx] > 0.5),
        "n_true_labels": int((true_vec > 0.5).sum()),
    })



Class order from ckpt: ['cel', 'cla', 'flu', 'gac', 'gel', 'org', 'pia', 'sax', 'tru', 'vio', 'voi']
Test windows: 8992


In [17]:
results_df = pd.DataFrame(rows)
if results_df.empty:
    raise RuntimeError("No clip-level results were produced; check the manifest path and contents.")

# results_df = results_df.sort_values(["hit@1", "pred_score"], ascending=[True, False]).reset_index(drop=True)

num_clips = len(results_df)
overall_accuracy = results_df["hit@1"].mean()
print(f"Evaluated {num_clips} clips across {window_count} windows from the IRMAS test set.")
print(f"Clip-level top-1 accuracy: {overall_accuracy:.4f} ({overall_accuracy * 100:.2f}%)")

num_multi_label = int((results_df["n_true_labels"] > 1).sum())
if num_multi_label:
    print(f"{num_multi_label} clip(s) have multiple ground-truth labels; a prediction counts as correct if it matches any of them.")

display(results_df.head(40))

# try:
#     display(results_df.head(20))
# except NameError:
#     print(results_df.head(20).to_string(index=False))


Evaluated 807 clips across 8992 windows from the IRMAS test set.
Clip-level top-1 accuracy: 0.7385 (73.85%)
505 clip(s) have multiple ground-truth labels; a prediction counts as correct if it matches any of them.


Unnamed: 0,clip,true_labels,pred_label,classes_scores_aggregated,class_tally,windows,hit@1,n_true_labels
0,(02) dont kill the whale-1,gel,vio,"vio (0.26), sax (0.19), tru (0.15), gel (0.13)...","{'tru': 2, 'sax': 1, 'vio': 7, 'gel': 3}",13,False,1
1,(02) dont kill the whale-11,gel,gel,"gel (0.27), sax (0.16), vio (0.13), tru (0.09)...","{'gel': 4, 'sax': 2}",6,True,1
2,(02) dont kill the whale-12,gel|voi,gel,"gel (0.24), vio (0.19), voi (0.13), sax (0.11)...","{'vio': 1, 'gel': 3}",4,True,2
3,(02) dont kill the whale-13,gel|voi,gel,"gel (0.22), vio (0.18), voi (0.16), cel (0.10)...","{'voi': 1, 'gel': 2, 'vio': 1}",4,True,2
4,(02) dont kill the whale-14,gel|voi,gel,"gel (0.26), vio (0.22), voi (0.12), org (0.12)...","{'gel': 6, 'vio': 1, 'voi': 1}",8,True,2
5,(02) dont kill the whale-15,gel|pia,gel,"gel (0.52), vio (0.13), sax (0.10), voi (0.06)...",{'gel': 8},8,True,2
6,(02) dont kill the whale-2,gel|voi,voi,"voi (0.73), gel (0.11), sax (0.05), vio (0.04)...",{'voi': 3},3,True,2
7,(02) dont kill the whale-3,gel|voi,voi,"voi (0.72), gel (0.11), vio (0.05), org (0.04)...","{'voi': 4, 'gel': 1}",5,True,2
8,(02) dont kill the whale-4,gel,tru,"tru (0.19), sax (0.18), vio (0.18), cla (0.14)...","{'tru': 6, 'sax': 2, 'gel': 3, 'vio': 2}",13,False,1
9,(02) dont kill the whale-6,gel|voi,voi,"voi (0.56), gel (0.12), vio (0.08), org (0.07)...","{'voi': 11, 'org': 2}",13,True,2


In [5]:
# ----------------- PER-CLASS BREAKDOWN -----------------
if "results_df" not in globals():
    raise RuntimeError("results_df is not defined; run the previous cell first.")

per_class_rows = []
for idx, cname in enumerate(train_order_names):
    mask = results_df["true_labels"].str.contains(rf"(?:^|[|]){cname}(?:$|[|])", regex=True, na=False)
    support = int(mask.sum())
    correct = int((results_df.loc[mask, "pred_label"] == cname).sum()) if support else 0
    hit_rate = (correct / support) if support else float("nan")
    per_class_rows.append({
        "class": cname,
        "support": support,
        "correct": correct,
        "hit_rate": hit_rate,
    })

per_class_df = pd.DataFrame(per_class_rows).sort_values("hit_rate", ascending=False)
try:
    display(per_class_df)
except NameError:
    print(per_class_df.to_string(index=False))

misclassified = results_df[~results_df["hit@1"]]
print(f"Misclassified clips: {len(misclassified)}")
if len(misclassified):
    try:
        display(misclassified.head(20))
    except NameError:
        print(misclassified.head(20).to_string(index=False))


Unnamed: 0,class,support,correct,hit_rate
10,voi,229,179,0.781659
7,sax,167,95,0.568862
3,gac,145,81,0.558621
5,org,91,38,0.417582
4,gel,241,97,0.40249
2,flu,28,10,0.357143
6,pia,367,75,0.20436
9,vio,51,9,0.176471
8,tru,79,11,0.139241
0,cel,26,1,0.038462


Misclassified clips: 211


Unnamed: 0,clip,true_labels,pred_label,pred_score,top3,hit@1,windows,n_true_labels
0,01 Chuck Mangione_Feels So Good_Feels So Good-2,gac|tru,cla,0.783961,"cla (0.78), flu (0.11), sax (0.08)",False,13,2
1,01 - Inolvidable-7,pia,gac,0.757457,"gac (0.76), gel (0.08), pia (0.07)",False,13,1
2,01 Chuck Mangione_Feels So Good_Feels So Good-1,gac|tru,cla,0.727761,"cla (0.73), sax (0.12), flu (0.07)",False,13,2
3,01 Blue Train-14,pia|tru,sax,0.647161,"sax (0.65), tru (0.28), cla (0.05)",False,13,2
4,01 Chuck Mangione_Feels So Good_Feels So Good-4,gac|tru,cla,0.6358,"cla (0.64), flu (0.18), sax (0.14)",False,13,2
5,01. The Best Of Wayne Shorter - The Blue Note ...,pia|sax,tru,0.627016,"tru (0.63), sax (0.28), vio (0.05)",False,13,2
6,01. Offering-9,org,pia,0.596978,"pia (0.60), gel (0.17), gac (0.09)",False,6,1
7,01 - Chet Baker - Prayer For The Newborn-15,pia|tru,cla,0.595186,"cla (0.60), sax (0.17), tru (0.09)",False,13,2
8,01 - Honky Cat-28,pia,gac,0.594202,"gac (0.59), pia (0.26), gel (0.10)",False,3,1
9,01 - Honky Cat-1,pia,gac,0.57674,"gac (0.58), pia (0.39), gel (0.03)",False,4,1
