# Data

In [None]:
# One cell: load 1 cache item, resolve codec/vocab/pad from cache+ckpt, plot, run, decode, audio + plots
from pathlib import Path
import json, numpy as np, torch

CACHE_DIR = Path("cache/encodec_acoustic")  # xcodec_acoustic, dac_acoustic
CKPT_PATH = Path("artifacts/checkpoints/encodec_small_single_kit.pt") # expressivegrid_to_xcodec, expressivegrid_to_dac
DEVICE = "cuda:0"         # model forward
DECODE_DEVICE = "cuda:0"  # codec decode (set to "cpu" if VRAM is tight)
EVAL_SR = 32000

# --- pick 1 item ---
manifest = sorted(CACHE_DIR.glob("manifest_midigroove_test_*.jsonl"))[0]
rec = json.loads(manifest.read_text().splitlines()[0])
npz_path = Path(rec["npz"])
print("NPZ:", npz_path)

with np.load(npz_path, allow_pickle=False) as d:
    ex = {k: np.asarray(d[k]) for k in d.files}

# --- load conditioning + targets ---
drum_hit = ex["drum_hit"].astype(np.float32)                                # [D,T]
drum_vel = ex.get("drum_vel", np.zeros_like(drum_hit)).astype(np.float32)   # [D,T]
drum_sus = ex.get("drum_sustain", np.zeros_like(drum_hit)).astype(np.float32)  # [D,T]
hh_cc4   = ex.get("hh_open_cc4", np.zeros((drum_hit.shape[1],), np.float32)).astype(np.float32)  # [T]
beat_pos = ex["beat_pos"].astype(np.int64)                                  # [T]
bpm = float(ex.get("bpm", 120.0))
drummer_id = int(ex.get("drummer_id", 0))
tgt = ex["tgt"].astype(np.int64)                                            # [C,T]
D, T = drum_hit.shape
C = int(tgt.shape[0])
print(f"D={D} T={T} C={C} bpm={bpm} drummer_id={drummer_id}")

# --- resolve codec from cache semantics (fallback to ckpt cfg) ---
cache_sem = json.loads(str(ex.get("semantics", np.asarray("{}")).item() or "{}"))
cache_codec = str(cache_sem.get("encoder", "") or "").strip().lower() or None

from midigroove_poc import expressivegrid as eg
ckpt = torch.load(CKPT_PATH, map_location="cpu")
cfg = ckpt.get("cfg", {}) if isinstance(ckpt.get("cfg", {}), dict) else {}
ckpt_codec = str(cfg.get("encoder_model", "") or "").strip().lower() or None
codec = cache_codec or ckpt_codec or "encodec"

# --- resolve vocab_size / pad_id from ckpt (new) or infer from head (old) ---
num_codebooks = int(ckpt["num_codebooks"])
state = ckpt["model"]
vocab_size = int(cfg.get("vocab_size", 0) or 0)
if vocab_size <= 0:
    vs = eg._infer_vocab_size_from_state_dict(state, num_codebooks=num_codebooks)
    vocab_size = int(vs) if vs is not None else int(eg._vocab_size_for_codebook(eg._default_codebook_size_for_encoder(codec)))
pad_id = int(cfg.get("pad_id", vocab_size - 1 if vocab_size > 1 else 2048))

print("resolved codec:", codec, "vocab_size:", vocab_size, "pad_id:", pad_id)

# --- plotting inputs (always show all lanes) ---
try:
    import matplotlib.pyplot as plt
except Exception as e:
    raise RuntimeError("Install matplotlib in this env (e.g. `pip install matplotlib`)") from e

from data.midigroove_encodec_dataset import CHANNELS
fig, axs = plt.subplots(5, 1, figsize=(14, 10), sharex=True)

axs[0].imshow(drum_hit, aspect="auto", origin="lower", interpolation="nearest")
axs[0].set_title("drum_hit [D,T]")
axs[0].set_yticks(range(len(CHANNELS)))
axs[0].set_yticklabels(CHANNELS, fontsize=8)

axs[1].imshow(drum_vel, aspect="auto", origin="lower", interpolation="nearest")
axs[1].set_title("drum_vel [D,T]")

axs[2].imshow(drum_sus, aspect="auto", origin="lower", interpolation="nearest")
axs[2].set_title("drum_sustain [D,T] (optional)")

axs[3].plot(hh_cc4)
axs[3].set_ylim(-0.05, 1.05)
axs[3].set_title("hh_open_cc4 [T] (optional)")

axs[4].imshow(beat_pos[None, :], aspect="auto", origin="lower", interpolation="nearest", vmin=0, vmax=3)
axs[4].set_title("beat_pos [T] (0..3)")
axs[4].set_xlabel("frame")

plt.tight_layout()
plt.show()

# --- build input grid respecting ckpt feature flags ---
include_sustain = bool(cfg.get("include_sustain", False))
include_hh_cc4 = bool(cfg.get("include_hh_cc4", False))

pieces = [drum_hit, drum_vel]
if include_sustain:
    pieces.append(drum_sus)
if include_hh_cc4:
    pieces.append(hh_cc4[None, :])

grid = np.concatenate(pieces, axis=0).astype(np.float32)  # [F,T]
in_dim = int(ckpt["in_dim"])
assert grid.shape[0] == in_dim and grid.shape[1] == T, (grid.shape, in_dim, T)

# --- run checkpoint -> predicted tokens ---
cfg2 = dict(cfg)
cfg2.setdefault("encoder_model", codec)
cfg2["vocab_size"] = int(vocab_size)
cfg2["pad_id"] = int(pad_id)

model = eg._build_model(num_codebooks=num_codebooks, in_dim=in_dim, cfg=cfg2)
model.load_state_dict(state, strict=True)
model.to(torch.device(DEVICE)).eval()

grid_t = torch.from_numpy(grid).unsqueeze(0).to(DEVICE)            # [1,F,T]
beat_pos_t = torch.from_numpy(beat_pos).unsqueeze(0).to(DEVICE)    # [1,T]
bpm_t = torch.tensor([bpm], dtype=torch.float32, device=DEVICE)    # [1]
drummer_id_t = torch.tensor([drummer_id], dtype=torch.long, device=DEVICE)  # [1]
valid_mask_t = torch.ones((1, T), dtype=torch.bool, device=DEVICE)

with torch.inference_mode():
    logits = model(grid=grid_t, beat_pos=beat_pos_t, bpm=bpm_t, drummer_id=drummer_id_t, valid_mask=valid_mask_t)  # [1,C,T,V]
    pred = logits.argmax(dim=-1).squeeze(0).to(torch.long).cpu()  # [C,T]

tgt_t = torch.from_numpy(tgt).to(torch.long)  # [C,T]
mask = tgt_t.ne(pad_id)
tok_acc_masked = float(((pred == tgt_t) & mask).sum().item() / max(1, mask.sum().item()))
print("used sustain:", include_sustain, "used hh_cc4:", include_hh_cc4)
print("token_acc(masked):", tok_acc_masked, "pred shape:", tuple(pred.shape), "tgt shape:", tuple(tgt_t.shape))

# for decoding: map PAD -> 0
pred_clean = torch.where(pred == pad_id, torch.zeros_like(pred), pred)
tgt_clean = torch.where(tgt_t == pad_id, torch.zeros_like(tgt_t), tgt_t)

# --- decode audio (prediction vs gt tokens) and compare to original segment ---
from IPython.display import Audio, display
from data.codecs import decode_tokens_to_audio
from midigroove_poc.eval import _load_audio_segment, _resample_linear

audio_path = Path(str(ex["audio_path"].item()))
sr_native = int(ex["sr"].item())
start_sec = float(ex["start_sec"].item())
window_seconds = float(ex["window_seconds"].item())
start_sample = int(round(start_sec * sr_native))
window_samples = int(round(window_seconds * sr_native))

ref, sr_ref = _load_audio_segment(audio_path, start_sample=start_sample, num_samples=window_samples)
ref_rs = _resample_linear(ref, sr_ref, EVAL_SR)

audio_gt_b1, sr_gt = decode_tokens_to_audio(tgt_clean, encoder_model=codec, device=DECODE_DEVICE)
audio_pr_b1, sr_pr = decode_tokens_to_audio(pred_clean, encoder_model=codec, device=DECODE_DEVICE)
gt_rs = _resample_linear(audio_gt_b1[0], sr_gt, EVAL_SR)
pr_rs = _resample_linear(audio_pr_b1[0], sr_pr, EVAL_SR)

N = min(ref_rs.size, gt_rs.size, pr_rs.size)
ref_rs, gt_rs, pr_rs = ref_rs[:N], gt_rs[:N], pr_rs[:N]

print("ref_rs:", ref_rs.shape, "gt_rs:", gt_rs.shape, "pr_rs:", pr_rs.shape)

print("Original audio (resampled):")
display(Audio(ref_rs, rate=EVAL_SR))
print(f"Codec reconstruction from ground-truth tokens ({codec}) (resampled):")
display(Audio(gt_rs, rate=EVAL_SR))
print(f"Model prediction decoded ({codec}) (resampled):")
display(Audio(pr_rs, rate=EVAL_SR))

# --- waveform comparison plots ---
import numpy as np
def l1(a, b): return float(np.mean(np.abs(a - b)))
print("L1(gt vs ref):", l1(gt_rs, ref_rs))
print("L1(pred vs ref):", l1(pr_rs, ref_rs))

t = np.arange(N) / float(EVAL_SR)
plt.figure(figsize=(14,4))
plt.plot(t, ref_rs, label="ref (resampled)", alpha=0.7)
plt.plot(t, gt_rs, label=f"gt decode ({codec})", alpha=0.7)
plt.plot(t, pr_rs, label=f"pred decode ({codec})", alpha=0.7)
plt.xlim(0, min(t[-1], 2.0))
plt.legend()
plt.title("Waveforms (first ~2s)")
plt.xlabel("seconds")
plt.show()


In [None]:
# --- plotting: main (waveform + hit + vel) and optionals (sus + cc4 + beat_pos) ---
try:
    import matplotlib.pyplot as plt
except Exception as e:
    raise RuntimeError("Install matplotlib in this env (e.g. `pip install matplotlib`)") from e

from data.midigroove_encodec_dataset import CHANNELS
from midigroove_poc.eval import _load_audio_segment, _resample_linear

audio_path = Path(str(ex["audio_path"].item()))
sr_native = int(ex["sr"].item())
start_sec = float(ex["start_sec"].item())
window_seconds = float(ex["window_seconds"].item())
start_sample = int(round(start_sec * sr_native))
window_samples = int(round(window_seconds * sr_native))

ref, sr_ref = _load_audio_segment(audio_path, start_sample=start_sample, num_samples=window_samples)
ref_rs = _resample_linear(ref, sr_ref, EVAL_SR)

t_sec = np.arange(ref_rs.size) / float(EVAL_SR)
extent_dt = [0.0, float(window_seconds), 0.0, float(D)]  # x=seconds, y=drum lanes

# Main figure: waveform + drum_hit + drum_vel
fig1, axs1 = plt.subplots(4, 1, figsize=(14, 9), sharex=True, gridspec_kw={"height_ratios": [1, 1.6, 1.6, 0.7]})

axs1[0].plot(t_sec, ref_rs, lw=0.8)
axs1[0].set_title("waveform (original segment, resampled)")
axs1[0].set_ylabel("amp")

axs1[1].imshow(drum_hit, aspect="auto", origin="lower", interpolation="nearest", extent=extent_dt)
axs1[1].set_title("drum_hit [D,T]")
axs1[1].set_yticks(np.arange(D) + 0.5)
axs1[1].set_yticklabels(CHANNELS, fontsize=8)

axs1[2].imshow(drum_vel, aspect="auto", origin="lower", interpolation="nearest", extent=extent_dt)
axs1[2].set_title("drum_vel [D,T]")
axs1[2].set_yticks(np.arange(D) + 0.5)
axs1[2].set_yticklabels(CHANNELS, fontsize=8)
axs1[2].set_xlabel("seconds")

axs1[3].imshow(beat_pos[None, :], aspect="auto", origin="lower", interpolation="nearest",
               extent=[0.0, float(window_seconds), 0.0, 1.0], vmin=0, vmax=3)
axs1[3].set_title("beat_pos [T] (0..3)")
axs1[3].set_yticks([])
axs1[3].set_xlabel("seconds")

plt.tight_layout()
plt.show()

# Optional figure: drum_sustain + hh_open_cc4 + beat_pos
fig2, axs2 = plt.subplots(3, 1, figsize=(14, 7), sharex=True, gridspec_kw={"height_ratios": [1.6, 1.0, 0.7]})

axs2[0].imshow(drum_sus, aspect="auto", origin="lower", interpolation="nearest", extent=extent_dt)
axs2[0].set_title("drum_sustain [D,T] (optional)")
axs2[0].set_yticks(np.arange(D) + 0.5)
axs2[0].set_yticklabels(CHANNELS, fontsize=8)

axs2[1].plot(np.linspace(0.0, float(window_seconds), num=T, endpoint=False), hh_cc4, lw=0.9)
axs2[1].set_ylim(-0.05, 1.05)
axs2[1].set_title("hh_open_cc4 [T] (optional)")
axs2[1].set_ylabel("cc4")

axs2[2].imshow(beat_pos[None, :], aspect="auto", origin="lower", interpolation="nearest",
               extent=[0.0, float(window_seconds), 0.0, 1.0], vmin=0, vmax=3)
axs2[2].set_title("beat_pos [T] (0..3)")
axs2[2].set_yticks([])
axs2[2].set_xlabel("seconds")

plt.tight_layout()
plt.show()


In [None]:
df['kit_name'].value_counts().sort_values(ascending=False)

# LAtex tables

In [4]:
# Big/Small models (one-kit + all-kits) → LaTeX table
# Applies your exact p{..} widths, fixes FAD header (single occurrence), adds OneKit/AllKits separator.

from pathlib import Path
import json
import numpy as np
from IPython.display import Markdown, display

MODEL_SIZE = "small"  # "small" | "big"

RUNS = [
    ("OneKit", Path(f"artifacts/eval/{MODEL_SIZE}_one_kit/summary.json")),
    ("AllKits", Path(f"artifacts/eval/{MODEL_SIZE}_all_kits/summary.json")),
]
SYSTEMS = ["encodec", "xcodec", "dac"]

TABLE_ENV = "table*"
DASH = r"---"
HICOLOR = r"green!15"

# Exact column widths you provided
COLSPEC = (
    r"|p{0.060\linewidth}|p{0.060\linewidth}|"
    r"p{0.060\linewidth}|p{0.060\linewidth}|p{0.060\linewidth}|"
    r"p{0.07\linewidth}|p{0.07\linewidth}|p{0.11\linewidth}|p{0.11\linewidth}|p{0.12\linewidth}|"
    r"p{0.067\linewidth}|p{0.067\linewidth}|p{0.067\linewidth}|"
    r"p{0.064\linewidth}|"
)

def _get_metric_dict(s, keys):
    for k in keys:
        if k in s:
            return s[k]
    return {"mean": np.nan, "std": np.nan}

def _mean(x):
    return float(x.get("mean", np.nan))

def pm(x, digits=3):
    m = _mean(x); sd = float(x.get("std", np.nan))
    return DASH if (not np.isfinite(m) or not np.isfinite(sd)) else f"${m:.{digits}f}\\pm{sd:.{digits}f}$"

def pm_pct(x, digits=1):
    m = _mean(x); sd = float(x.get("std", np.nan))
    return DASH if (not np.isfinite(m) or not np.isfinite(sd)) else f"${(100*m):.{digits}f}\\pm{(100*sd):.{digits}f}$"

def fnum(v, digits=3):
    try:
        v = float(v)
    except Exception:
        return DASH
    return DASH if not np.isfinite(v) else f"${v:.{digits}f}$"

def fadtk_key(sys_dict):
    for k in sys_dict.keys():
        if str(k).startswith("fad_fadtk_"):
            return str(k)
    return None

def fad_val(run_summary, sys_name):
    s = run_summary["systems"].get(sys_name, {})
    k = fadtk_key(s)
    return np.nan if not k else float((s.get(k, {}) or {}).get("fad", np.nan))

def fadinf_val(run_summary, sys_name):
    s = run_summary["systems"].get(sys_name, {})
    k = fadtk_key(s)
    if not k:
        return np.nan
    fi = (s.get(k, {}) or {}).get("fad_inf", {}) or {}
    return float(fi.get("fad_inf", np.nan))

METRICS = [
    dict(name="token_nll",    keys=["token_nll"],    fmt=pm,     digits=3, dir="min"),
    dict(name="token_ppl",    keys=["token_ppl"],    fmt=pm,     digits=1, dir="min"),
    dict(name="token_acc",    keys=["token_acc"],    fmt=pm_pct, digits=1, dir="max"),

    dict(name="rmse",         keys=["rmse"],         fmt=pm,     digits=4, dir="min"),
    dict(name="mae",          keys=["mae"],          fmt=pm,     digits=4, dir="min"),
    dict(name="mr_stft_sc",   keys=["mr_stft_sc"],   fmt=pm,     digits=3, dir="min"),
    dict(name="env_rms_corr", keys=["env_rms_corr"], fmt=pm,     digits=3, dir="max"),
    dict(name="tter_db_mae",  keys=["tter_db_mae"],  fmt=pm,     digits=2, dir="min"),

    dict(name="onset_precision", keys=["onset_precision","onset_prec","onset_p","onset_pr"], fmt=pm_pct, digits=1, dir="max"),
    dict(name="onset_recall",    keys=["onset_recall","onset_rec","onset_r"],                fmt=pm_pct, digits=1, dir="max"),
    dict(name="onset_f1",        keys=["onset_f1"],                                        fmt=pm_pct, digits=1, dir="max"),
]

# ---- load summaries ----
summ = {run_label: json.loads(p.read_text()) for run_label, p in RUNS}

# ---- choose per-run FAD column value: FAD∞ if present, else FAD ----
run_fad_kind = {}
for run_label, _ in RUNS:
    has_inf = any(np.isfinite(fadinf_val(summ[run_label], s)) for s in SYSTEMS)
    run_fad_kind[run_label] = "inf" if has_inf else "fad"

def fad_value(run_label, sys_name):
    return fadinf_val(summ[run_label], sys_name) if run_fad_kind[run_label] == "inf" else fad_val(summ[run_label], sys_name)

# ---- best cells ----
best_cells = set()  # (run_label, sys_name, metric_name)

for run_label, _ in RUNS:
    sysdict = summ[run_label]["systems"]
    for ms in METRICS:
        vals = []
        for sys_name in SYSTEMS:
            x = _get_metric_dict(sysdict[sys_name], ms["keys"])
            m = _mean(x)
            if np.isfinite(m):
                vals.append((sys_name, m))
        if vals:
            best_sys = min(vals, key=lambda t: t[1])[0] if ms["dir"] == "min" else max(vals, key=lambda t: t[1])[0]
            best_cells.add((run_label, best_sys, ms["name"]))

for run_label, _ in RUNS:
    vals = [(s, fad_value(run_label, s)) for s in SYSTEMS]
    vals = [(s, v) for (s, v) in vals if np.isfinite(v)]
    if vals:
        best_cells.add((run_label, min(vals, key=lambda t: t[1])[0], "fad_col"))

def format_metric(run_label, sys_name, sysdict, metric_name):
    ms = next(m for m in METRICS if m["name"] == metric_name)
    x = _get_metric_dict(sysdict, ms["keys"])
    out = ms["fmt"](x, ms["digits"])
    if out != DASH and (run_label, sys_name, metric_name) in best_cells:
        return rf"\cellcolor{{{HICOLOR}}} {out}"
    return out

def format_fad(run_label, sys_name):
    out = fnum(fad_value(run_label, sys_name), digits=3)
    if out != DASH and (run_label, sys_name, "fad_col") in best_cells:
        return rf"\cellcolor{{{HICOLOR}}} {out}"
    return out

def row_for(run_label, sys_name):
    s = summ[run_label]["systems"][sys_name]
    return [
        run_label,
        sys_name,
        format_metric(run_label, sys_name, s, "token_nll"),
        format_metric(run_label, sys_name, s, "token_ppl"),
        format_metric(run_label, sys_name, s, "token_acc"),
        format_metric(run_label, sys_name, s, "rmse"),
        format_metric(run_label, sys_name, s, "mae"),
        format_metric(run_label, sys_name, s, "mr_stft_sc"),
        format_metric(run_label, sys_name, s, "env_rms_corr"),
        format_metric(run_label, sys_name, s, "tter_db_mae"),
        format_metric(run_label, sys_name, s, "onset_precision"),
        format_metric(run_label, sys_name, s, "onset_recall"),
        format_metric(run_label, sys_name, s, "onset_f1"),
        format_fad(run_label, sys_name),
    ]

def fadtk_note(run_label):
    m = summ[run_label].get("fadtk", {}) or {}
    model = m.get("model", "clap-laion-music")
    n_used = (m.get("reference", {}) or {}).get("n_items_fad", m.get("n_used", ""))
    clip = (m.get("clip_dur_s", {}) or {})
    clip_med = clip.get("median", np.nan)
    clip_med_s = f"{clip_med:.3f}s" if np.isfinite(clip_med) else "?"
    which = "FAD$\\infty$" if run_fad_kind[run_label] == "inf" else "FAD"
    return f"{run_label}: {which}, {model}, n={n_used}, clip~med={clip_med_s}"

def H(tex):
    return rf"\shortstack[c]{{\scriptsize {tex}}}"

caption = f"Evaluation on one-kit and all-kits test sets ({MODEL_SIZE} models)."
label = f"tab:eval_{MODEL_SIZE}_models"

# ---- build LaTeX ----
lines = []
lines.append(rf"\begin{{{TABLE_ENV}}}[htbp]")
lines.append(rf"\caption{{{caption}}}")
lines.append(r"\begin{center}")
lines.append(r"\renewcommand{\arraystretch}{1}")
lines.append(r"\setlength{\tabcolsep}{1pt}")
lines.append(r"\resizebox{\textwidth}{!}{%")
lines.append(r"\begin{tabular}{" + COLSPEC + r"}")
lines.append(r"\hline")

# Header row 1: group headers; keep FAD cell empty to avoid repeating “FAD”
lines.append(
    r"\textbf{Eval} & \textbf{Codec} "
    r"& \multicolumn{3}{c|}{\textbf{Token metrics}} "
    r"& \multicolumn{5}{c|}{\textbf{Audio metrics}} "
    r"& \multicolumn{3}{c|}{\textbf{Onset metrics}} "
    r"& \\"
)
lines.append(r"\cline{3-5}\cline{6-10}\cline{11-13}")

# Header row 2: per-metric titles (arrows inline) including a single FAD entry
lines.append(
    r"\textbf{Setting} & \textbf{}"
    r" & " + H(r"\textbf{\textit{NLL}$^{\mathrm{a}}$\,\,$\downarrow$}") +
    r" & " + H(r"\textbf{\textit{PPL}$^{\mathrm{a}}$\,\,$\downarrow$}") +
    r" & " + H(r"\textbf{\textit{Acc(\%)}$^{\mathrm{a}}$\,\,$\uparrow$}") +
    r" & " + H(r"\textbf{\textit{RMSE}$^{\mathrm{b}}$\,\,$\downarrow$}") +
    r" & " + H(r"\textbf{\textit{MAE}$^{\mathrm{b}}$\,\,$\downarrow$}") +
    r" & " + H(r"\textbf{\textit{MR-STFT SC}$^{\mathrm{b}}$\,\,$\downarrow$}") +
    r" & " + H(r"\textbf{\textit{Env RMS corr}$^{\mathrm{b}}$\,\,$\uparrow$}") +
    r" & " + H(r"\textbf{\textit{TTER (dB) MAE}$^{\mathrm{b}}$\,\,$\downarrow$}") +
    r" & " + H(r"\textbf{\textit{P(\%)}$^{\mathrm{c}}$\,\,$\uparrow$}") +
    r" & " + H(r"\textbf{\textit{R(\%)}$^{\mathrm{c}}$\,\,$\uparrow$}") +
    r" & " + H(r"\textbf{\textit{F1(\%)}$^{\mathrm{c}}$\,\,$\uparrow$}") +
    r" & " + H(r"\textbf{\textit{FAD}$^{\mathrm{d}}$\,\,$\downarrow$}") +
    r" \\"
)
lines.append(r"\hline")

# Body + separator between OneKit and AllKits
for run_idx, (run_label, _p) in enumerate(RUNS):
    for sys_name in SYSTEMS:
        lines.append(" " + " & ".join(row_for(run_label, sys_name)) + r" \\")
        lines.append(r"\hline")
    if run_idx == 0 and len(RUNS) > 1:
        lines.append(r"\hline")  # extra separator line between blocks

lines.append(r"\end{tabular}%")
lines.append(r"}")  # resizebox
lines.append(r"\vspace{2pt}")
lines.append(
    r"\parbox{\linewidth}{\footnotesize "
    r"$^{\mathrm{a}}$PAD ignored; mean$\pm$std over windows. "
    r"$^{\mathrm{b}}$Tokens decoded then resampled to 32\,kHz; mean$\pm$std over windows. "
    r"$^{\mathrm{c}}$Onset metrics match predicted-audio onsets to grid-derived GT onsets within 50\,ms (GT velocity$\ge$0.30). "
    r"$^{\mathrm{d}}$fadtk embedding: CLAP-LAION-music; per-run variant: "
    + "; ".join([fadtk_note(rl) for rl, _ in RUNS])
    + r".}"
)
lines.append(rf"\label{{{label}}}")
lines.append(r"\end{center}")
lines.append(rf"\end{{{TABLE_ENV}}}")

latex = "\n".join(lines)
display(Markdown("```latex\n" + latex + "\n```"))


```latex
\begin{table*}[htbp]
\caption{Evaluation on one-kit and all-kits test sets (small models).}
\begin{center}
\renewcommand{\arraystretch}{1}
\setlength{\tabcolsep}{1pt}
\resizebox{\textwidth}{!}{%
\begin{tabular}{|p{0.060\linewidth}|p{0.060\linewidth}|p{0.060\linewidth}|p{0.060\linewidth}|p{0.060\linewidth}|p{0.07\linewidth}|p{0.07\linewidth}|p{0.11\linewidth}|p{0.11\linewidth}|p{0.12\linewidth}|p{0.067\linewidth}|p{0.067\linewidth}|p{0.067\linewidth}|p{0.064\linewidth}|}
\hline
\textbf{Eval} & \textbf{Codec} & \multicolumn{3}{c|}{\textbf{Token metrics}} & \multicolumn{5}{c|}{\textbf{Audio metrics}} & \multicolumn{3}{c|}{\textbf{Onset metrics}} & \\
\cline{3-5}\cline{6-10}\cline{11-13}
\textbf{Setting} & \textbf{} & \shortstack[c]{\scriptsize \textbf{\textit{NLL}$^{\mathrm{a}}$\,\,$\downarrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{PPL}$^{\mathrm{a}}$\,\,$\downarrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{Acc(\%)}$^{\mathrm{a}}$\,\,$\uparrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{RMSE}$^{\mathrm{b}}$\,\,$\downarrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{MAE}$^{\mathrm{b}}$\,\,$\downarrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{MR-STFT SC}$^{\mathrm{b}}$\,\,$\downarrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{Env RMS corr}$^{\mathrm{b}}$\,\,$\uparrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{TTER (dB) MAE}$^{\mathrm{b}}$\,\,$\downarrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{P(\%)}$^{\mathrm{c}}$\,\,$\uparrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{R(\%)}$^{\mathrm{c}}$\,\,$\uparrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{F1(\%)}$^{\mathrm{c}}$\,\,$\uparrow$}} & \shortstack[c]{\scriptsize \textbf{\textit{FAD}$^{\mathrm{d}}$\,\,$\downarrow$}} \\
\hline
 OneKit & encodec & \cellcolor{green!15} $2.142\pm0.681$ & \cellcolor{green!15} $10.8\pm8.3$ & \cellcolor{green!15} $42.7\pm13.7$ & $0.0201\pm0.0106$ & $0.0100\pm0.0058$ & \cellcolor{green!15} $0.842\pm0.160$ & \cellcolor{green!15} $0.690\pm0.228$ & \cellcolor{green!15} $1.29\pm1.22$ & \cellcolor{green!15} $78.3\pm15.5$ & \cellcolor{green!15} $68.4\pm17.8$ & \cellcolor{green!15} $71.0\pm13.3$ & \cellcolor{green!15} $0.281$ \\
\hline
 OneKit & xcodec & $4.422\pm0.590$ & $102.1\pm81.8$ & $11.9\pm3.9$ & $0.0305\pm0.0176$ & $0.0161\pm0.0110$ & $1.357\pm0.702$ & $0.552\pm0.242$ & $1.92\pm1.60$ & $76.6\pm17.6$ & $64.8\pm16.5$ & $67.8\pm12.4$ & $0.350$ \\
\hline
 OneKit & dac & $6.265\pm0.450$ & $563.5\pm236.7$ & $3.8\pm6.0$ & \cellcolor{green!15} $0.0184\pm0.0094$ & \cellcolor{green!15} $0.0095\pm0.0054$ & $0.982\pm0.085$ & $0.580\pm0.235$ & $1.44\pm1.24$ & $69.8\pm15.1$ & $65.0\pm18.1$ & $65.0\pm12.3$ & $0.545$ \\
\hline
\hline
 AllKits & encodec & \cellcolor{green!15} $2.153\pm0.743$ & \cellcolor{green!15} $11.6\pm11.0$ & \cellcolor{green!15} $43.4\pm14.7$ & $0.0200\pm0.0118$ & $0.0103\pm0.0070$ & \cellcolor{green!15} $0.827\pm0.171$ & \cellcolor{green!15} $0.710\pm0.220$ & \cellcolor{green!15} $1.47\pm1.30$ & $78.2\pm15.8$ & \cellcolor{green!15} $68.2\pm18.2$ & \cellcolor{green!15} $70.6\pm13.4$ & \cellcolor{green!15} $0.193$ \\
\hline
 AllKits & xcodec & $4.429\pm0.617$ & $104.9\pm90.0$ & $12.5\pm4.3$ & $0.0336\pm0.0203$ & $0.0176\pm0.0125$ & $1.669\pm0.925$ & $0.568\pm0.252$ & $2.22\pm1.89$ & \cellcolor{green!15} $78.9\pm16.5$ & $64.4\pm16.7$ & $68.8\pm12.7$ & $0.277$ \\
\hline
 AllKits & dac & $6.153\pm0.540$ & $521.8\pm304.5$ & $4.7\pm6.8$ & \cellcolor{green!15} $0.0190\pm0.0109$ & \cellcolor{green!15} $0.0099\pm0.0065$ & $1.034\pm0.179$ & $0.602\pm0.265$ & $1.99\pm1.66$ & $75.7\pm15.6$ & $68.1\pm17.5$ & $69.3\pm12.2$ & $0.405$ \\
\hline
\end{tabular}%
}
\vspace{2pt}
\parbox{\linewidth}{\footnotesize $^{\mathrm{a}}$PAD ignored; mean$\pm$std over windows. $^{\mathrm{b}}$Tokens decoded then resampled to 32\,kHz; mean$\pm$std over windows. $^{\mathrm{c}}$Onset metrics match predicted-audio onsets to grid-derived GT onsets within 50\,ms (GT velocity$\ge$0.30). $^{\mathrm{d}}$fadtk embedding: CLAP-LAION-music; per-run variant: OneKit: FAD$\infty$, clap-laion-music, n=1748, clip~med=2.286s; AllKits: FAD$\infty$, clap-laion-music, n=68180, clip~med=2.286s.}
\label{tab:eval_small_models}
\end{center}
\end{table*}
```

In [3]:
from pathlib import Path
import json
import numpy as np
import pandas as pd
from IPython.display import Markdown, display

K = 3
OUT_DIR = Path("artifacts/eval/small_all_kits")
FAD_PER_KIT_CSV = Path("artifacts/eval/fadtk_per_kit_small_all_kits/fadtk_per_kit.csv")

TABLE_ENV = "table*"
HICOLOR = r"green!15"
DASH = r"---"

# Compact col widths
COLSPEC = (
    r"|p{0.055\linewidth}|p{0.165\linewidth}|p{0.055\linewidth}|"
    r"p{0.050\linewidth}|p{0.050\linewidth}|p{0.055\linewidth}|"
    r"p{0.060\linewidth}|p{0.060\linewidth}|p{0.070\linewidth}|p{0.070\linewidth}|p{0.070\linewidth}|"
    r"p{0.052\linewidth}|p{0.052\linewidth}|p{0.052\linewidth}|"
    r"p{0.055\linewidth}|"
)

def load_eval(out_dir: Path):
    s = json.loads((out_dir / "summary.json").read_text(encoding="utf-8"))
    it = pd.read_csv(out_dir / "items.csv")
    return s, it

def base_systems(summary: dict) -> list[str]:
    return [k for k in summary["systems"].keys() if not k.endswith(("_oracle", "_random"))]

def codec_label(sys_key: str) -> str:
    s = str(sys_key).lower()
    if s.startswith("encodec"): return "encodec"
    if s.startswith("xcodec"):  return "xcodec"
    if s.startswith("dac"):     return "dac"
    return s

def H(tex: str) -> str:
    return rf"\shortstack[c]{{\scriptsize {tex}}}"

def mean_only(mu: float, digits: int, *, pct: bool=False) -> str:
    if not np.isfinite(mu):
        return DASH
    if pct:
        mu *= 100.0
    return rf"${mu:.{digits}f}$"

def fnum(v: float, digits: int = 3) -> str:
    return DASH if not np.isfinite(v) else rf"${float(v):.{digits}f}$"

summary, items = load_eval(OUT_DIR)
bases = base_systems(summary)

if "token_nll" not in items.columns:
    raise RuntimeError("items.csv must contain token_nll")

items = items.copy()
items["token_ppl"] = np.exp(items["token_nll"].astype(float))

encodec_sys = next((b for b in bases if str(b).lower().startswith("encodec")), None)
if encodec_sys is None:
    raise RuntimeError(f"No EnCodec system in {bases}")

enc_kit_acc = (
    items.loc[items["system"] == encodec_sys]
    .groupby("kit")["token_acc"].mean()
    .sort_values(ascending=True)
)
worst_kits = enc_kit_acc.head(K).index.tolist()
top_kits   = enc_kit_acc.tail(K).index.tolist()[::-1]
selected_kits = top_kits + worst_kits

METRICS = [
    dict(key="token_nll",        disp=r"\textbf{NLL}$^{\mathrm{a}}$\,\,$\downarrow$",  digits=2, dir="min", pct=False),
    dict(key="token_ppl",        disp=r"\textbf{PPL}$^{\mathrm{a}}$\,\,$\downarrow$",  digits=1, dir="min", pct=False),
    dict(key="token_acc",        disp=r"\textbf{Acc(\%)}$^{\mathrm{a}}$\,\,$\uparrow$",digits=1, dir="max", pct=True),

    dict(key="rmse",             disp=r"\textbf{RMSE}$^{\mathrm{b}}$\,\,$\downarrow$", digits=4, dir="min", pct=False),
    dict(key="mae",              disp=r"\textbf{MAE}$^{\mathrm{b}}$\,\,$\downarrow$",  digits=4, dir="min", pct=False),
    dict(key="mr_stft_sc",       disp=r"\textbf{MRSTFT}$^{\mathrm{b}}$\,\,$\downarrow$",digits=3, dir="min", pct=False),
    dict(key="env_rms_corr",     disp=r"\textbf{Env}$^{\mathrm{b}}$\,\,$\uparrow$",    digits=2, dir="max", pct=False),
    dict(key="tter_db_mae",      disp=r"\textbf{TTER}$^{\mathrm{b}}$\,\,$\downarrow$", digits=2, dir="min", pct=False),

    dict(key="onset_precision",  disp=r"\textbf{P(\%)}$^{\mathrm{c}}$\,\,$\uparrow$",  digits=0, dir="max", pct=True),
    dict(key="onset_recall",     disp=r"\textbf{R(\%)}$^{\mathrm{c}}$\,\,$\uparrow$",  digits=0, dir="max", pct=True),
    dict(key="onset_f1",         disp=r"\textbf{F1(\%)}$^{\mathrm{c}}$\,\,$\uparrow$", digits=0, dir="max", pct=True),
]

metric_keys = [m["key"] for m in METRICS if m["key"] in items.columns]
METRICS = [m for m in METRICS if m["key"] in items.columns]

sub = items.loc[items["kit"].isin(selected_kits) & items["system"].isin(bases)].copy()
g_mean = sub.groupby(["kit", "system"])[metric_keys].mean()

fad_map = {}
if FAD_PER_KIT_CSV.is_file():
    fad_df = pd.read_csv(FAD_PER_KIT_CSV)
    fad_df["kit"] = fad_df["kit"].astype(str)
    fad_df["system"] = fad_df["system"].astype(str)
    fad_map = {(r["kit"], r["system"]): float(r["fad"]) for _, r in fad_df.iterrows()}
HAS_FAD = any((kit, sys) in fad_map and np.isfinite(fad_map[(kit, sys)]) for kit in selected_kits for sys in bases)

best_cells = set()
for kit in selected_kits:
    for m in METRICS:
        vals = []
        for sys in bases:
            if (kit, sys) in g_mean.index:
                v = float(g_mean.loc[(kit, sys), m["key"]])
                if np.isfinite(v):
                    vals.append((sys, v))
        if vals:
            best_sys = min(vals, key=lambda t: t[1])[0] if m["dir"] == "min" else max(vals, key=lambda t: t[1])[0]
            best_cells.add((kit, best_sys, m["key"]))
if HAS_FAD:
    for kit in selected_kits:
        vals = [(sys, fad_map.get((kit, sys), np.nan)) for sys in bases]
        vals = [(sys, float(v)) for (sys, v) in vals if np.isfinite(v)]
        if vals:
            best_cells.add((kit, min(vals, key=lambda t: t[1])[0], "__fad__"))

def metric_cell(kit: str, sys: str, m: dict) -> str:
    mu = float(g_mean.loc[(kit, sys), m["key"]]) if (kit, sys) in g_mean.index else np.nan
    out = mean_only(mu, m["digits"], pct=m["pct"])
    if out != DASH and (kit, sys, m["key"]) in best_cells:
        return rf"\cellcolor{{{HICOLOR}}} {out}"
    return out

def fad_cell(kit: str, sys: str) -> str:
    v = fad_map.get((kit, sys), np.nan)
    out = fnum(v, digits=3)
    if out != DASH and (kit, sys, "__fad__") in best_cells:
        return rf"\cellcolor{{{HICOLOR}}} {out}"
    return out

caption = rf"Top-{K} and worst-{K} kits by EnCodec token accuracy (all-kits). Best values per kit highlighted."
label = "tab:top_worst_kits_allkits_compact"

lines = []
lines.append(rf"\begin{{{TABLE_ENV}}}[htbp]")
lines.append(rf"\caption{{{caption}}}")
lines.append(r"\begin{center}")
lines.append(r"\renewcommand{\arraystretch}{0.95}")
lines.append(r"\setlength{\tabcolsep}{0.6pt}")
lines.append(r"\scriptsize")
lines.append(r"\resizebox{\textwidth}{!}{%")
lines.append(r"\begin{tabular}{" + COLSPEC + r"}")
lines.append(r"\hline")

lines.append(
    r"\textbf{Eval} & \textbf{Kit} & \textbf{Codec} "
    r"& \multicolumn{3}{c|}{\textbf{Token}} "
    r"& \multicolumn{5}{c|}{\textbf{Audio}} "
    r"& \multicolumn{3}{c|}{\textbf{Onset}} "
    r"& \\"
)
lines.append(r"\cline{4-6}\cline{7-11}\cline{12-14}")

lines.append(
    r"\textbf{Set} & \textbf{} & \textbf{}"
    r" & " + H(METRICS[0]["disp"]) +
    r" & " + H(METRICS[1]["disp"]) +
    r" & " + H(METRICS[2]["disp"]) +
    r" & " + H(METRICS[3]["disp"]) +
    r" & " + H(METRICS[4]["disp"]) +
    r" & " + H(METRICS[5]["disp"]) +
    r" & " + H(METRICS[6]["disp"]) +
    r" & " + H(METRICS[7]["disp"]) +
    r" & " + H(METRICS[8]["disp"]) +
    r" & " + H(METRICS[9]["disp"]) +
    r" & " + H(METRICS[10]["disp"]) +
    r" & " + H(r"\textbf{FAD}$^{\mathrm{d}}$\,\,$\downarrow$") +
    r" \\"
)
lines.append(r"\hline")

def emit_block(block_name, kits):
    for kit in kits:
        first = True
        for sys in bases:
            row = [rf"\textbf{{{block_name}}}" if first else "", kit if first else "", codec_label(sys)]
            for m in METRICS:
                row.append(metric_cell(kit, sys, m))
            row.append(fad_cell(kit, sys) if HAS_FAD else DASH)
            lines.append(" " + " & ".join(row) + r" \\")
            first = False
        lines.append(r"\hline")

emit_block("Top", top_kits)
lines.append(r"\hline\hline")
emit_block("Worst", worst_kits)

lines.append(r"\end{tabular}%")
lines.append(r"}")  # resizebox
lines.append(r"\vspace{1pt}")
lines.append(
    r"\parbox{\linewidth}{\scriptsize "
    r"$^{\mathrm{a}}$PAD ignored. "
    r"$^{\mathrm{b}}$Decoded audio at 32\,kHz. "
    r"$^{\mathrm{c}}$Onsets: pred-audio vs grid GT within 50\,ms (GT vel$>0.30$). "
    r"$^{\mathrm{d}}$Per-kit FAD from \texttt{fadtk} (CLAP-LAION-music).}"
)
lines.append(rf"\label{{{label}}}")
lines.append(r"\end{center}")
lines.append(rf"\end{{{TABLE_ENV}}}")

display(Markdown("```latex\n" + "\n".join(lines) + "\n```"))


```latex
\begin{table*}[htbp]
\caption{Top-3 and worst-3 kits by EnCodec token accuracy (all-kits). Best values per kit highlighted.}
\begin{center}
\renewcommand{\arraystretch}{0.95}
\setlength{\tabcolsep}{0.6pt}
\scriptsize
\resizebox{\textwidth}{!}{%
\begin{tabular}{|p{0.055\linewidth}|p{0.165\linewidth}|p{0.055\linewidth}|p{0.050\linewidth}|p{0.050\linewidth}|p{0.055\linewidth}|p{0.060\linewidth}|p{0.060\linewidth}|p{0.070\linewidth}|p{0.070\linewidth}|p{0.070\linewidth}|p{0.052\linewidth}|p{0.052\linewidth}|p{0.052\linewidth}|p{0.055\linewidth}|}
\hline
\textbf{Eval} & \textbf{Kit} & \textbf{Codec} & \multicolumn{3}{c|}{\textbf{Token}} & \multicolumn{5}{c|}{\textbf{Audio}} & \multicolumn{3}{c|}{\textbf{Onset}} & \\
\cline{4-6}\cline{7-11}\cline{12-14}
\textbf{Set} & \textbf{} & \textbf{} & \shortstack[c]{\scriptsize \textbf{NLL}$^{\mathrm{a}}$\,\,$\downarrow$} & \shortstack[c]{\scriptsize \textbf{PPL}$^{\mathrm{a}}$\,\,$\downarrow$} & \shortstack[c]{\scriptsize \textbf{Acc(\%)}$^{\mathrm{a}}$\,\,$\uparrow$} & \shortstack[c]{\scriptsize \textbf{RMSE}$^{\mathrm{b}}$\,\,$\downarrow$} & \shortstack[c]{\scriptsize \textbf{MAE}$^{\mathrm{b}}$\,\,$\downarrow$} & \shortstack[c]{\scriptsize \textbf{MRSTFT}$^{\mathrm{b}}$\,\,$\downarrow$} & \shortstack[c]{\scriptsize \textbf{Env}$^{\mathrm{b}}$\,\,$\uparrow$} & \shortstack[c]{\scriptsize \textbf{TTER}$^{\mathrm{b}}$\,\,$\downarrow$} & \shortstack[c]{\scriptsize \textbf{P(\%)}$^{\mathrm{c}}$\,\,$\uparrow$} & \shortstack[c]{\scriptsize \textbf{R(\%)}$^{\mathrm{c}}$\,\,$\uparrow$} & \shortstack[c]{\scriptsize \textbf{F1(\%)}$^{\mathrm{c}}$\,\,$\uparrow$} & \shortstack[c]{\scriptsize \textbf{FAD}$^{\mathrm{d}}$\,\,$\downarrow$} \\
\hline
 \textbf{Top} & Shuffle (Blues) & dac & $6.17$ & $521.7$ & $4.8$ & \cellcolor{green!15} $0.0131$ & \cellcolor{green!15} $0.0063$ & $1.080$ & $0.68$ & $2.46$ & $78$ & \cellcolor{green!15} $67$ & \cellcolor{green!15} $70$ & $0.449$ \\
  &  & encodec & \cellcolor{green!15} $1.83$ & \cellcolor{green!15} $7.9$ & \cellcolor{green!15} $50.5$ & $0.0134$ & $0.0066$ & \cellcolor{green!15} $0.826$ & \cellcolor{green!15} $0.74$ & \cellcolor{green!15} $1.40$ & \cellcolor{green!15} $83$ & $58$ & $66$ & \cellcolor{green!15} $0.332$ \\
  &  & xcodec & $4.49$ & $109.8$ & $12.1$ & $0.0252$ & $0.0139$ & $2.072$ & $0.47$ & $2.10$ & $76$ & $66$ & $68$ & $0.407$ \\
\hline
 \textbf{Top} & Warmer Funk & dac & $5.89$ & $429.1$ & $7.6$ & \cellcolor{green!15} $0.0150$ & \cellcolor{green!15} $0.0061$ & $0.996$ & $0.65$ & $1.86$ & $75$ & $70$ & $71$ & $0.408$ \\
  &  & encodec & \cellcolor{green!15} $1.89$ & \cellcolor{green!15} $9.0$ & \cellcolor{green!15} $50.2$ & $0.0164$ & $0.0068$ & \cellcolor{green!15} $0.847$ & \cellcolor{green!15} $0.70$ & \cellcolor{green!15} $1.85$ & $77$ & \cellcolor{green!15} $73$ & \cellcolor{green!15} $73$ & \cellcolor{green!15} $0.253$ \\
  &  & xcodec & $4.44$ & $110.4$ & $12.2$ & $0.0267$ & $0.0107$ & $1.566$ & $0.64$ & $1.96$ & \cellcolor{green!15} $79$ & $62$ & $68$ & $0.310$ \\
\hline
 \textbf{Top} & 60s Rock & dac & $5.95$ & $444.1$ & $7.4$ & \cellcolor{green!15} $0.0139$ & \cellcolor{green!15} $0.0065$ & $0.973$ & $0.65$ & $1.83$ & $77$ & $67$ & $69$ & $0.346$ \\
  &  & encodec & \cellcolor{green!15} $1.84$ & \cellcolor{green!15} $8.1$ & \cellcolor{green!15} $49.9$ & $0.0153$ & $0.0071$ & \cellcolor{green!15} $0.844$ & \cellcolor{green!15} $0.72$ & \cellcolor{green!15} $1.61$ & $80$ & \cellcolor{green!15} $68$ & \cellcolor{green!15} $72$ & \cellcolor{green!15} $0.266$ \\
  &  & xcodec & $4.38$ & $103.0$ & $13.1$ & $0.0222$ & $0.0100$ & $1.274$ & $0.65$ & $1.83$ & \cellcolor{green!15} $81$ & $66$ & $71$ & $0.329$ \\
\hline
\hline\hline
 \textbf{Worst} & Classic Rock & dac & $6.33$ & $610.1$ & $3.3$ & $0.0234$ & \cellcolor{green!15} $0.0132$ & $1.023$ & $0.62$ & $2.64$ & $72$ & \cellcolor{green!15} $66$ & $66$ & $0.499$ \\
  &  & encodec & \cellcolor{green!15} $2.63$ & \cellcolor{green!15} $19.3$ & \cellcolor{green!15} $34.5$ & \cellcolor{green!15} $0.0226$ & $0.0132$ & \cellcolor{green!15} $0.788$ & \cellcolor{green!15} $0.72$ & \cellcolor{green!15} $1.18$ & $77$ & $65$ & \cellcolor{green!15} $68$ & \cellcolor{green!15} $0.197$ \\
  &  & xcodec & $4.52$ & $112.8$ & $12.0$ & $0.0450$ & $0.0266$ & $2.177$ & $0.40$ & $2.82$ & \cellcolor{green!15} $77$ & $61$ & $66$ & $0.455$ \\
\hline
 \textbf{Worst} & Arena Stage & dac & $6.32$ & $593.7$ & $3.1$ & $0.0265$ & $0.0159$ & $1.153$ & $0.50$ & $2.65$ & $75$ & $68$ & $69$ & $0.489$ \\
  &  & encodec & \cellcolor{green!15} $2.59$ & \cellcolor{green!15} $18.3$ & \cellcolor{green!15} $34.5$ & \cellcolor{green!15} $0.0241$ & \cellcolor{green!15} $0.0146$ & \cellcolor{green!15} $0.780$ & \cellcolor{green!15} $0.70$ & \cellcolor{green!15} $1.63$ & $78$ & \cellcolor{green!15} $69$ & \cellcolor{green!15} $70$ & \cellcolor{green!15} $0.155$ \\
  &  & xcodec & $4.44$ & $107.0$ & $12.2$ & $0.0404$ & $0.0247$ & $1.604$ & $0.49$ & $2.67$ & \cellcolor{green!15} $79$ & $66$ & $69$ & $0.334$ \\
\hline
 \textbf{Worst} & Ele-Drum & dac & $6.05$ & $483.3$ & $5.9$ & \cellcolor{green!15} $0.0250$ & \cellcolor{green!15} $0.0143$ & $0.986$ & $0.70$ & $1.62$ & $70$ & $67$ & $66$ & $0.325$ \\
  &  & encodec & \cellcolor{green!15} $2.46$ & \cellcolor{green!15} $14.7$ & \cellcolor{green!15} $35.9$ & $0.0262$ & $0.0149$ & \cellcolor{green!15} $0.834$ & \cellcolor{green!15} $0.75$ & \cellcolor{green!15} $1.59$ & $74$ & $66$ & $67$ & $0.187$ \\
  &  & xcodec & $4.27$ & $89.8$ & $12.9$ & $0.0581$ & $0.0355$ & $2.486$ & $0.54$ & $1.88$ & \cellcolor{green!15} $76$ & \cellcolor{green!15} $68$ & \cellcolor{green!15} $70$ & \cellcolor{green!15} $0.155$ \\
\hline
\end{tabular}%
}
\vspace{1pt}
\parbox{\linewidth}{\scriptsize $^{\mathrm{a}}$PAD ignored. $^{\mathrm{b}}$Decoded audio at 32\,kHz. $^{\mathrm{c}}$Onsets: pred-audio vs grid GT within 50\,ms (GT vel$>0.30$). $^{\mathrm{d}}$Per-kit FAD from \texttt{fadtk} (CLAP-LAION-music).}
\label{tab:top_worst_kits_allkits_compact}
\end{center}
\end{table*}
```

# Eval metrics viz

In [None]:
# One-example onset-metrics debug (GT = cached GRID onsets w/ filtering) + drum grid plot + audio playback
# GT excludes hh_closed + low-velocity hits; preds from audio onsets (madmom if available); ±50ms eval.

from pathlib import Path
import os, json, hashlib
import numpy as np
import matplotlib.pyplot as plt
import wave
import importlib

from IPython.display import Audio, display

os.environ.setdefault("MIDIGROOVE_ONSET_BACKEND", "madmom")  # "madmom" | "native"

from midigroove_poc import eval as mg_eval
importlib.reload(mg_eval)

PRED_RUN_DIR = Path("artifacts/pred/small_one_kit")
CACHE_DIR = Path("cache/encodec_acoustic")
SYSTEMS = ["encodec", "xcodec", "dac"]

TOL_MS = 50.0
MAX_SECONDS_TO_SHOW = None

ONSET_KW = dict(min_separation_s=0.05, backtrack_ms=0.0, refine_ms=12.0, rms_gate_db=35.0)

GT_VEL_THRESH = 0.30
GT_EXCLUDE = []

def read_wav_mono_float32(path: Path):
    with wave.open(str(path), "rb") as wf:
        sr = wf.getframerate()
        n = wf.getnframes()
        sampwidth = wf.getsampwidth()
        nch = wf.getnchannels()
        x = wf.readframes(n)
    if sampwidth != 2:
        raise RuntimeError(f"Expected 16-bit PCM wav, got sampwidth={sampwidth} at {path}")
    y = np.frombuffer(x, dtype=np.int16).astype(np.float32) / 32767.0
    if nch > 1:
        y = y.reshape(-1, nch).mean(axis=1)
    return y, int(sr)

def stable_key_str(audio_path: str, midi_path: str, sr: int, start_sample: int, window_samples: int) -> str:
    return hashlib.sha1(f"{audio_path}|{midi_path}|{sr}|{start_sample}|{window_samples}".encode("utf-8")).hexdigest()[:16]

def vlines(ax, samps, *, sr, color, alpha, lw, label=None):
    if samps is None or len(samps) == 0:
        return
    xs = np.asarray(samps, dtype=np.float64) / float(sr)
    ax.vlines(xs, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1], color=color, alpha=alpha, linewidth=lw, label=label)

def shade_tol(ax, ref_samps, *, sr, tol_samps, color="green", alpha=0.06):
    for r in np.asarray(ref_samps, dtype=np.int64):
        t0 = (r - tol_samps) / float(sr)
        t1 = (r + tol_samps) / float(sr)
        ax.axvspan(t0, t1, color=color, alpha=alpha, linewidth=0)

def clip_to_seconds(y, sr, sec):
    if sec is None:
        return y
    n = int(min(y.size, round(float(sec) * float(sr))))
    return y[:n]

# --- pick a saved item present for all systems ---
avail = []
for sys in SYSTEMS:
    ids = {p.stem for p in (PRED_RUN_DIR / "pred" / sys).glob("*.wav")}
    if not ids:
        raise RuntimeError(f"No saved preds in {PRED_RUN_DIR/'pred'/sys}")
    avail.append(ids)
item_id = sorted(set.intersection(*avail))[0]
print("Using saved item_id:", item_id)

ref_wav_path = PRED_RUN_DIR / "ref" / f"{item_id}.wav"
if not ref_wav_path.is_file():
    raise FileNotFoundError(f"Missing ref wav: {ref_wav_path}")

summary = json.loads(Path("artifacts/eval/small_one_kit/summary.json").read_text(encoding="utf-8"))
eval_sr = int(summary["eval_sr"])
tol_samps = max(1, int(round((TOL_MS * 1e-3) * eval_sr)))
print("eval_sr:", eval_sr, "tol_ms:", TOL_MS, "tol_samps:", tol_samps)

# --- find matching npz in cache test manifest, load grid for visualization ---
manifest = sorted(CACHE_DIR.glob("manifest_midigroove_test_*.jsonl"))[0]
npz_match = None
drum_hit = drum_vel = None
window_sec = None

for line in manifest.read_text(encoding="utf-8").splitlines():
    rec = json.loads(line)
    npz = Path(rec["npz"])
    with np.load(npz, allow_pickle=False) as d:
        audio_path = str(d["audio_path"].item())
        midi_p = str(d["midi_path"].item())
        sr0 = int(d["sr"].item())
        st = float(d["start_sec"].item())
        ws = float(d["window_seconds"].item())
        ss = int(round(st * sr0))
        ns = int(round(ws * sr0))
        if stable_key_str(audio_path, midi_p, sr0, ss, ns) != item_id:
            continue
        npz_match = npz
        window_sec = float(ws)
        drum_hit = np.asarray(d["drum_hit"], dtype=np.float32)
        drum_vel = np.asarray(d["drum_vel"], dtype=np.float32) if "drum_vel" in d else np.zeros_like(drum_hit)
        break

if npz_match is None:
    raise RuntimeError(f"Could not find item_id={item_id} in {manifest}")
print("npz:", npz_match.name, "D,T:", drum_hit.shape, "window_sec:", window_sec)

# --- load reference audio ---
y_ref, sr_ref = read_wav_mono_float32(ref_wav_path)
assert sr_ref == eval_sr

# GT from GRID (filtered)
gt_onsets = mg_eval._onsets_from_grid_npz(
    npz_match,
    eval_sr=eval_sr,
    vel_thresh=GT_VEL_THRESH,
    exclude_channels=GT_EXCLUDE,
)
print("grid_gt_onsets(filtered):", len(gt_onsets))

# optional overlay: ref-audio detected onsets (NOT GT)
ref_audio_onsets = mg_eval._onsets_from_audio(y_ref, sr=eval_sr, **ONSET_KW)

# --- per-system pred onsets + scores vs GRID GT ---
pred_audio, pred_onsets, scores = {}, {}, {}
for sys in SYSTEMS:
    y, sr = read_wav_mono_float32(PRED_RUN_DIR / "pred" / sys / f"{item_id}.wav")
    assert sr == eval_sr
    pred_audio[sys] = y
    pred_onsets[sys] = mg_eval._onsets_from_audio(y, sr=eval_sr, **ONSET_KW)

    m = mg_eval._onset_pr_metrics(
        pred=y,
        ref=y_ref,
        sr=eval_sr,
        pred_onsets=pred_onsets[sys],
        ref_onsets=gt_onsets,
        midi_path=None, start_sec=None, end_sec=None,
        tol_ms=TOL_MS,
    )
    scores[sys] = m
    print(f"{sys:7s} P={m['onset_precision']:.3f} R={m['onset_recall']:.3f} F1={m['onset_f1']:.3f} (pred={len(pred_onsets[sys])} gt={len(gt_onsets)})")

print("\nAudio: reference then each system prediction")
display(Audio(clip_to_seconds(y_ref, eval_sr, MAX_SECONDS_TO_SHOW), rate=eval_sr))
for sys in SYSTEMS:
    display(Audio(clip_to_seconds(pred_audio[sys], eval_sr, MAX_SECONDS_TO_SHOW), rate=eval_sr))

t = np.arange(y_ref.size) / float(eval_sr)
n = y_ref.size if MAX_SECONDS_TO_SHOW is None else int(min(y_ref.size, round(float(MAX_SECONDS_TO_SHOW) * eval_sr)))

fig, axes = plt.subplots(2 + len(SYSTEMS), 1, figsize=(14, 2.0 * (2 + len(SYSTEMS))), sharex=True)

ax0 = axes[0]
D, T = drum_hit.shape
grid_img = np.concatenate([drum_hit, drum_vel], axis=0)
ax0.imshow(grid_img, aspect="auto", origin="lower", interpolation="nearest",
           extent=[0.0, float(window_sec), 0.0, float(2 * D)])
ax0.set_title(f"Cached grid + GT onsets (green) [exclude={GT_EXCLUDE}, vel>={GT_VEL_THRESH}]")
shade_tol(ax0, gt_onsets, sr=eval_sr, tol_samps=tol_samps, color="green", alpha=0.05)
vlines(ax0, gt_onsets, sr=eval_sr, color="green", alpha=0.85, lw=1.2, label="GRID GT (filtered)")
ax0.legend(loc="upper right", frameon=False)

ax = axes[1]
ax.plot(t[:n], y_ref[:n], color="0.25", linewidth=0.9)
ax.set_title("Reference waveform + GT (green) + ref-audio onsets (gray, optional)")
shade_tol(ax, gt_onsets, sr=eval_sr, tol_samps=tol_samps, color="green", alpha=0.06)
vlines(ax, gt_onsets, sr=eval_sr, color="green", alpha=0.85, lw=1.2, label="GRID GT (filtered)")
vlines(ax, ref_audio_onsets, sr=eval_sr, color="0.6", alpha=0.35, lw=1.0, label="ref audio onsets")
ax.legend(loc="upper right", frameon=False)

for i, sys in enumerate(SYSTEMS, 2):
    ax = axes[i]
    y = pred_audio[sys]
    m = scores[sys]
    ax.plot(t[:n], y[:n], color="C0", linewidth=0.9)
    ax.set_title(f"{sys}: pred onsets (red) vs GT (green);  P={m['onset_precision']:.3f} R={m['onset_recall']:.3f} F1={m['onset_f1']:.3f}")
    shade_tol(ax, gt_onsets, sr=eval_sr, tol_samps=tol_samps, color="green", alpha=0.05)
    vlines(ax, gt_onsets, sr=eval_sr, color="green", alpha=0.50, lw=1.0, label="GT")
    vlines(ax, pred_onsets[sys], sr=eval_sr, color="red", alpha=0.70, lw=1.2, label="pred")
    if i == 2:
        ax.legend(loc="upper right", frameon=False)

axes[-1].set_xlabel("seconds")
plt.tight_layout()
plt.show()


# FAD check

In [None]:
# Sanity checks for FAD(PANN): does it separate obvious wrong audio?
# - FAD(ref, ref) should be ~0
# - FAD(ref, pred) should be > 0
# - FAD(ref, shuffled_pred) should be >= FAD(ref, pred)
# - FAD(ref, noise) should be much larger than ref/pred
# - FAD(ref, oracle) should usually be <= FAD(ref, pred)

from pathlib import Path
import random
import shutil
import numpy as np

from frechet_audio_distance import FrechetAudioDistance

RUN = Path("artifacts/pred/small_one_kit")  # <-- change
SYS = "encodec"  # <-- change: "dac" | "encodec" | "xcodec"
EVAL_SR = 32000  # <-- must match what eval saved

ref_dir    = RUN / "ref_by_system" / SYS
pred_dir   = RUN / "pred" / SYS
oracle_dir = RUN / "oracle" / SYS

assert ref_dir.is_dir(), ref_dir
assert pred_dir.is_dir(), pred_dir

fad = FrechetAudioDistance(
    model_name="pann",
    sample_rate=int(EVAL_SR),
    use_pca=False,
    use_activation=False,
    verbose=False,
)

def _list_wavs(d: Path):
    return sorted([p for p in d.rglob("*.wav") if p.is_file()])

def _make_clean_dir(d: Path):
    if d.exists():
        shutil.rmtree(d)
    d.mkdir(parents=True, exist_ok=True)

def _copy_subset(src: Path, dst: Path, wavs: list[Path]):
    _make_clean_dir(dst)
    for p in wavs:
        rel = p.relative_to(src)
        out = dst / rel
        out.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(p, out)

def _shuffled_pred_subset(ref_wavs: list[Path], pred_wavs: list[Path], dst: Path):
    # copy preds but with filenames matching refs (pairing wrong items)
    _make_clean_dir(dst)
    pred_shuf = pred_wavs[:]
    random.shuffle(pred_shuf)
    for r, p in zip(ref_wavs, pred_shuf):
        out = dst / r.name  # flatten OK for FAD
        shutil.copy2(p, out)

def _noise_like_ref_subset(ref_wavs: list[Path], dst: Path):
    # white noise with same length as each ref file (requires soundfile)
    import soundfile as sf
    _make_clean_dir(dst)
    rng = np.random.default_rng(0)
    for r in ref_wavs:
        y, sr = sf.read(r, dtype="float32", always_2d=False)
        assert int(sr) == int(EVAL_SR), (sr, EVAL_SR, r)
        if y.ndim > 1:
            y = y.mean(axis=-1)
        n = rng.standard_normal(size=y.shape[0]).astype("float32")
        # match RMS roughly
        rms_y = float(np.sqrt(np.mean(np.square(y)) + 1e-12))
        rms_n = float(np.sqrt(np.mean(np.square(n)) + 1e-12))
        n = n * (rms_y / max(1e-8, rms_n))
        sf.write(dst / r.name, n, int(EVAL_SR))

def score(a: Path, b: Path) -> float:
    return float(fad.score(str(a), str(b), dtype="float32"))

# ---- run checks on a subset for speed ----
ref_wavs  = _list_wavs(ref_dir)
pred_wavs = _list_wavs(pred_dir)
assert len(ref_wavs) == len(pred_wavs), (len(ref_wavs), len(pred_wavs))

N = min(512, len(ref_wavs))  # <-- increase for stability
random.seed(0)
idx = random.sample(range(len(ref_wavs)), k=N)
ref_sub  = [ref_wavs[i] for i in idx]
pred_sub = [pred_wavs[i] for i in idx]

TMP = RUN / "fad_sanity_tmp" / SYS
ref_tmp     = TMP / "ref"
pred_tmp    = TMP / "pred"
shuf_tmp    = TMP / "pred_shuffled"
noise_tmp   = TMP / "noise"
oracle_tmp  = TMP / "oracle"

_copy_subset(ref_dir, ref_tmp, ref_sub)
_copy_subset(pred_dir, pred_tmp, pred_sub)
_shuffled_pred_subset(ref_sub, pred_sub, shuf_tmp)
_noise_like_ref_subset(ref_sub, noise_tmp)

print(f"Subset N={N} @ sr={EVAL_SR} sys={SYS}")
fad_ref_ref   = score(ref_tmp, ref_tmp)
fad_ref_pred  = score(ref_tmp, pred_tmp)
fad_ref_shuf  = score(ref_tmp, shuf_tmp)
fad_ref_noise = score(ref_tmp, noise_tmp)

print(f"FAD(ref, ref)        = {fad_ref_ref:.8e}")
print(f"FAD(ref, pred)       = {fad_ref_pred:.8e}")
print(f"FAD(ref, shuffled)   = {fad_ref_shuf:.8e}")
print(f"FAD(ref, noise)      = {fad_ref_noise:.8e}")

if oracle_dir.is_dir():
    oracle_wavs = _list_wavs(oracle_dir)
    if len(oracle_wavs) >= N:
        oracle_sub = [oracle_wavs[i] for i in idx]
        _copy_subset(oracle_dir, oracle_tmp, oracle_sub)
        fad_ref_oracle = score(ref_tmp, oracle_tmp)
        print(f"FAD(ref, oracle)     = {fad_ref_oracle:.8e}")
    else:
        print("Oracle present but not enough wavs for subset.")

print("Tmp written to:", TMP)


In [None]:
# Optional: check if FAD is mostly driven by loudness (RMS) mismatch
# (If FAD barely changes after RMS-normalizing, it’s not sensitive to your errors.)

from pathlib import Path
import numpy as np
import shutil
import soundfile as sf
from frechet_audio_distance import FrechetAudioDistance

RUN = Path("artifacts/pred/small_one_kit")  # <-- change
SYS = "encodec"
EVAL_SR = 32000

ref_dir  = RUN / "ref_by_system" / SYS
pred_dir = RUN / "pred" / SYS

fad = FrechetAudioDistance(model_name="pann", sample_rate=int(EVAL_SR), use_pca=False, use_activation=False, verbose=False)

def _list_wavs(d: Path):
    return sorted([p for p in d.rglob("*.wav") if p.is_file()])

def _make_clean_dir(d: Path):
    if d.exists():
        shutil.rmtree(d)
    d.mkdir(parents=True, exist_ok=True)

def _rms_normalize_tree(src: Path, dst: Path, target_rms: float = 0.05):
    _make_clean_dir(dst)
    for p in _list_wavs(src):
        y, sr = sf.read(p, dtype="float32", always_2d=False)
        assert int(sr) == int(EVAL_SR), (sr, EVAL_SR, p)
        if y.ndim > 1:
            y = y.mean(axis=-1)
        rms = float(np.sqrt(np.mean(np.square(y)) + 1e-12))
        y = y * (float(target_rms) / max(1e-8, rms))
        out = dst / p.name
        sf.write(out, y, int(EVAL_SR))

TMP = RUN / "fad_rms_tmp" / SYS
ref_n = TMP / "ref_norm"
pred_n = TMP / "pred_norm"

ref = RUN / "fad_sanity_tmp" / SYS / "ref"     # reuse subset from previous snippet if present
pred = RUN / "fad_sanity_tmp" / SYS / "pred"

if not ref.is_dir():
    ref = ref_dir
if not pred.is_dir():
    pred = pred_dir

_rms_normalize_tree(ref, ref_n, target_rms=0.05)
_rms_normalize_tree(pred, pred_n, target_rms=0.05)

fad_raw = float(fad.score(str(ref), str(pred), dtype="float32"))
fad_norm = float(fad.score(str(ref_n), str(pred_n), dtype="float32"))

print(f"FAD raw  = {fad_raw:.8e}")
print(f"FAD norm = {fad_norm:.8e}")
print("Tmp written to:", TMP)


In [None]:
# 1) Verify the "noise" files are actually noise (not accidentally refs)
import numpy as np
import soundfile as sf
from pathlib import Path

RUN = Path("artifacts/pred/small_one_kit")
SYS = "encodec"
ref_dir  = RUN / "fad_sanity_tmp" / SYS / "ref"
noise_dir = RUN / "fad_sanity_tmp" / SYS / "noise"

p = sorted(ref_dir.glob("*.wav"))[0]
r, sr = sf.read(p, dtype="float32")
n, sr2 = sf.read(noise_dir / p.name, dtype="float32")
m = min(len(r), len(n))
r = r[:m]; n = n[:m]

print("sr:", sr, sr2)
print("rms ref  :", float(np.sqrt(np.mean(r*r) + 1e-12)))
print("rms noise:", float(np.sqrt(np.mean(n*n) + 1e-12)))
print("corr(ref,noise):", float(np.corrcoef(r, n)[0,1]))
print("max|ref-noise|:", float(np.max(np.abs(r - n))))


In [None]:
# 2) Make more extreme baselines: silence + full-scale noise (no RMS matching)
#    Expect: FAD(ref, silence) and FAD(ref, loud_noise) >> FAD(ref, pred)
import numpy as np
import soundfile as sf
import shutil
from pathlib import Path
from frechet_audio_distance import FrechetAudioDistance

RUN = Path("artifacts/pred/small_one_kit")
SYS = "encodec"
EVAL_SR = 32000

ref_tmp = RUN / "fad_sanity_tmp" / SYS / "ref"
sil_dir = RUN / "fad_sanity_tmp" / SYS / "silence"
ln_dir  = RUN / "fad_sanity_tmp" / SYS / "loud_noise"

def clean(d: Path):
    if d.exists(): shutil.rmtree(d)
    d.mkdir(parents=True, exist_ok=True)

clean(sil_dir); clean(ln_dir)

for p in sorted(ref_tmp.glob("*.wav")):
    y, sr = sf.read(p, dtype="float32")
    assert int(sr) == int(EVAL_SR)
    sf.write(sil_dir / p.name, np.zeros_like(y), int(EVAL_SR))
    # loud full-scale noise (no RMS match)
    rng = np.random.default_rng(0)
    n = rng.standard_normal(size=y.shape).astype("float32")
    n = n / (np.max(np.abs(n)) + 1e-8) * 0.95
    sf.write(ln_dir / p.name, n, int(EVAL_SR))

fad = FrechetAudioDistance(model_name="pann", sample_rate=int(EVAL_SR), use_pca=False, use_activation=False, verbose=False)
print("FAD(ref, silence)   =", float(fad.score(str(ref_tmp), str(sil_dir), dtype="float32")))
print("FAD(ref, loud_noise)=", float(fad.score(str(ref_tmp), str(ln_dir), dtype="float32")))


In [None]:

import zipfile
from pathlib import Path
import numpy as np

CACHE = Path("cache/dac_allkits")
bad = []
for p in sorted((CACHE/"items").glob("*.npz")):
    try:
        with np.load(p, allow_pickle=False) as d:
            _ = d.files
    except Exception as e:
        if isinstance(e, zipfile.BadZipFile) or "BadZipFile" in str(e):
            bad.append(p)
print("bad:", len(bad))
for p in bad[:50]:
    print(p)



In [None]:

import zipfile
from pathlib import Path
import numpy as np

CACHE = Path("cache/dac_allkits")
bad = []
for p in sorted((CACHE/"items").glob("*.npz")):
    try:
        with np.load(p, allow_pickle=False) as d:
            _ = d.files
    except Exception as e:
        if isinstance(e, zipfile.BadZipFile) or "BadZipFile" in str(e):
            bad.append(p)
for p in bad:
    p.unlink(missing_ok=True)
print("deleted:", len(bad))



In [None]:
# Test whether PAD->0 before decoding fixes “noisy” big-model preds.
# Plays A/B audio (raw decode vs PAD-mapped decode) for a few items.

from pathlib import Path
import json, numpy as np, torch
from IPython.display import Audio, display

from midigroove_poc import expressivegrid as eg
from midigroove_poc.eval import _load_audio_segment, _resample_linear
from data.codecs import decode_tokens_to_audio

CACHE_DIR = Path("cache/encodec_acoustic")  # use the cache matching the system you’re testing
CKPT_PATH = Path("artifacts/checkpoints/encodec_big_single_kit.pt")
DEVICE = "cuda:0"
DECODE_DEVICE = "cuda:0"
EVAL_SR = 32000
N_EX = 5  # how many examples to audition

# --- load a few test items ---
manifest = sorted(CACHE_DIR.glob("manifest_midigroove_test_*.jsonl"))[0]
recs = [json.loads(x) for x in manifest.read_text().splitlines()[:N_EX]]
npz_paths = [Path(r["npz"]) for r in recs]
print("NPZs:", [p.name for p in npz_paths])

# --- load ckpt + build model ---
ckpt = torch.load(CKPT_PATH, map_location="cpu")
state = ckpt["model"]
cfg = ckpt.get("cfg", {}) if isinstance(ckpt.get("cfg", {}), dict) else {}
num_codebooks = int(ckpt["num_codebooks"])
in_dim = int(ckpt["in_dim"])
codec = str(cfg.get("encoder_model", "encodec") or "encodec").strip().lower()

# resolve vocab/pad exactly like eval does
cfg2 = dict(cfg)
cfg2["in_dim"] = in_dim
cfg2["num_codebooks"] = num_codebooks
if "use_kit_name" not in cfg2:
    cfg2["use_kit_name"] = bool("kit_name_emb.weight" in state)
if "vocab_size" not in cfg2:
    vs = eg._infer_vocab_size_from_state_dict(state, num_codebooks=num_codebooks)
    if vs is not None:
        cfg2["vocab_size"] = int(vs)
if "vocab_size" in cfg2 and ("pad_id" not in cfg2 or "codebook_size" not in cfg2):
    vs2 = int(cfg2.get("vocab_size", 0) or 0)
    if vs2 > 1:
        cfg2.setdefault("pad_id", vs2 - 1)
        cfg2.setdefault("codebook_size", vs2 - 1)

codebook_size = int(cfg2.get("codebook_size", eg._default_codebook_size_for_encoder(codec)))
pad_id = int(cfg2.get("pad_id", eg._pad_id_for_codebook(codebook_size)))
vocab_size = int(cfg2.get("vocab_size", eg._vocab_size_for_codebook(codebook_size)))
cfg2["codebook_size"] = codebook_size
cfg2["pad_id"] = pad_id
cfg2["vocab_size"] = vocab_size

print("codec:", codec, "codebook_size:", codebook_size, "pad_id:", pad_id, "vocab_size:", vocab_size)

model = eg._build_model(num_codebooks=num_codebooks, in_dim=in_dim, cfg=cfg2)
model.load_state_dict(state, strict=True)
model.to(torch.device(DEVICE)).eval()

# helper: build grid the same way training does
def load_item(npz_path: Path):
    with np.load(npz_path, allow_pickle=False) as d:
        ex = {k: np.asarray(d[k]) for k in d.files}
    drum_hit = ex["drum_hit"].astype(np.float32)
    drum_vel = ex.get("drum_vel", np.zeros_like(drum_hit)).astype(np.float32)
    drum_sus = ex.get("drum_sustain", np.zeros_like(drum_hit)).astype(np.float32)
    hh_cc4   = ex.get("hh_open_cc4", np.zeros((drum_hit.shape[1],), np.float32)).astype(np.float32)
    beat_pos = ex["beat_pos"].astype(np.int64)
    bpm = float(ex.get("bpm", 120.0))
    drummer_id = int(ex.get("drummer_id", 0))
    kit_name_id = int(ex.get("kit_name_id", 0))
    tgt = ex["tgt"].astype(np.int64)

    include_sustain = bool(cfg2.get("include_sustain", False))
    include_hh_cc4  = bool(cfg2.get("include_hh_cc4", False))

    pieces = [drum_hit, drum_vel]
    if include_sustain:
        pieces.append(drum_sus)
    if include_hh_cc4:
        pieces.append(hh_cc4[None, :])
    grid = np.concatenate(pieces, axis=0).astype(np.float32)

    # reference audio segment
    audio_path = Path(str(ex["audio_path"].item()))
    sr_native = int(ex["sr"].item())
    start_sec = float(ex["start_sec"].item())
    window_seconds = float(ex["window_seconds"].item())
    start_sample = int(round(start_sec * sr_native))
    window_samples = int(round(window_seconds * sr_native))
    ref, sr_ref = _load_audio_segment(audio_path, start_sample=start_sample, num_samples=window_samples)
    ref_rs = _resample_linear(ref, sr_ref, EVAL_SR)

    return grid, beat_pos, bpm, drummer_id, kit_name_id, tgt, ref_rs

for i, p in enumerate(npz_paths):
    grid, beat_pos, bpm, drummer_id, kit_name_id, tgt, ref_rs = load_item(p)
    T = int(grid.shape[1])

    grid_t = torch.from_numpy(grid).unsqueeze(0).to(DEVICE)                 # [1,F,T]
    beat_t = torch.from_numpy(beat_pos).unsqueeze(0).to(DEVICE)             # [1,T]
    bpm_t  = torch.tensor([bpm], dtype=torch.float32, device=DEVICE)
    dr_t   = torch.tensor([drummer_id], dtype=torch.long, device=DEVICE)
    kit_t  = torch.tensor([kit_name_id], dtype=torch.long, device=DEVICE)
    valid  = torch.ones((1, T), dtype=torch.bool, device=DEVICE)

    with torch.inference_mode():
        logits = model(grid=grid_t, beat_pos=beat_t, bpm=bpm_t, drummer_id=dr_t,
                       kit_name_id=kit_t if bool(cfg2.get("use_kit_name", True)) else None,
                       valid_mask=valid)
        pred = logits.argmax(dim=-1).squeeze(0).to(torch.long).cpu()        # [C,T]

    pad_rate = float((pred == pad_id).float().mean().item())
    pred_raw = pred
    pred_fix = torch.where(pred == pad_id, torch.zeros_like(pred), pred)

    audio_raw_b1, sr_raw = decode_tokens_to_audio(pred_raw, encoder_model=codec, device=DECODE_DEVICE)
    audio_fix_b1, sr_fix = decode_tokens_to_audio(pred_fix, encoder_model=codec, device=DECODE_DEVICE)

    raw_rs = _resample_linear(audio_raw_b1[0], sr_raw, EVAL_SR)
    fix_rs = _resample_linear(audio_fix_b1[0], sr_fix, EVAL_SR)

    N = min(ref_rs.size, raw_rs.size, fix_rs.size)
    ref_rs2 = ref_rs[:N]
    raw_rs2 = raw_rs[:N]
    fix_rs2 = fix_rs[:N]

    print(f"\n{i+1}/{len(npz_paths)} {p.name}  pad_rate={pad_rate:.3f}  T={T}  C={pred.shape[0]}")
    print("Reference:")
    display(Audio(ref_rs2, rate=EVAL_SR))
    print("Pred decode (RAW, may include PAD):")
    display(Audio(raw_rs2, rate=EVAL_SR))
    print("Pred decode (PAD->0 before decode):")
    display(Audio(fix_rs2, rate=EVAL_SR))


In [None]:
# Diagnose “big model sounds like noise”: compare SMALL vs BIG on the same cache items.
# Prints PAD rate + token diversity, and plays: reference, GT decode, small pred, big pred.

from pathlib import Path
import json, numpy as np, torch
from IPython.display import Audio, display

from midigroove_poc import expressivegrid as eg
from midigroove_poc.eval import _load_audio_segment, _resample_linear
from data.codecs import decode_tokens_to_audio

CACHE_DIR = Path("cache/encodec_acoustic")
CKPT_SMALL = Path("artifacts/checkpoints/encodec_small_single_kit.pt")
CKPT_BIG   = Path("artifacts/checkpoints/encodec_big_single_kit.pt")

DEVICE = "cuda:0"
DECODE_DEVICE = "cuda:0"
EVAL_SR = 32000
N_EX = 4

# ---- pick a few test items deterministically ----
manifest = sorted(CACHE_DIR.glob("manifest_midigroove_test_*.jsonl"))[0]
recs = [json.loads(x) for x in manifest.read_text().splitlines()[:N_EX]]
npz_paths = [Path(r["npz"]) for r in recs]
print("NPZs:", [p.name for p in npz_paths])

def build_model(ckpt_path: Path):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt["model"]
    cfg = ckpt.get("cfg", {}) if isinstance(ckpt.get("cfg", {}), dict) else {}
    num_codebooks = int(ckpt["num_codebooks"])
    in_dim = int(ckpt["in_dim"])
    codec = str(cfg.get("encoder_model", "encodec") or "encodec").strip().lower()

    cfg2 = dict(cfg)
    if "use_kit_name" not in cfg2:
        cfg2["use_kit_name"] = bool("kit_name_emb.weight" in state)

    if "vocab_size" not in cfg2:
        vs = eg._infer_vocab_size_from_state_dict(state, num_codebooks=num_codebooks)
        if vs is not None:
            cfg2["vocab_size"] = int(vs)
    if "vocab_size" in cfg2 and ("pad_id" not in cfg2 or "codebook_size" not in cfg2):
        vs2 = int(cfg2.get("vocab_size", 0) or 0)
        if vs2 > 1:
            cfg2.setdefault("pad_id", vs2 - 1)
            cfg2.setdefault("codebook_size", vs2 - 1)

    codebook_size = int(cfg2.get("codebook_size", eg._default_codebook_size_for_encoder(codec)))
    pad_id = int(cfg2.get("pad_id", eg._pad_id_for_codebook(codebook_size)))
    vocab_size = int(cfg2.get("vocab_size", eg._vocab_size_for_codebook(codebook_size)))
    cfg2["codebook_size"] = codebook_size
    cfg2["pad_id"] = pad_id
    cfg2["vocab_size"] = vocab_size

    model = eg._build_model(num_codebooks=num_codebooks, in_dim=in_dim, cfg=cfg2)
    model.load_state_dict(state, strict=True)
    model.to(torch.device(DEVICE)).eval()
    return model, cfg2, codec, pad_id

def load_item(npz_path: Path):
    with np.load(npz_path, allow_pickle=False) as d:
        ex = {k: np.asarray(d[k]) for k in d.files}

    drum_hit = ex["drum_hit"].astype(np.float32)
    drum_vel = ex.get("drum_vel", np.zeros_like(drum_hit)).astype(np.float32)
    drum_sus = ex.get("drum_sustain", np.zeros_like(drum_hit)).astype(np.float32)
    hh_cc4   = ex.get("hh_open_cc4", np.zeros((drum_hit.shape[1],), np.float32)).astype(np.float32)

    beat_pos = ex["beat_pos"].astype(np.int64)
    bpm = float(ex.get("bpm", 120.0))
    drummer_id = int(ex.get("drummer_id", 0))
    kit_name_id = int(ex.get("kit_name_id", 0))
    tgt = ex["tgt"].astype(np.int64)

    audio_path = Path(str(ex["audio_path"].item()))
    sr_native = int(ex["sr"].item())
    start_sec = float(ex["start_sec"].item())
    window_seconds = float(ex["window_seconds"].item())
    start_sample = int(round(start_sec * sr_native))
    window_samples = int(round(window_seconds * sr_native))
    ref, sr_ref = _load_audio_segment(audio_path, start_sample=start_sample, num_samples=window_samples)
    ref_rs = _resample_linear(ref, sr_ref, EVAL_SR)

    return dict(
        drum_hit=drum_hit, drum_vel=drum_vel, drum_sus=drum_sus, hh_cc4=hh_cc4,
        beat_pos=beat_pos, bpm=bpm, drummer_id=drummer_id, kit_name_id=kit_name_id,
        tgt=tgt, ref_rs=ref_rs
    )

def predict_tokens(model, cfg, pad_id, ex):
    include_sustain = bool(cfg.get("include_sustain", False))
    include_hh_cc4  = bool(cfg.get("include_hh_cc4", False))

    pieces = [ex["drum_hit"], ex["drum_vel"]]
    if include_sustain:
        pieces.append(ex["drum_sus"])
    if include_hh_cc4:
        pieces.append(ex["hh_cc4"][None, :])
    grid = np.concatenate(pieces, axis=0).astype(np.float32)
    T = grid.shape[1]

    grid_t = torch.from_numpy(grid).unsqueeze(0).to(DEVICE)
    beat_t = torch.from_numpy(ex["beat_pos"]).unsqueeze(0).to(DEVICE)
    bpm_t  = torch.tensor([ex["bpm"]], dtype=torch.float32, device=DEVICE)
    dr_t   = torch.tensor([ex["drummer_id"]], dtype=torch.long, device=DEVICE)
    kit_t  = torch.tensor([ex["kit_name_id"]], dtype=torch.long, device=DEVICE)
    valid  = torch.ones((1, T), dtype=torch.bool, device=DEVICE)

    with torch.inference_mode():
        logits = model(
            grid=grid_t, beat_pos=beat_t, bpm=bpm_t, drummer_id=dr_t,
            kit_name_id=kit_t if bool(cfg.get("use_kit_name", True)) else None,
            valid_mask=valid
        )
        pred = logits.argmax(dim=-1).squeeze(0).to(torch.long).cpu()  # [C,T]

    pad_rate = float((pred == int(pad_id)).float().mean().item())
    uniq = int(torch.unique(pred).numel())
    total = int(pred.numel())
    return pred, pad_rate, uniq, total

model_s, cfg_s, codec_s, pad_s = build_model(CKPT_SMALL)
model_b, cfg_b, codec_b, pad_b = build_model(CKPT_BIG)

assert codec_s == codec_b, (codec_s, codec_b)
codec = codec_s

print("codec:", codec)
print("small pad_id:", pad_s, "big pad_id:", pad_b)

for i, npz in enumerate(npz_paths, 1):
    ex = load_item(npz)

    pred_s, pad_rate_s, uniq_s, total_s = predict_tokens(model_s, cfg_s, pad_s, ex)
    pred_b, pad_rate_b, uniq_b, total_b = predict_tokens(model_b, cfg_b, pad_b, ex)

    # decode (also decode GT)
    tgt = torch.from_numpy(ex["tgt"]).to(torch.long)
    tgt_clean = torch.where(tgt == pad_s, torch.zeros_like(tgt), tgt)  # safe even if no PAD
    pred_s_clean = torch.where(pred_s == pad_s, torch.zeros_like(pred_s), pred_s)
    pred_b_clean = torch.where(pred_b == pad_b, torch.zeros_like(pred_b), pred_b)

    gt_b1, sr_gt = decode_tokens_to_audio(tgt_clean, encoder_model=codec, device=DECODE_DEVICE)
    s_b1,  sr_s  = decode_tokens_to_audio(pred_s_clean, encoder_model=codec, device=DECODE_DEVICE)
    b_b1,  sr_b  = decode_tokens_to_audio(pred_b_clean, encoder_model=codec, device=DECODE_DEVICE)

    ref = ex["ref_rs"]
    gt  = _resample_linear(gt_b1[0], sr_gt, EVAL_SR)
    s   = _resample_linear(s_b1[0],  sr_s,  EVAL_SR)
    b   = _resample_linear(b_b1[0],  sr_b,  EVAL_SR)
    N = min(ref.size, gt.size, s.size, b.size)
    ref, gt, s, b = ref[:N], gt[:N], s[:N], b[:N]

    print(f"\n[{i}/{len(npz_paths)}] {npz.name}")
    print(f"SMALL: pad_rate={pad_rate_s:.4f}, unique_tokens={uniq_s}/{total_s}")
    print(f"BIG:   pad_rate={pad_rate_b:.4f}, unique_tokens={uniq_b}/{total_b}")

    print("Reference (original segment):"); display(Audio(ref, rate=EVAL_SR))
    print("GT decode (oracle codec reconstruction):"); display(Audio(gt, rate=EVAL_SR))
    print("SMALL pred decode:"); display(Audio(s, rate=EVAL_SR))
    print("BIG pred decode:"); display(Audio(b, rate=EVAL_SR))
