# IRMAS Test Evaluation

Load the trained `CNNVarTime` checkpoint, run it on the precomputed IRMAS test mel windows, aggregate predictions back to the clip level, and report clip-level accuracy.


In [34]:
import json
from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

from src.models import CNNVarTime
from src.utils.utils import CLASSES, decode_label_bits


In [35]:
PROJECT_ROOT = Path.cwd()
IRMAS_TEST_MANIFEST = PROJECT_ROOT / "data/manifests/irmas_test_mels.csv"
IRMAS_CONFIG_PATH = IRMAS_TEST_MANIFEST.with_suffix(".config.json")
RESUME_CKPT = PROJECT_ROOT / "saved_weights/irmas_pretrain/last.pt"
# RESUME_CKPT = PROJECT_ROOT / "saved_weights/irmas_pretrain/best_val_acc.pt"
# RESUME_CKPT = PROJECT_ROOT / "saved_weights/irmas_pretrain/val_0.56.pt"

BATCH_SIZE = 32
NUM_WORKERS = 0

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
PIN_MEMORY = DEVICE.type == "cuda"

assert IRMAS_TEST_MANIFEST.exists(), f"Manifest missing: {IRMAS_TEST_MANIFEST}"
assert RESUME_CKPT.exists(), f"Checkpoint missing: {RESUME_CKPT}"

if IRMAS_CONFIG_PATH.exists():
    irmas_config = json.loads(IRMAS_CONFIG_PATH.read_text())
    eval_class_names = irmas_config.get("classes", [])
else:
    irmas_config = {}
    eval_class_names = []

if not eval_class_names:
    eval_class_names = list(CLASSES)
else:
    eval_class_names = list(eval_class_names)

print(f"Using device: {DEVICE}")
print(f"Test manifest rows: {sum(1 for _ in open(IRMAS_TEST_MANIFEST)) - 1}")
print(f"Eval classes ({len(eval_class_names)}): {eval_class_names}")


Using device: mps
Test manifest rows: 8992
Eval classes (11): ['cel', 'cla', 'flu', 'gac', 'gel', 'org', 'pia', 'sax', 'tru', 'vio', 'voi']


In [36]:
class IRMASTestWindowDataset(Dataset):
    def __init__(self, manifest_csv: Path, project_root: Path, class_names, per_example_norm: bool = True):
        self.project_root = project_root
        self.class_names = list(class_names)
        self.label_to_idx = {label: idx for idx, label in enumerate(self.class_names)}
        self.per_example_norm = per_example_norm

        df = pd.read_csv(
            manifest_csv,
            dtype={"label_multi": "string"},
            keep_default_na=False,
        )
        if "label_multi" not in df.columns:
            raise ValueError("Manifest missing 'label_multi' column")
        df["filepath"] = df["filepath"].astype(str)
        df["label_multi"] = (
            df["label_multi"]
            .astype(str)
            .str.strip()
            .str.replace(r"[^01]", "", regex=True)
            .str.pad(len(self.class_names), side="right", fillchar="0")
            .str[:len(self.class_names)]
        )
        if "irmas_filename" in df.columns:
            df["clip_id"] = df["irmas_filename"].astype(str)
        elif "clip_id" not in df.columns:
            df["clip_id"] = df["filepath"].map(lambda p: Path(p).stem)
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def _resolve_path(self, raw_path: str) -> Path:
        path = Path(raw_path)
        if not path.is_absolute():
            path = self.project_root / raw_path
        return path

    def __getitem__(self, index):
        row = self.df.iloc[index]
        mel_path = self._resolve_path(row["filepath"])
        if not mel_path.exists():
            raise FileNotFoundError(mel_path)
        mel = np.load(mel_path, allow_pickle=False).astype(np.float32)
        if self.per_example_norm:
            mean = mel.mean(axis=(1, 2), keepdims=True)
            std = mel.std(axis=(1, 2), keepdims=True).clip(min=1e-6)
            mel = (mel - mean) / std
        bits = row["label_multi"]
        positives = [self.class_names[i] for i, ch in enumerate(bits) if ch == "1" and i < len(self.class_names)]
        target = np.zeros(len(self.class_names), dtype=np.float32)
        for label in positives:
            target[self.label_to_idx[label]] = 1.0
        return torch.from_numpy(mel), torch.from_numpy(target), row["clip_id"], str(mel_path)


def pad_collate(batch):
    xs, ys, clip_ids, paths = zip(*batch)
    B = len(xs)
    C, F = xs[0].shape[:2]
    Tmax = max(x.shape[-1] for x in xs)
    padded = torch.zeros(B, C, F, Tmax, dtype=xs[0].dtype)
    for i, x in enumerate(xs):
        T = x.shape[-1]
        padded[i, :, :, :T] = x
    targets = torch.stack(ys).float()
    return padded, targets, list(clip_ids), list(paths)


In [37]:
test_dataset = IRMASTestWindowDataset(
    manifest_csv=IRMAS_TEST_MANIFEST,
    project_root=PROJECT_ROOT,
    class_names=eval_class_names,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=pad_collate,
)

model = CNNVarTime(num_classes=len(eval_class_names))

ckpt = torch.load(RESUME_CKPT, map_location="cpu")
state_dict = None
if isinstance(ckpt, dict):
    for key in ("model_state_dict", "state_dict", "model"):
        maybe = ckpt.get(key)
        if isinstance(maybe, dict):
            state_dict = maybe
            break
    if state_dict is None:
        tensor_items = {k: v for k, v in ckpt.items() if isinstance(v, torch.Tensor)}
        state_dict = tensor_items or ckpt
else:
    state_dict = ckpt

state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
    print(f"Missing keys from checkpoint: {sorted(missing)}")
if unexpected:
    print(f"Unexpected keys in checkpoint: {sorted(unexpected)}")

model.to(DEVICE)
model.eval()
eval_class_names = tuple(eval_class_names)

print(f"Loaded checkpoint: {RESUME_CKPT}")
print(f"Windows: {len(test_dataset)} | Unique clips: {test_dataset.df['clip_id'].nunique()}")


Missing keys from checkpoint: ['bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var', 'bn3.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var', 'bn4.weight', 'conv1.bias', 'conv1.weight', 'conv2.bias', 'conv2.weight', 'conv3.bias', 'conv3.weight', 'conv4.bias', 'conv4.weight', 'fc1.bias', 'fc1.weight', 'fc2.bias', 'fc2.weight']
Unexpected keys in checkpoint: ['epoch', 'history', 'label_to_idx', 'model_state', 'opt_state', 'sched_state']
Loaded checkpoint: /Users/hughsignoriello/Developer/ML-based-analysis-of-sound/saved_weights/irmas_pretrain/last.pt
Windows: 8992 | Unique clips: 807


In [38]:
probs_by_clip = defaultdict(list)
labels_by_clip = {}
windows_per_clip = defaultdict(int)

model.eval()
with torch.no_grad():
    for inputs, targets, clip_ids, paths in test_loader:
        inputs = inputs.to(DEVICE)
        logits = model(inputs)
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        targets_np = targets.cpu().numpy()
        for prob_vec, target_vec, clip_id in zip(probs, targets_np, clip_ids):
            probs_by_clip[clip_id].append(prob_vec)
            labels_by_clip[clip_id] = target_vec
            windows_per_clip[clip_id] += 1

num_clips = len(probs_by_clip)
total_windows = sum(windows_per_clip.values())
print(f"Collected predictions for {num_clips} clips across {total_windows} windows.")


Collected predictions for 807 clips across 8992 windows.


In [39]:
# print(probs_by_clip.items())
dict_length = len(probs_by_clip)
total_probs = sum(len(probs) for probs in probs_by_clip.values())
memory_size = sum(sum(prob.nbytes for prob in probs) for probs in probs_by_clip.values()) / (1024 * 1024)  # MB

print(f"Dictionary length: {dict_length} clips")
print(f"Total probability arrays: {total_probs}")
print(f"Approximate memory usage: {memory_size:.2f} MB")

Dictionary length: 807 clips
Total probability arrays: 8992
Approximate memory usage: 0.38 MB


In [40]:
results = []
top1_hits = 0
for clip_id, prob_list in probs_by_clip.items():
    # print(clip_id)
    stacked = np.stack(prob_list, axis=0)
    agg = stacked.max(axis=0)
    gt = labels_by_clip[clip_id]
    true_idx = np.where(gt > 0.5)[0]
    order = np.argsort(agg)[::-1]
    top1 = order[0]
    hit1 = len(true_idx) > 0 and (top1 in true_idx)
    top1_hits += int(hit1)
    top3_labels = [f"{eval_class_names[idx]} ({agg[idx]:.2f})" for idx in order[:3]]
    bits = "".join("1" if v > 0.5 else "0" for v in gt)
    true_names = decode_label_bits(bits, eval_class_names)
    results.append({
        "clip": clip_id,
        "true_labels": true_names,
        "pred_label": eval_class_names[top1],
        "pred_score": float(agg[top1]),
        "top3": ", ".join(top3_labels),
        "hit@1": bool(hit1),
        "windows": windows_per_clip[clip_id],
    })

clip_count = len(results) or 1
top1_acc = top1_hits / clip_count
print(f"Clips evaluated: {clip_count}")
print(f"Top-1 accuracy (hit@1): {top1_acc:.3f}")


Clips evaluated: 807
Top-1 accuracy (hit@1): 0.032


In [None]:
results_df = pd.DataFrame(results).sort_values("clip").reset_index(drop=True)
display(results_df.head(20))

misses = results_df[~results_df["hit@1"]]
print(f"Misclassified clips: {len(misses)}")
if not misses.empty:
    display(misses.head(10))




Unnamed: 0,clip,true_labels,pred_label,pred_score,top3,hit@1,windows
0,(02) dont kill the whale-1.wav,gel,cel,0.100376,"cel (0.10), org (0.10), tru (0.09)",False,13
1,(02) dont kill the whale-11.wav,gel,cel,0.100346,"cel (0.10), org (0.10), tru (0.09)",False,6
2,(02) dont kill the whale-12.wav,"gel, voi",cel,0.100255,"cel (0.10), org (0.10), tru (0.09)",False,4
3,(02) dont kill the whale-13.wav,"gel, voi",cel,0.100265,"cel (0.10), org (0.10), tru (0.09)",False,4
4,(02) dont kill the whale-14.wav,"gel, voi",cel,0.100492,"cel (0.10), org (0.10), tru (0.09)",False,8
5,(02) dont kill the whale-15.wav,"gel, pia",cel,0.100244,"cel (0.10), org (0.10), tru (0.09)",False,8
6,(02) dont kill the whale-2.wav,"gel, voi",cel,0.100433,"cel (0.10), org (0.10), tru (0.09)",False,3
7,(02) dont kill the whale-3.wav,"gel, voi",cel,0.100423,"cel (0.10), org (0.10), tru (0.09)",False,5
8,(02) dont kill the whale-4.wav,gel,cel,0.100518,"cel (0.10), org (0.10), tru (0.09)",False,13
9,(02) dont kill the whale-6.wav,"gel, voi",cel,0.10052,"cel (0.10), org (0.10), tru (0.09)",False,13


Misclassified clips: 781


Unnamed: 0,clip,true_labels,pred_label,pred_score,top3,hit@1,windows
0,(02) dont kill the whale-1.wav,gel,cel,0.100376,"cel (0.10), org (0.10), tru (0.09)",False,13
1,(02) dont kill the whale-11.wav,gel,cel,0.100346,"cel (0.10), org (0.10), tru (0.09)",False,6
2,(02) dont kill the whale-12.wav,"gel, voi",cel,0.100255,"cel (0.10), org (0.10), tru (0.09)",False,4
3,(02) dont kill the whale-13.wav,"gel, voi",cel,0.100265,"cel (0.10), org (0.10), tru (0.09)",False,4
4,(02) dont kill the whale-14.wav,"gel, voi",cel,0.100492,"cel (0.10), org (0.10), tru (0.09)",False,8
5,(02) dont kill the whale-15.wav,"gel, pia",cel,0.100244,"cel (0.10), org (0.10), tru (0.09)",False,8
6,(02) dont kill the whale-2.wav,"gel, voi",cel,0.100433,"cel (0.10), org (0.10), tru (0.09)",False,3
7,(02) dont kill the whale-3.wav,"gel, voi",cel,0.100423,"cel (0.10), org (0.10), tru (0.09)",False,5
8,(02) dont kill the whale-4.wav,gel,cel,0.100518,"cel (0.10), org (0.10), tru (0.09)",False,13
9,(02) dont kill the whale-6.wav,"gel, voi",cel,0.10052,"cel (0.10), org (0.10), tru (0.09)",False,13


NameError: name 'hits' is not defined