# IRMAS Test Evaluation

Load the trained `CNNVarTime` checkpoint, run it on the precomputed IRMAS test mel windows, aggregate predictions back to the clip level, and report clip-level accuracy.


In [None]:

from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
import torch, numpy as np, pandas as pd
from collections import defaultdict
from src.models import CNNVarTime

PROJECT_ROOT = Path.cwd()
IRMAS_TEST_MANIFEST = PROJECT_ROOT / "data/manifests/irmas_test_mels.csv"
RESUME_CKPT = PROJECT_ROOT / "saved_weights/irmas_pretrain/best_val_acc.pt"
BATCH_SIZE = 32

In [22]:
from src.utils.checkpoint import class_names_from_ckpt, load_model_state
from src.utils.datasets import IRMASTestWindowDataset
from src.utils.utils import CLASSES, pick_device

# Align class order with training
train_order_names, _ = class_names_from_ckpt(RESUME_CKPT, fallback=list(CLASSES))
test_dataset = IRMASTestWindowDataset(
    manifest_csv=IRMAS_TEST_MANIFEST,
    project_root=PROJECT_ROOT,
    class_names=train_order_names,
    per_example_norm=True,   # must match your train setting
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=None,
    collate_fn=None,
)

DEVICE = pick_device()
# Build model + strict load
model = CNNVarTime(num_classes=len(train_order_names)).to(DEVICE)
state_dict = load_model_state(RESUME_CKPT)
missing, unexpected = model.load_state_dict(state_dict, strict=True)
assert not missing and not unexpected, f"state_dict mismatch: {missing} | {unexpected}"
w = model.state_dict()["conv1.weight"]
print("conv1 weight mean/std:", float(w.mean()), float(w.std()))
model.eval()


conv1 weight mean/std: 0.0006829480407759547 0.10078245401382446


CNNVarTime(
  (conv1): Conv2d(2, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1))
  (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1))
  (bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (gap): AdaptiveAvgPool2d(outpu

In [None]:
logits_by_clip = defaultdict(list)
labels_by_clip = {}
windows_per_clip = defaultdict(int)
pin_memory = (DEVICE == "cuda")

with torch.no_grad():
    for inputs, targets, clip_ids, paths in test_loader:
        inputs = inputs.to(DEVICE, non_blocking=pin_memory)
        logits = model(inputs)                         # (B, C)
        logits_cpu = logits.detach().cpu().numpy()
        targets_np = targets.cpu().numpy()

        for logit_vec, target_vec, clip_id in zip(logits_cpu, targets_np, clip_ids):
            logits_by_clip[clip_id].append(logit_vec)
            labels_by_clip[clip_id] = target_vec.astype(np.float32)
            windows_per_clip[clip_id] += 1

# --- aggregate per clip & score single-label membership ---
rows = []
for clip_id, logit_list in logits_by_clip.items():
    stacked = np.stack(logit_list, 0)                # (W, C)
    mean_logits = stacked.mean(axis=0)               # (C,)
    exp = np.exp(mean_logits - mean_logits.max())
    probs = exp / exp.sum()
    pred_idx = int(mean_logits.argmax())
    gt_vec = labels_by_clip[clip_id]
    hit1 = bool(gt_vec[pred_idx] > 0.5)

    order = np.argsort(mean_logits)[::-1]
    top3 = ", ".join(f"{train_order_names[i]} ({probs[i]:.2f})" for i in order[:3])
    true_names = [train_order_names[i] for i, v in enumerate(gt_vec) if v > 0.5]

    rows.append({
        "clip": clip_id,
        "true_labels": "|".join(true_names),
        "pred_label": train_order_names[pred_idx],
        "pred_score": float(probs[pred_idx]),
        "top3": top3,
        "hit@1": hit1,
        "windows": int(windows_per_clip[clip_id]),
    })

df = pd.DataFrame(rows).sort_values(["hit@1", "pred_score"], ascending=[True, False])
print("Top-1 clip accuracy:", df["hit@1"].mean())
df.head()

Top-1 clip accuracy: 0.08550185873605948


Unnamed: 0,clip,true_labels,pred_label,pred_score,top3,hit@1,windows
420,01 Organ Grinder's Swing-7.wav,cel,voi,0.978553,"voi (0.98), gel (0.01), org (0.00)",False,2
12,00 - gold fronts-1.wav,cel,voi,0.948858,"voi (0.95), sax (0.01), org (0.01)",False,4
511,01) Bert Jansch - Avocet-61.wav,cel,gac,0.874771,"gac (0.87), pia (0.06), gel (0.04)",False,7
128,01 - Inolvidable-9.wav,cel|gel,voi,0.872394,"voi (0.87), org (0.03), sax (0.02)",False,2
236,01 - roads-7.wav,cel,voi,0.8671,"voi (0.87), gel (0.04), org (0.04)",False,4


In [None]:
per_class = []
for ci, cname in enumerate(train_order_names):
    mask = df["true_labels"].str.contains(rf"(?:^|[|]){cname}(?:$|[|])")
    support = int(mask.sum())
    correct = int((df.loc[mask, "pred_label"] == cname).sum())
    per_class.append({"class": cname, "support": support, "hit_rate": (correct / support if support else float("nan"))})
per_class_df = pd.DataFrame(per_class).sort_values("hit_rate", ascending=False)
per_class_df


Unnamed: 0,class,support,hit_rate
10,voi,23,0.826087
5,org,29,0.517241
9,vio,22,0.090909
3,gac,32,0.0625
1,cla,218,0.055046
4,gel,64,0.046875
0,cel,807,0.01487
2,flu,112,0.0
6,pia,98,0.0
7,sax,27,0.0
