# IRMAS Test Evaluation

Load the trained `CNNVarTime` checkpoint, run it on the precomputed IRMAS test mel windows, aggregate predictions back to the clip level, and report clip-level accuracy.


In [1]:
# -------- Fixed IRMAS test evaluation (single-class model, clip-level) --------
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from collections import defaultdict

import torch
from torch.utils.data import DataLoader

from src.models import CNNVarTime
from src.utils.datasets import IRMASTestWindowDataset
from src.utils.utils import IRMAS_CLASSES  # only for label names if ckpt lacks map

# ----------------- CONFIG -----------------
IRMAS_TEST_MANIFEST = "data/manifests/irmas_test_mels.csv"   # <-- set yours
RESUME_CKPT         = "saved_weights/irmas_pretrain_single_class/train_2/best_epoch_0075_val_acc_66.50.pt"
BATCH_SIZE          = 64
NUM_WORKERS         = 2
PROJECT_ROOT        = Path.cwd()

DEVICE  = "cuda" if torch.cuda.is_available() else "cpu"
PIN_MEM = (DEVICE == "cuda")

In [2]:


# ----------------- HELPERS -----------------
def load_model_state(ckpt_path: str) -> Dict[str, torch.Tensor]:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    sd = None
    if isinstance(ckpt, dict):
        for k in ("model_state", "model_state_dict", "state_dict", "model"):
            if isinstance(ckpt.get(k), dict):
                sd = ckpt[k]; break
    if sd is None:
        sd = ckpt if isinstance(ckpt, dict) else ckpt
    # strip potential DDP prefix
    return { (k[7:] if k.startswith("module.") else k): v for k, v in sd.items() }

def class_names_from_ckpt(ckpt_path: str, fallback: List[str]) -> Tuple[List[str], Dict[str,int]]:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    if isinstance(ckpt, dict) and isinstance(ckpt.get("label_to_idx"), dict):
        l2i = ckpt["label_to_idx"]
        ordered = [name for name, idx in sorted(l2i.items(), key=lambda kv: kv[1])]
        return ordered, l2i
    # Fallback only if absolutely necessary
    return list(fallback), {n:i for i, n in enumerate(fallback)}

# ----------------- CLASS ORDER (LOCK TO CKPT) -----------------
train_order_names, label_to_idx = class_names_from_ckpt(RESUME_CKPT, IRMAS_CLASSES)
num_classes = len(train_order_names)
print("Class order from ckpt:", train_order_names)

# ----------------- MODEL -----------------
model = CNNVarTime(in_ch=2, num_classes=num_classes, p_drop=0.5).to(DEVICE)
state_dict = load_model_state(RESUME_CKPT)
missing, unexpected = model.load_state_dict(state_dict, strict=True)
assert not missing and not unexpected, f"state_dict mismatch: missing={missing} unexpected={unexpected}"
model.eval()

# ----------------- DATASET / LOADER -----------------
# IMPORTANT: pass class_names=train_order_names so targets align with the checkpoint order
test_ds = IRMASTestWindowDataset(
    manifest_csv=Path(IRMAS_TEST_MANIFEST),
    project_root=PROJECT_ROOT,
    class_names=train_order_names,
    per_example_norm=True,   # keep consistent with training
)

# quick sanity checks
assert len(test_ds) > 0, "Empty test dataset."
x0, y0, clip0, p0 = test_ds[0]
assert x0.ndim == 3 and x0.shape[0] == 2, f"Expected mel shape (2, 128, T); got {tuple(x0.shape)}"
assert y0.numel() == num_classes, f"Target length {y0.numel()} != num_classes {num_classes}"

test_loader = DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEM
)
print("Test windows:", len(test_ds))

# ----------------- EVAL (clip-level by averaging logits) -----------------
logits_by_clip = defaultdict(list)
labels_by_clip = {}
windows_per_clip = defaultdict(int)

with torch.no_grad():
    for inputs, targets, clip_ids, paths in test_loader:
        inputs = inputs.to(DEVICE, non_blocking=PIN_MEM)
        logits = model(inputs)                     # (B, C) raw logits (pre-softmax)
        logits_cpu = logits.detach().cpu().numpy()
        targets_np = targets.cpu().numpy()

        for logit_vec, target_vec, clip_id in zip(logits_cpu, targets_np, clip_ids):
            print(clip_id)
            logits_by_clip[clip_id].append(logit_vec)               # accumulate windows
            labels_by_clip[clip_id] = target_vec.astype(np.float32) # multi-hot GT per clip
            windows_per_clip[clip_id] += 1

rows = []
for clip_id, logit_list in logits_by_clip.items():
    stacked = np.stack(logit_list, 0)       # (W, C)
    mean_logits = stacked.mean(axis=0)      # (C,)

    # Single-class head (CE): convert to softmax probs for readability
    exp = np.exp(mean_logits - mean_logits.max())
    probs = exp / exp.sum()

    pred_idx = int(mean_logits.argmax())
    gt_vec = labels_by_clip[clip_id]        # multi-hot from manifest
    hit1 = bool(gt_vec[pred_idx] > 0.5)

    order = np.argsort(mean_logits)[::-1]
    top3 = ", ".join(f"{train_order_names[i]} ({probs[i]:.2f})" for i in order[:3])
    true_names = [train_order_names[i] for i, v in enumerate(gt_vec) if v > 0.5]

    rows.append({
        "clip": clip_id,
        "true_labels": "|".join(true_names),
        "pred_label": train_order_names[pred_idx],
        "pred_score": float(probs[pred_idx]),
        "top3": top3,
        "hit@1": hit1,
        # "windows": int(windows_per_clip[clip_id]),
    })

df = pd.DataFrame(rows).sort_values(["hit@1", "pred_score"], ascending=[True, False])
print("Clip-level Top-1 accuracy:", df["hit@1"].mean())
try:
    display(df.head(10))
except NameError:
    print(df.head(10).to_string(index=False))

  ckpt = torch.load(ckpt_path, map_location="cpu")
  ckpt = torch.load(ckpt_path, map_location="cpu")


Class order from ckpt: ['cel', 'cla', 'flu', 'gac', 'gel', 'org', 'pia', 'sax', 'tru', 'vio', 'voi']


FileNotFoundError: e:\qingchaolaopian\Instrument Sound\Github\ML-based-analysis-of-sound\.cache\mels\irmas\test\(02) dont kill the whale-1__53f276b471__sr44100_dur3.0_m128_w30_h10_s0.npy

In [None]:
per_class = []
for ci, cname in enumerate(train_order_names):
    mask = df["true_labels"].str.contains(rf"(?:^|[|]){cname}(?:$|[|])")
    support = int(mask.sum())
    correct = int((df.loc[mask, "pred_label"] == cname).sum())
    per_class.append({"class": cname, "support": support, "hit_rate": (correct / support if support else float("nan"))})
per_class_df = pd.DataFrame(per_class).sort_values("hit_rate", ascending=False)
per_class_df


NameError: name 'df' is not defined