In [None]:
import torch
from torch.utils.data import DataLoader
from retriever import FurnishedRoomSTFTDataset, compute_audio_distance, RIRRetrievalMLP
import matplotlib.pyplot as plt

# ─── Config ──────────────────────────────────────────────────────────
root            = '../data/RAF/FurnishedRoom'
ckpt_path       = './outputs/20250730_171523/rir_retrieval_model.ckpt'
grid_vec_p      = "./features.pt"
use_global_grid = True
metric          = "MAG"
# query_ids = [
#     "005501","032691","016513",
#     "002445","027043","019617","011118",
#     "002524","015172","036618","017953",
# ]

query_ids = ["038176","017044","024526","032983", "015883", "036124", "036298", "013798", "038419", "003398"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── Load model & global grid ────────────────────────────────────────
ckpt = torch.load(ckpt_path, map_location=device)
model = RIRRetrievalMLP(**ckpt["model_config"]).to(device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

if use_global_grid:
    grid_vec = torch.load(grid_vec_p).to(device)

# ─── Load full evaluation & reference sets ───────────────────────────
q_ds = FurnishedRoomSTFTDataset(root, split="test",
                                return_wav=True, mode="reference")
g_ds = FurnishedRoomSTFTDataset(root, split="reference",
                                return_wav=True, mode="reference")

# build quick-access lists/dicts
q_items = [q_ds[i] for i in range(len(q_ds)) if q_ds[i]['id'] in query_ids]
q_id2item = {s['id']: s for s in q_items}

g_items = [g_ds[i] for i in range(len(g_ds))]
g_ids    = [s['id'] for s in g_items]

# ─── Precompute gallery embeddings & audio features ─────────────────
# Embeddings
Gg    = grid_vec.unsqueeze(0).expand(len(g_items), -1).to(device)
mic_g = torch.stack([s['mic_pose']    for s in g_items]).to(device)
src_g = torch.stack([s['source_pose'] for s in g_items]).to(device)
rot_g = torch.stack([s['rot']         for s in g_items]).to(device)
with torch.no_grad():
    Zg = model(Gg, mic_g, src_g, rot_g)

# Audio features
flat_g = torch.stack([s['stft'] for s in g_items]).to(device)
wav_g  = torch.stack([s['wav']             for s in g_items]).to(device)

# ─── Loop over queries & plot ────────────────────────────────────────
for qid in query_ids:
    q = q_id2item[qid]

    # — embeddings prediction —
    Gq    = grid_vec.unsqueeze(0).to(device)
    mic_q = q['mic_pose'].unsqueeze(0).to(device)
    src_q = q['source_pose'].unsqueeze(0).to(device)
    rot_q = q['rot'].unsqueeze(0).to(device)
    with torch.no_grad():
        Zq = model(Gq, mic_q, src_q, rot_q)

    S = (Zq @ Zg.T).squeeze(0)
    pred_top3 = S.argsort(descending=True)[:3]

    # — ground-truth nearest via audio metric —
    flat_q = q['stft'].unsqueeze(0).to(device)
    wav_q  = q['wav'].unsqueeze(0).to(device)
    all_flat = torch.cat([flat_q, flat_g], dim=0)
    all_wav  = torch.cat([wav_q,  wav_g],  dim=0)
    D_full = compute_audio_distance(all_flat, all_wav, metric=metric)
    D_qg   = D_full[0, 1:]              # distances query→each gallery
    gt_top3 = D_qg.argsort()[:3]

    # — Plot setup —
    fig, axes = plt.subplots(2, 4, figsize=(20, 6))
    fig.suptitle(f"Query ID: {qid}", fontsize=16)

    # Row 1: Query + GT-1/2/3
    axes[0,0].plot(q['wav'].cpu())
    axes[0,0].set_title("Query")
    # axes[0,0].axis('off')

    for i, gi in enumerate(gt_top3.tolist()):
        wav = g_items[gi]['wav'].cpu()
        gid = g_ids[gi]
        # compute cross-distances correctly on a 2×2 block:
        flat_pair = torch.cat([flat_q, flat_g[gi:gi+1]], dim=0)
        wav_pair  = torch.cat([wav_q,  wav_g[gi:gi+1]],  dim=0)
        D2_MAG = compute_audio_distance(flat_pair, wav_pair, metric="MAG")
        D2_SC  = compute_audio_distance(flat_pair, wav_pair, metric="SPL")
        D2_LSD = compute_audio_distance(flat_pair, wav_pair, metric="MSE")
        mag = D2_MAG[0,1].item()
        sc  = D2_SC[0,1].item()
        lsd = D2_LSD[0,1].item()

        axes[0, i+1].plot(wav)
        axes[0, i+1].set_title(f"GT#{i+1}: {gid}\nMAG={mag:.2f} SPL={sc:.4f} MSE={lsd:.2e}")
        # axes[0, i+1].axis('off')

    # Row 2: Query + Pred-1/2/3
    axes[1,0].plot(q['wav'].cpu())
    axes[1,0].set_title("Query")
    # axes[1,0].axis('off')

    for i, gi in enumerate(pred_top3.tolist()):
        wav   = g_items[gi]['wav'].cpu()
        gid   = g_ids[gi]
        gt_rank = (D_qg.argsort() == gi).nonzero(as_tuple=True)[0].item() + 1

        flat_pair = torch.cat([flat_q, flat_g[gi:gi+1]], dim=0)
        wav_pair  = torch.cat([wav_q,  wav_g[gi:gi+1]],  dim=0)
        D2_MAG = compute_audio_distance(flat_pair, wav_pair, metric="MAG")
        D2_SC  = compute_audio_distance(flat_pair, wav_pair, metric="SPL")
        D2_LSD = compute_audio_distance(flat_pair, wav_pair, metric="MSE")
        mag = D2_MAG[0,1].item()
        sc  = D2_SC[0,1].item()
        lsd = D2_LSD[0,1].item()

        axes[1, i+1].plot(wav)
        axes[1, i+1].set_title(
            f"Pred#{i+1}: {gid}\nGT-Rank={gt_rank} MAG={mag:.2f} SPL={sc:.4f} MSE={lsd:.2e}"
        )
        # axes[1, i+1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Single‐cell evaluation notebook (fixed splits)

import os
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from retriever import (
    FurnishedRoomSTFTDataset,
    RIRRetrievalMLP,
    compute_audio_distance
)

# ------------- Config -------------
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root         = '../data/RAF/FurnishedRoom'
model_path   = './outputs/20250729_021957/rir_retrieval_model.pth'
feature_path = './features.pt'

sort_metric  = 'MAG'          # GT‑distance metric
sim_metric   = 'cosine'       # model similarity
TOP_K        = 3
QUERY_IDS    = [
    "005501","032691","016513",
    "002445","027043","019617","011118",
    "002524",
    "015172",
    "036618",
    "017953",
]

# ------------- Load model -------------
model = RIRRetrievalMLP().to(device).eval()
ckpt  = torch.load(model_path, map_location=device)
state = model.state_dict()
state.update({k: v for k, v in ckpt.items() if k in state and v.shape == state[k].shape})
model.load_state_dict(state)

# ------------- Prepare datasets -------------
#  * Gallery = references split, loaded from data‑split‑references.json
ds_ref    = FurnishedRoomSTFTDataset(root, split='reference', return_wav=True, mode='reference')
loader_ref= DataLoader(ds_ref, batch_size=64, pin_memory=True)

#  * Queries = validation split from the *same* JSON
ds_q      = FurnishedRoomSTFTDataset(root, split='validation', return_wav=True, mode='reference')

# ------------- Precompute reference features -------------
ref_ids, ref_flats, ref_wavs, ref_embs = [], [], [], []
eats = torch.load(feature_path).to(device)   # global grid feature

def pair_metrics(flat_q, wav_q, ref_flat, ref_wav):
    # flat_q: [1, M], wav_q: [T]
    pair_flat = torch.cat([flat_q, ref_flat.unsqueeze(0)], dim=0)
    pair_wav  = torch.cat([wav_q.unsqueeze(0), ref_wav.unsqueeze(0)], dim=0)
    mag = compute_audio_distance(pair_flat, pair_wav, metric='MAG2')[0,1].item()
    sc  = compute_audio_distance(pair_flat, pair_wav, metric='SPL' )[0,1].item()
    lsd = compute_audio_distance(pair_flat, pair_wav, metric='EDC')[0,1].item()
    return mag, sc, lsd

for batch in loader_ref:
    B = len(batch['id'])
    ref_ids.extend(batch['id'])

    stft = batch['stft'].to(device)           # [B, F, T]
    wav  = batch['wav'].to(device)            # [B, T]
    flat = stft               # [B, M]
    G    = eats.unsqueeze(0).expand(B, -1)    # [B, E]

    with torch.no_grad():
        emb = model(
            G,
            batch['mic_pose'].to(device),
            batch['source_pose'].to(device),
            batch['rot'].to(device)
        )

    ref_flats.append(flat.cpu())
    ref_wavs.append(wav.cpu())
    ref_embs.append(emb.cpu())

ref_flats = torch.cat(ref_flats)              # [N_ref, M]
ref_wavs  = torch.cat(ref_wavs)               # [N_ref, T]
ref_embs  = torch.cat(ref_embs)               # [N_ref, D]

# ------------- Diversity of references (all metrics) -------------
for metric in ['MAG', 'LSD', 'SC']:
    D = compute_audio_distance(
        ref_flats.to(device),
        ref_wavs.to(device),
        metric=metric
    ).cpu().numpy()                           # [N_ref, N_ref]

    N = D.shape[0]
    mask = ~np.eye(N, dtype=bool)
    vals = D[mask]

    mean_d   = vals.mean()
    std_d    = vals.std()
    median_d = np.median(vals)
    uniq, cnt= np.unique(np.round(vals, 3), return_counts=True)
    mode_d   = uniq[cnt.argmax()]

    print(f"[{metric}] mean={mean_d:.3f}, median={median_d:.3f}, "
          f"mode≈{mode_d:.3f}, std={std_d:.3f}")
    
    
# [MAG] mean=2.404, median=1.884, mode≈1.146, std=1.967
# [LSD] mean=0.724, median=0.683, mode≈0.601, std=0.216
# [SC] mean=1.489, median=0.926, mode≈0.915, std=1.629

In [None]:
import torch
from torch.utils.data import DataLoader
from retriever import (
    FurnishedRoomSTFTDataset,
    compute_audio_distance,
    RIRRetrievalMLP,
)

# ─── configs ─────────────────────────────────────────────────────
root         = '../data/RAF/FurnishedRoom'
ckpt_path   = './outputs/20250806_194739/rir_retrieval_model.ckpt'
feats_map_p = "./features.pt"   # dict: sample_id → Tensor[grid_feat_dim]
grid_vec_p  = "/path/to/grid_vec.pt"    # Tensor[grid_feat_dim]
use_global_grid = True                 # set True if you trained with a single global grid
metric      = "SPL"                     # ground‐truth audio metric
Ks          = [1, 2, 3]                 # for map@K and top‐K acc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model checkpoint and global grid
ckpt = torch.load(ckpt_path, map_location=device)
grid_vec = torch.load("./features.pt").to(device)  # Tensor[1024], used for all samples

# Reconstruct model from checkpoint config
model = RIRRetrievalMLP(**ckpt["model_config"]).to(device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# ─── prepare data ─────────────────────────────────────────────────
q_ds = FurnishedRoomSTFTDataset(root, split="validation",
                                return_wav=True, mode="reference")
g_ds = FurnishedRoomSTFTDataset(root, split="reference",
                                return_wav=True, mode="reference")
q_loader = DataLoader(q_ds, batch_size=len(q_ds), shuffle=False)
g_loader = DataLoader(g_ds, batch_size=len(g_ds), shuffle=False)
batch_q = next(iter(q_loader))
batch_g = next(iter(g_loader))

# ─── build grid inputs ────────────────────────────────────────────
Gq = grid_vec.unsqueeze(0).expand(len(q_ds), -1).to(device)
Gg = grid_vec.unsqueeze(0).expand(len(g_ds), -1).to(device)

# ─── forward to get embeddings ────────────────────────────────────
with torch.no_grad():
    Zq = model(
        Gq,
        batch_q["mic_pose"].to(device),
        batch_q["source_pose"].to(device),
        batch_q["rot"].to(device),
    )
    Zg = model(
        Gg,
        batch_g["mic_pose"].to(device),
        batch_g["source_pose"].to(device),
        batch_g["rot"].to(device),
    )

# ─── prepare audio‐distance for ground truth ───────────────────────
flat_q, flat_g = (
    batch_q["stft"].to(device),
    batch_g["stft"].to(device),
)
wav_q, wav_g = batch_q["wav"].to(device), batch_g["wav"].to(device)
all_flats = torch.cat([flat_q, flat_g], dim=0)
all_wavs  = torch.cat([wav_q, wav_g], dim=0)
D_full    = compute_audio_distance(all_flats, all_wavs, metric=metric)
D_qg      = D_full[: len(q_ds), len(q_ds) :]

# ground‐truth nearest
gt_idxs   = D_qg.argmin(dim=1)          # for each query, gallery‐index of true NN
gt_dists  = D_qg[torch.arange(len(q_ds)), gt_idxs]

# ─── prediction via cosine sim ────────────────────────────────────
S_pred    = Zq @ Zg.t()
pred_ranks= S_pred.argsort(dim=1, descending=True)  # [Q, G]

# metrics
ranks = torch.tensor([
    (pred_ranks[i] == gt_idxs[i]).nonzero().item()
    for i in range(len(q_ds))
], dtype=torch.int)

map_at_k = {
    k: (ranks < k).float().mean().item()
    for k in Ks
}
mean_rank = (ranks.float() + 1).mean().item()  # 1‐indexed
topk_acc = {
    k: (ranks < k).float().mean().item()  # same as map@k here
    for k in Ks
}

# mean audio‐metric diff between GT and pred@1
pred1_idxs = pred_ranks[:, 0]
pred1_dists= D_qg[torch.arange(len(q_ds)), pred1_idxs]
mean_diff  = (pred1_dists - gt_dists).abs().mean().item()

# ─── report ────────────────────────────────────────────────────────
print(f"map@1: {map_at_k[1]:.3f}, map@2: {map_at_k[2]:.3f}, map@3: {map_at_k[3]:.3f}")
print(f"mean_rank: {mean_rank:.3f}")
print("top‑K accuracies:", {k: f"{topk_acc[k]:.3f}" for k in Ks})
print(f"mean |Δ{metric}| (pred@1 vs GT): {mean_diff:.4f}")

In [None]:
import torch
from torch.utils.data import DataLoader
from retriever import FurnishedRoomSTFTDataset, RIRRetrievalMLP
import json

# ─── configs ─────────────────────────────────────────────────────
root            = '../data/RAF/FurnishedRoom'
ckpt_path       = './outputs/20250806_194739/rir_retrieval_model.ckpt'
# if you trained with a global grid vector, point this to your .pt file
grid_vec_p      = "./features.pt"
use_global_grid = True
topk            = 10  # retrieve top-5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── load model and global grid ──────────────────────────────────
ckpt = torch.load(ckpt_path, map_location=device)
model = RIRRetrievalMLP(**ckpt["model_config"]).to(device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

if use_global_grid:
    grid_vec = torch.load(grid_vec_p).to(device)    # Tensor[grid_feat_dim]

# ─── define splits to process ────────────────────────────────────
# evaluation corresponds to your "validation" split in the original code
splits = {
    "train":      "train",
    "evaluation": "validation",
    "test":       "test",
}

references = {}

for split_name, split_id in splits.items():
    # prepare query & gallery datasets
    q_ds = FurnishedRoomSTFTDataset(root, split=split_id,
                                    return_wav=False, mode="reference")
    g_ds = FurnishedRoomSTFTDataset(root, split="reference",
                                    return_wav=False, mode="reference")

    # load entire sets in one batch
    q_batch = next(iter(DataLoader(q_ds, batch_size=len(q_ds), shuffle=False)))
    g_batch = next(iter(DataLoader(g_ds, batch_size=len(g_ds), shuffle=False)))

    # build grid inputs
    if use_global_grid:
        Gq = grid_vec.unsqueeze(0).expand(len(q_ds), -1)
        Gg = grid_vec.unsqueeze(0).expand(len(g_ds), -1)
    else:
        raise NotImplementedError("Per-sample grid not implemented")

    # forward pass to get embeddings
    with torch.no_grad():
        Zq = model(
            Gq.to(device),
            q_batch["mic_pose"].to(device),
            q_batch["source_pose"].to(device),
            q_batch["rot"].to(device),
        )
        Zg = model(
            Gg.to(device),
            g_batch["mic_pose"].to(device),
            g_batch["source_pose"].to(device),
            g_batch["rot"].to(device),
        )

    # compute cosine‐similarity and get top-k indices
    S = Zq @ Zg.t()  # [Q, G]
    topk_idxs = S.argsort(dim=1, descending=True)[:, :topk]  # [Q, topk]

    # map each query ID to its top-k gallery IDs
    refs_for_split = {}
    for i, qid in enumerate(q_ds.ids):
        retrieved = [g_ds.ids[j] for j in topk_idxs[i].tolist()]
        refs_for_split[qid] = retrieved

    references[split_name] = refs_for_split

# ─── save to JSON ─────────────────────────────────────────────────
with open("references.json", "w") as f:
    json.dump(references, f, indent=2)

print("Wrote references.json with top-5 retrievals for each split")


In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from retriever import FurnishedRoomSTFTDataset, compute_audio_distance
import json

# ─── configs ─────────────────────────────────────────────────────
root               = '../data/RAF/FurnishedRoom'
metric             = 'SPL'       # your ground-truth audio‐distance
topk               = 5          # retrieve top-10 per query
query_batch_size   = 2048         # tune so (batch_size + G)^2 fits GPU
device             = torch.device("cuda" if torch.cuda.is_available() else "cpu")

splits = {
    "train":      "train",        # 40 000 queries
    "evaluation": "validation",   #  3 000 queries
    "test":       "test",         #  3 000 queries
}

# ─── preload gallery ONCE ────────────────────────────────────────
g_ds     = FurnishedRoomSTFTDataset(root, split="reference",
                                    return_wav=True, mode="reference")
g_loader = DataLoader(g_ds, batch_size=len(g_ds), shuffle=False)
batch_g  = next(iter(g_loader))
G        = len(g_ds)

# flatten + move gallery to GPU
flat_g = batch_g["stft"].view(G, -1).to(device)  # [G, D]
wav_g  = batch_g["wav"].to(device)               # [G, L]

references = {}

for split_name, split_id in splits.items():
    # ─── prepare query loader ────────────────────────────────────
    q_ds     = FurnishedRoomSTFTDataset(root, split=split_id,
                                        return_wav=True, mode="reference")
    Q        = len(q_ds)
    q_loader = DataLoader(q_ds, batch_size=query_batch_size,
                          shuffle=False, drop_last=False)

    refs_for_split = {}

    # ─── process each query batch on-GPU ────────────────────────
    for batch_idx, batch_q in enumerate(tqdm(q_loader, 
                                             desc=f"{split_name} batches",
                                             total=(Q + query_batch_size - 1) // query_batch_size)):
        bsize = batch_q["stft"].size(0)
        # compute which slice of q_ds.ids this batch corresponds to
        start = batch_idx * query_batch_size
        end   = start + bsize
        batch_ids = q_ds.ids[start:end]

        # move this batch to GPU
        flat_q = batch_q["stft"].view(bsize, -1).to(device)  # [bsize, D]
        wav_q  = batch_q["wav"].to(device)                   # [bsize, L]

        # compute the (bsize + G)² distance matrix on GPU
        with torch.no_grad():
            all_flats = torch.cat([flat_q, flat_g], dim=0)  # [bsize+G, D]
            all_wavs  = torch.cat([wav_q,  wav_g],  dim=0)  # [bsize+G, L]
            D_full    = compute_audio_distance(all_flats, all_wavs, metric=metric)
            # slice out only the [batch × G] block, move to CPU
            D_qg = D_full[:bsize, bsize:].cpu()             # [bsize, G]

        # pick top-k smallest distances per query
        topk_idxs = D_qg.argsort(dim=1)[:, :topk]           # [bsize, topk]

        # map each query ID → list of top-k gallery IDs
        for i, qid in enumerate(batch_ids):
            refs_for_split[qid] = [g_ds.ids[j] for j in topk_idxs[i].tolist()]

        # free GPU memory before next batch
        del flat_q, wav_q, all_flats, all_wavs, D_full, D_qg
        torch.cuda.empty_cache()

    references[split_name] = refs_for_split

# ─── save to JSON ─────────────────────────────────────────────────
with open("gt_references.json", "w") as f:
    json.dump(references, f, indent=2)

print(f"Wrote gt_references.json with top-{topk} ground-truth retrievals (metric={metric})")

In [None]:
# --- CELL 1: Build data, compute metrics, cache them (NO PLOTS) ---

import os
import glob
import json
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
from retriever import compute_audio_distance, FurnishedRoomSTFTDataset
from torchaudio.transforms import GriffinLim
import sys
sys.path.append('../NeRAF')
from NeRAF_helper import compute_t60, evaluate_edt, evaluate_clarity

# ---------------- Params ----------------
eval_pattern = "../eval_results/furnishedroom_2/renders/eval_*.npy"
max_files = 200
root_dir = "../data/RAF/FurnishedRoom"
refs_file = "./references.json"
sample_rate = 48000
CACHE_PATH = "./records_cache_furnishedroom_2.pkl"

# Griffin-Lim ISTFT setup
n_fft = (513 - 1) * 2
win_length = 512
hop_length = 256
power = 1
istft = GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=power)

# ─── Helpers ──────────────────────────────────────────────────────────────────
def _as_2d_numpy(wav):
    if isinstance(wav, torch.Tensor):
        arr = wav.detach().cpu().numpy()
    else:
        arr = np.asarray(wav)
    if arr.ndim == 1:
        arr = arr[None, :]
    return arr

def room_metric_diffs(wav_gt, wav_x, fs):
    gt = _as_2d_numpy(wav_gt)
    xx = _as_2d_numpy(wav_x)
    L = min(gt.shape[1], xx.shape[1])
    gt = gt[:, :L]
    xx = xx[:, :L]

    # T60: mean % error, invalids -> 100%
    t60_gt, t60_x = compute_t60(gt, xx, fs=fs, advanced=True)
    t60_gt = np.atleast_1d(t60_gt).astype(float)
    t60_x  = np.atleast_1d(t60_x).astype(float)
    with np.errstate(divide='ignore', invalid='ignore'):
        t60_diff = np.abs(t60_x - t60_gt) / (np.abs(t60_gt) + 1e-12)
    invalid_mask = (t60_gt < -0.5) | (t60_x < -0.5)
    t60_diff[invalid_mask] = 1.0
    t60_err_pct = float(np.mean(t60_diff) * 100.0)

    # EDT & C50: mean absolute differences
    edt_gt, edt_x = evaluate_edt(xx, gt, fs=fs)
    edt_mae = float(np.mean(np.abs(edt_x - edt_gt)))

    c50_gt, c50_x = evaluate_clarity(xx, gt, fs=fs)
    c50_mae = float(np.mean(np.abs(c50_x - c50_gt)))

    return {'EDT': edt_mae, 'C50': c50_mae, 'T60': t60_err_pct}

# ─── Data Gathering ───────────────────────────────────────────────────────────
with open(refs_file, 'r') as f:
    references = json.load(f)["test"]

file_paths = sorted(
    glob.glob(eval_pattern),
    key=lambda x: int(os.path.basename(x).split('_')[1].split('.')[0])
)
if max_files is not None:
    file_paths = file_paths[:max_files]

# Use reference mode for both splits as you set earlier
ds_test = FurnishedRoomSTFTDataset(
    root_dir=root_dir, split="test", sample_rate=sample_rate, return_wav=True, mode="reference"
)
ds_ref = FurnishedRoomSTFTDataset(
    root_dir=root_dir, split="reference", sample_rate=sample_rate, return_wav=True, mode="reference"
)

records = []
for fp in file_paths:
    data = np.load(fp, allow_pickle=True).item()
    idx = int(data["audio_idx"])

    test_item = ds_test[idx]
    wav_gt = test_item['wav'].squeeze()
    stft_gt = test_item['stft'].squeeze(0)               # log-mag [F,T]

    stft_pred = torch.from_numpy(data["pred_stft"]).float().squeeze(0)  # log-mag
    mag_pred = torch.exp(stft_pred) - 1e-3               # linear mag
    wav_pred = istft(mag_pred.unsqueeze(0)).squeeze(0)

    # align lengths for waveform-based metrics
    L = min(wav_gt.shape[0], wav_pred.shape[0])
    wav_gt, wav_pred = wav_gt[:L], wav_pred[:L]

    test_id = test_item['id']
    top3_ids = references[test_id][:3]

    ref_samples = []
    for rid in top3_ids:
        r_idx = ds_ref.id2idx[rid]
        ref_item = ds_ref[r_idx]
        wav_r = ref_item['wav'].squeeze()
        stft_r = ref_item['stft'].squeeze(0)            # log-mag
        Lr = min(wav_gt.shape[0], wav_r.shape[0])
        wav_r = wav_r[:Lr]
        ref_samples.append({'id': rid, 'wav': wav_r, 'stft': stft_r})

    records.append({
        'idx': idx, 'id': test_id,
        'wav_gt': wav_gt, 'stft_gt': stft_gt,
        'wav_pred': wav_pred, 'stft_pred': stft_pred,
        'refs': ref_samples
    })

# ─── Metric Calculation ───────────────────────────────────────────────────────
for rec in records:
    # prediction vs GT
    pair_stft = torch.stack([rec['stft_gt'], rec['stft_pred']], dim=0)  # log-mag
    pair_wav  = torch.stack([rec['wav_gt'],  rec['wav_pred']], dim=0)

    rec['metrics'] = {
        'MSE' : compute_audio_distance(pair_stft, wavs=pair_wav, metric='MSE')[0,1].item(),
        'SPL' : compute_audio_distance(pair_stft, wavs=pair_wav, metric='SPL', fs=sample_rate)[0,1].item(),
        'MAG' : compute_audio_distance(pair_stft, metric='MAG')[0,1].item(),
        'MAG2': compute_audio_distance(pair_stft, metric='MAG2')[0,1].item(),
    }
    rec['metrics'].update(room_metric_diffs(rec['wav_gt'], rec['wav_pred'], fs=sample_rate))

    # top-3 reference metrics (each vs GT)
    ref_metrics = []
    for ref in rec['refs']:
        pair_r_stft = torch.stack([rec['stft_gt'], ref['stft']], dim=0)  # log-mag
        L = min(rec['wav_gt'].shape[0], ref['wav'].shape[0])
        pair_r_wav = torch.stack([rec['wav_gt'][:L], ref['wav'][:L]], dim=0)

        m = {
            'id'  : ref['id'],
            'MSE' : compute_audio_distance(pair_r_stft, wavs=pair_r_wav, metric='MSE')[0,1].item(),
            'SPL' : compute_audio_distance(pair_r_stft, wavs=pair_r_wav, metric='SPL', fs=sample_rate)[0,1].item(),
            'MAG' : compute_audio_distance(pair_r_stft, metric='MAG')[0,1].item(),
            'MAG2': compute_audio_distance(pair_r_stft, metric='MAG2')[0,1].item(),
        }
        m.update(room_metric_diffs(rec['wav_gt'], ref['wav'], fs=sample_rate))
        ref_metrics.append(m)
    rec['ref_metrics'] = ref_metrics

# ─── Save cache ──────────────────────────────────────────────────────────────
with open(CACHE_PATH, "wb") as f:
    pickle.dump({
        "records": records,
        "params": {
            "sample_rate": sample_rate,
            "n_fft": n_fft, "win_length": win_length, "hop_length": hop_length, "power": power
        }
    }, f)

print(f"Cached {len(records)} records to {CACHE_PATH}. "
      f"Available metrics per sample: {list(records[0]['metrics'].keys()) if records else '[]'}")

In [None]:
# --- CELL 2: Choose metric & N, then plot from cache (FAST) ---

import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt

CACHE_PATH = "./records_cache_furnishedroom_2.pkl"

# --- User-tunable: choose sort metric and number of worst to show ---
SORT_METRIC = "MSE"   # one of: "SPL", "MSE", "MAG", "MAG2", "T60", "C50", "EDT"
TOP_N       = 5      # how many worst to plot (after sorting by SORT_METRIC, desc)

# --- Small helper for titles ---
def three_line_block(id_str, m):
    """
    Format:
      1) ID
      2) SPL, MSE, MAG, MAG2
      3) T60%, C50Δ, EDTΔ
    """
    line1 = f"{id_str}"
    line2 = (f"SPL: {m['SPL']:.4f}   "
             f"MSE: {m['MSE']:.2e}   "
             f"MAG: {m['MAG']:.2f}   "
             f"MAG2: {m['MAG2']:.2e}")
    line3 = (f"T60%: {m['T60']:.1f}%   "
             f"C50Δ: {m['C50']:.2f} dB   "
             f"EDTΔ: {m['EDT']:.3f} s")
    return f"{line1}\n{line2}\n{line3}"

# --- Load cache ---
with open(CACHE_PATH, "rb") as f:
    cache = pickle.load(f)
records = cache["records"]

if not records:
    raise RuntimeError("No records in cache. Run Cell 1 first.")

if SORT_METRIC not in records[0]["metrics"]:
    raise ValueError(f"SORT_METRIC '{SORT_METRIC}' not in metrics. "
                     f"Choose from {list(records[0]['metrics'].keys())}")

# ─── Sort & Select Worst by chosen metric ─────────────────────────────────────
records_sorted = sorted(records, key=lambda r: r['metrics'][SORT_METRIC], reverse=True)
worst = records_sorted[:min(TOP_N, len(records_sorted))]

# ─── Plotting ─────────────────────────────────────────────────────────────────
for rec in worst:
    # Time-domain canvas: 2 rows (GT/Pred, then 3 refs)
    fig, axes = plt.subplots(2, 3, figsize=(18, 8))
    # top row: GT and Pred, keep titles clean
    axes[0,0].plot(rec['wav_gt'].numpy())
    axes[0,0].set_title('Ground Truth', fontsize=11, pad=8)
    axes[0,0].set_xlabel('Sample'); axes[0,0].set_ylabel('Amplitude')

    axes[0,1].plot(rec['wav_pred'].numpy())
    axes[0,1].set_title('Prediction', fontsize=11, pad=8)
    axes[0,1].set_xlabel('Sample'); axes[0,1].set_ylabel('Amplitude')

    axes[0,2].axis('off')  # spacer

    # bottom row: top-3 references with their own 3-line metric blocks
    for i, rm in enumerate(rec['ref_metrics'][:3]):
        wav_r = rec['refs'][i]['wav']
        axes[1,i].plot(wav_r.numpy())
        axes[1,i].set_xlabel('Sample'); axes[1,i].set_ylabel('Amplitude')
        axes[1,i].set_title(
            three_line_block(f"Ref {rm['id']}", rm),
            fontsize=10, pad=10
        )

    # Figure title (3 lines) for the prediction
    fig.suptitle(
        three_line_block(f"{rec['id']} (sorted by {SORT_METRIC})", rec['metrics']),
        fontsize=12, y=0.90
    )
    fig.subplots_adjust(top=0.83, hspace=0.45, wspace=0.25)
    plt.tight_layout(rect=[0, 0.02, 1, 0.88])
    plt.show()

    # STFT canvas: same layout but with images
    fig2, axes2 = plt.subplots(2, 3, figsize=(18, 9))
    im = axes2[0,0].imshow(rec['stft_gt'].numpy(), aspect='auto', origin='lower')
    axes2[0,0].set_title('GT STFT', fontsize=11, pad=8)
    axes2[0,0].set_xlabel('Time frame'); axes2[0,0].set_ylabel('Frequency bin')
    plt.colorbar(im, ax=axes2[0,0], fraction=0.046, pad=0.04)

    im2 = axes2[0,1].imshow(rec['stft_pred'].numpy(), aspect='auto', origin='lower')
    axes2[0,1].set_title('Pred STFT', fontsize=11, pad=8)
    axes2[0,1].set_xlabel('Time frame'); axes2[0,1].set_ylabel('Frequency bin')
    plt.colorbar(im2, ax=axes2[0,1], fraction=0.046, pad=0.04)

    axes2[0,2].axis('off')  # spacer

    for i, rm in enumerate(rec['ref_metrics'][:3]):
        stft_r = rec['refs'][i]['stft']
        imr = axes2[1,i].imshow(stft_r.numpy(), aspect='auto', origin='lower')
        axes2[1,i].set_title(
            three_line_block(f"Ref {rm['id']}", rm),
            fontsize=10, pad=10
        )
        axes2[1,i].set_xlabel('Time frame'); axes2[1,i].set_ylabel('Frequency bin')
        plt.colorbar(imr, ax=axes2[1,i], fraction=0.046, pad=0.04)

    fig2.suptitle(
        three_line_block(f"{rec['id']} (sorted by {SORT_METRIC})", rec['metrics']),
        fontsize=12, y=0.98
    )
    fig2.subplots_adjust(top=0.82, hspace=0.50, wspace=0.28)
    plt.tight_layout(rect=[0, 0.02, 1, 0.87])
    plt.show()


In [None]:
import torch
import os
from retriever import FurnishedRoomSTFTDataset
from NeRAF_helper import measure_rt60_advance, measure_edt, measure_clarity

def compute_edc_db(wav: torch.Tensor, T_target: int = 60) -> torch.Tensor:
    """
    Compute Schroeder EDC in dB, downsample to T_target frames.
    wav: [T] float32 tensor (mono)
    """
    e = wav.float()**2
    edc = torch.flip(torch.cumsum(torch.flip(e, dims=[0]), dim=0), dims=[0])
    edc = edc / (edc[0] + 1e-12)
    edc_db = 10.0 * torch.log10(edc + 1e-12)
    idx = torch.linspace(0, edc_db.numel() - 1, T_target).long()
    return edc_db[idx]

def compute_dr(wav: torch.Tensor, fs: int = 48000, direct_ms: float = 5.0) -> float:
    """
    Compute Direct-to-Reverberant ratio (dB) for a single IR.
    direct_ms: time window (ms) after direct-path arrival for 'direct' energy
    """
    # Find direct-path arrival index
    idx_direct = torch.argmax(torch.abs(wav)).item()
    win_samples = int(direct_ms * fs / 1000.0)

    start = max(idx_direct - win_samples // 2, 0)
    end   = min(idx_direct + win_samples // 2, wav.numel())

    direct_energy = torch.sum(wav[start:end] ** 2)
    reverb_energy = torch.sum(wav[end:] ** 2)

    dr_db = 10.0 * torch.log10((direct_energy + 1e-12) / (reverb_energy + 1e-12))
    return dr_db.item()  # return as Python float

# ─── Load dataset with all samples ────────────────────────────────
root_dir = "../data/RAF/EmptyRoom"
ds = FurnishedRoomSTFTDataset(
    root_dir=root_dir,
    split='all',
    sample_rate=48000,
    return_wav=True
)

# ─── Loop through dataset and compute features ────────────────────
features_dict = {}
for i in range(len(ds)):
    item = ds[i]
    sid = item['id']
    wav = item['wav'].float()  # [T] mono

    # EDC
    edc_curve = compute_edc_db(wav, T_target=60)

    # Decay features (single IR)
    t60 = measure_rt60_advance(wav.numpy(), sr=48000)         # seconds
    edt = measure_edt(wav.numpy(), fs=48000)                  # seconds
    c50 = measure_clarity(wav.numpy(), time=50, fs=48000)     # dB
    dr  = compute_dr(wav, fs=48000)                           # dB

    decay_feats = torch.tensor([t60, c50, edt, dr], dtype=torch.float32)

    features_dict[sid] = {
        'edc': edc_curve,
        'decay_feats': decay_feats
    }

# ─── Save to one file ─────────────────────────────────────────────
save_path = os.path.join(root_dir, "edc_decay_features.pt")
torch.save(features_dict, save_path)
print(f"Saved EDC + decay features (T60, C50, EDT, DR) for {len(features_dict)} samples to {save_path}")

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

from retriever import FurnishedRoomSTFTDataset, compute_audio_distance

import sys
sys.path.append('../NeRAF')
# pairwise “helper diffs”
from NeRAF_helper import compute_t60, evaluate_edt, evaluate_clarity
# single-sample features (fallback if dataset misses precomputed)
from NeRAF_helper import measure_rt60_advance, measure_edt, measure_clarity

# ─── Config ────────────────────────────────────────────────────────
root_dir    = "../data/RAF/FurnishedRoom"
sample_rate = 48000
i, j        = 0, 1   # change as you like
EDC_T       = 60     # frames to downsample EDC to (must match across the pair)

# ─── Tiny helpers ──────────────────────────────────────────────────
def compute_edc_db(wav_1d: torch.Tensor, T_target: int = 60) -> torch.Tensor:
    e = wav_1d.float()**2
    edc = torch.flip(torch.cumsum(torch.flip(e, dims=[0]), dim=0), dims=[0])
    edc = edc / (edc[0] + 1e-12)
    edc_db = 10.0 * torch.log10(edc + 1e-12)
    idx = torch.linspace(0, edc_db.numel() - 1, T_target).long()
    return edc_db[idx]

def get_or_compute_features(item, fs, edc_T=60):
    """Return (wav, stft, edc[T], decay_feats[3]) for a dataset item, computing if missing."""
    wav  = item['wav'].squeeze()
    stft = item['stft'].squeeze(0)  # log-mag
    # EDC
    edc  = item.get('edc', None)
    if edc is None:
        edc = compute_edc_db(wav, T_target=edc_T)
    # Decay feats: [T60, C50, EDT]
    decay = item.get('decay_feats', None)
    if decay is None:
        # single-sample measurements
        t60 = float(measure_rt60_advance(wav.detach().cpu().numpy(), sr=fs))
        edt = float(measure_edt(wav.detach().cpu().numpy(), fs=fs))
        c50 = float(measure_clarity(wav.detach().cpu().numpy(), time=50, fs=fs))
        decay = torch.tensor([t60, c50, edt], dtype=torch.float32)
    return wav, stft, edc, decay

def helper_decay_diffs(wav_a: torch.Tensor, wav_b: torch.Tensor, fs: int):
    """EDTΔ (s), C50Δ (dB), T60% error using the helper functions (pairwise)."""
    L = min(wav_a.numel(), wav_b.numel())
    A = wav_a[:L].detach().cpu().numpy()[None, :]
    B = wav_b[:L].detach().cpu().numpy()[None, :]

    # T60: % error (invalid -> 100%)
    t60_gt, t60_x = compute_t60(A, B, fs=fs, advanced=True)
    t60_gt = np.atleast_1d(t60_gt).astype(float)
    t60_x  = np.atleast_1d(t60_x).astype(float)
    with np.errstate(divide='ignore', invalid='ignore'):
        t60_diff = np.abs(t60_x - t60_gt) / (np.abs(t60_gt) + 1e-12)
    invalid = (t60_gt < -0.5) | (t60_x < -0.5)
    t60_diff[invalid] = 1.0
    t60_pct = float(np.mean(t60_diff) * 100.0)

    # EDT & C50: absolute differences
    edt_gt, edt_x = evaluate_edt(B, A, fs=fs)
    edt_mae = float(np.mean(np.abs(edt_x - edt_gt)))

    c50_gt, c50_x = evaluate_clarity(B, A, fs=fs)
    c50_mae = float(np.mean(np.abs(c50_x - c50_gt)))

    return dict(T60_h=t60_pct, C50_h=c50_mae, EDT_h=edt_mae)

# ─── Dataset ───────────────────────────────────────────────────────
ds = FurnishedRoomSTFTDataset(
    root_dir=root_dir, split="reference",
    sample_rate=sample_rate, return_wav=True, mode="reference"
)

a = ds[i]
b = ds[j]

wav_a, stft_a, edc_a, decay_a = get_or_compute_features(a, sample_rate, EDC_T)
wav_b, stft_b, edc_b, decay_b = get_or_compute_features(b, sample_rate, EDC_T)

# ─── Helper (pairwise) ─────────────────────────────────────────────
helper = helper_decay_diffs(wav_a, wav_b, sample_rate)

# ─── compute_audio_distance versions ───────────────────────────────
pair_stft  = torch.stack([stft_a, stft_b], dim=0)                   # [2,F,T] (only for API shape/device)
pair_decay = torch.stack([decay_a, decay_b], dim=0).float()         # [2,3] = [T60,C50,EDT]
pair_edc   = torch.stack([edc_a, edc_b], dim=0).float()             # [2,T_edc]

t60_cd = compute_audio_distance(pair_stft, decay_feats=pair_decay, metric='T60PCT')[0,1].item()
c50_cd = compute_audio_distance(pair_stft, decay_feats=pair_decay, metric='C50')[0,1].item()
edt_cd = compute_audio_distance(pair_stft, decay_feats=pair_decay, metric='EDT')[0,1].item()
edc_d  = compute_audio_distance(pair_stft, edc_curves=pair_edc,   metric='EDC')[0,1].item()

# ─── Print ─────────────────────────────────────────────────────────
print(f"A = {a['id']}  |  B = {b['id']}")
print("Helper (diffs):")
print(f"  T60%: {helper['T60_h']:.1f}%   C50Δ: {helper['C50_h']:.3f} dB   EDTΔ: {helper['EDT_h']:.3f} s")
print("compute_audio_distance (unitless distances):")
print(f"  T60_cd: {t60_cd:.6f}   C50_cd: {c50_cd:.6f}   EDT_cd: {edt_cd:.6f}   EDC_D: {edc_d:.6f}")

# ─── Quick plots: waveform + EDC curves ────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(wav_a.cpu().numpy(), label=f"A {a['id']}", alpha=0.8)
axes[0].plot(wav_b.cpu().numpy(), label=f"B {b['id']}", alpha=0.8)
axes[0].set_title("Waveforms")
axes[0].set_xlabel("Sample"); axes[0].set_ylabel("Amplitude")
axes[0].legend(loc="upper right", fontsize=9)

axes[1].plot(edc_a.cpu().numpy(), label=f"A {a['id']}", alpha=0.9)
axes[1].plot(edc_b.cpu().numpy(), label=f"B {b['id']}", alpha=0.9)
axes[1].set_title(f"EDC (downsampled to {EDC_T} frames)")
axes[1].set_xlabel("Frame"); axes[1].set_ylabel("EDC (dB)")
axes[1].legend(loc="upper right", fontsize=9)

plt.tight_layout()
plt.show()

In [None]:
# --- CELL 1 (GPU + BATCHED + EDC): Build data, retrieve top-3 refs, compute metrics (incl. EDC), cache ---
import os, glob, pickle
import numpy as np
import torch
from typing import Sequence
from torch.utils.data import DataLoader
from torchaudio.transforms import GriffinLim
from tqdm.auto import tqdm
from evaluator import compute_audio_distance, compute_edc_db
from retriever import FurnishedRoomSTFTDataset, RIRRetrievalMLP 
  # :contentReference[oaicite:0]{index=0}
import sys
sys.path.append('../NeRAF')
from NeRAF_helper import compute_t60, evaluate_edt, evaluate_clarity

# ---------------- Params ----------------
EVAL_PATTERN         = "../eval_results/emptyroom/emptyroom/renders/eval_*.npy"
MAX_EVAL_FILES       = 20
ROOT_DIR             = "../data/RAF/EmptyRoom"
CKPT_PATH            = './outputs/20250906_184700/rir_retrieval_model.ckpt'   #20250906_184700 20250812_204815
GRID_VEC_PATH        = "./features.pt"  # use_global_grid is always True
SAMPLE_RATE          = 48000
CACHE_PATH           = "./records_cache_autorefs.pkl"

RETRIEVAL_BACKEND    = "METRIC"      # "METRIC" (default) or "EMBEDDING"
RETRIEVAL_METRIC     = "MIXED"       # or "MIXED" or "FEATS"
MIXED_WEIGHTS        = ['EDC', 0.6, 'SPL', 0.4]
TOPK                 = 3
CHUNK_SIZE_G         = 2048
# ---------------- New knobs for "feature-vector" retrieval ----------------
FEATS_INCLUDE_EDC    = True      # include the (shape-normalized) EDC curve
FEATS_USE_DECAYS     = ['T60','C50','EDT','DR']   # choose any subset of these four
FEATS_CHUNK_SIZE_G   = 2048      # chunking like before
FEATS_EPS            = 1e-6
FEATS_EQUALIZE_GROUP_ENERGY = False
FEATS_EDC_WEIGHT  = 0.4   # if set, overrides auto alpha
FEATS_DECAY_WEIGHT = 0.6

# ---------------- Local EDC normalizer (same logic as in retriever.py) ----
def _normalize_edc_local(edc_db: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    edc_db: [B, T_edc] in dB. Make 0 dB at t=0, then z-score over time (per sample).
    """
    edc0 = edc_db[:, :1]                   # [B,1]
    edc_rel = edc_db - edc0
    std = edc_rel.std(dim=1, keepdim=True).clamp_min(eps)
    return edc_rel / std                   # [B, T_edc]

def _select_decay_columns(decay_blk: torch.Tensor, names: Sequence[str]) -> torch.Tensor:
    """
    decay_blk: [B, D] where D can be 3 or 4 depending on whether DR was stored.
    names: subset of ['T60','C50','EDT','DR'] in any order.
    Maps to columns: T60->0, C50->1, EDT->2, DR->3 (if present).
    """
    colmap = {'T60':0, 'C50':1, 'EDT':2, 'DR':3}
    idxs = [colmap[n] for n in names if (n in colmap)]
    idxs = [i for i in idxs if i < decay_blk.shape[1]]
    return decay_blk[:, idxs] if len(idxs) else None

# ---------------- Build per-chunk feature block (query + gallery) ---------
def _build_feature_blocks_for_chunk(q_item, g_edc_chunk, g_decay_chunk, T_edc_gallery, 
                                    include_edc=True, decay_names=('T60','C50','EDT','DR')):
    """
    Returns Xz: [B, D_total] on device, where B = 1 + G_chunk (query first).
    EDC part is per-sample normalized; then all dims are z-scored per block.
    Optionally re-weight EDC group vs decay group to equalize energy.
    """
    # --- EDC block (optional) ---
    edc_blk = None
    if include_edc:
        q_edc = q_item.get('edc')
        if (q_edc is None) or (q_edc.numel() == 0):
            wav_q = q_item['wav'].to(device)
            q_edc = compute_edc_db(wav_q.float(), T_target=T_edc_gallery).to(device)
        else:
            q_edc = q_edc.to(device)
            if q_edc.shape[0] != T_edc_gallery:
                q_edc = (q_edc[:T_edc_gallery] if q_edc.shape[0] > T_edc_gallery
                         else torch.nn.functional.pad(q_edc, (0, T_edc_gallery - q_edc.shape[0])))
        edc_blk = torch.cat([q_edc.unsqueeze(0), g_edc_chunk], dim=0) if g_edc_chunk is not None else q_edc.unsqueeze(0)
        edc_blk = _normalize_edc_local(edc_blk)   # [B, T_edc]

    # --- Decay features block (optional subset) ---
    decay_blk = None
    if g_decay_chunk is not None and (len(decay_names) > 0):
        q_decay = q_item.get('decay_feats')
        if q_decay is None:
            wav_q_np = _as_2d_numpy_cpu(q_item['wav'])
            t60_q, _ = compute_t60(wav_q_np, wav_q_np, fs=SAMPLE_RATE, advanced=True)
            edt_q, _ = evaluate_edt(wav_q_np, wav_q_np, fs=SAMPLE_RATE)
            c50_q, _ = evaluate_clarity(wav_q_np, wav_q_np, fs=SAMPLE_RATE)
            q_decay_np = np.array([
                np.atleast_1d(t60_q).astype(float)[0],
                np.atleast_1d(c50_q).astype(float)[0],
                np.atleast_1d(edt_q).astype(float)[0]
            ], dtype=np.float32)
            q_decay = torch.from_numpy(q_decay_np).to(device)
        else:
            q_decay = q_decay.to(device)

        # If gallery has DR and query doesn’t, add it
        if g_decay_chunk.shape[1] == 4 and q_decay.numel() == 3:
            from retriever import compute_dr
            q_dr = compute_dr(q_item['wav'].detach().cpu().numpy(), fs=SAMPLE_RATE)
            q_decay = torch.cat([q_decay, torch.tensor([q_dr], device=device, dtype=q_decay.dtype)], dim=0)

        decay_blk = torch.cat([q_decay.unsqueeze(0), g_decay_chunk], dim=0)  # [B, D]
        decay_blk = _select_decay_columns(decay_blk, list(decay_names))       # [B, D_sel] or None

    # --- Concatenate parts along feature dim ---
    pieces = []
    d_edc = 0
    d_decay = 0
    if edc_blk is not None:
        pieces.append(edc_blk); d_edc = edc_blk.shape[1]
    if decay_blk is not None:
        pieces.append(decay_blk); d_decay = decay_blk.shape[1]
    if not pieces:
        raise ValueError("No features selected: both EDC and decay feature set are empty.")
    X = torch.cat(pieces, dim=1)  # [B, D_total]

    # --- z-normalize per block (query+gallery) ---
    mu = _nanmean(X, dim=0, keepdim=True)
    sd = _nanstd(X, dim=0, keepdim=True, eps=FEATS_EPS)
    Xz = (X - mu) / sd  # [B, D_total]

    # --- group energy equalization (post z-score) ---
    if FEATS_EQUALIZE_GROUP_ENERGY and (d_edc > 0) and (d_decay > 0):
        alpha = (d_decay / max(d_edc, 1)) ** 0.5 if (FEATS_EDC_WEIGHT is None) else float(FEATS_EDC_WEIGHT)
        start = 0
        # EDC slice
        Xz[:, start:start+d_edc] *= alpha
        start += d_edc
        # DECAY slice
        Xz[:, start:start+d_decay] *= float(FEATS_DECAY_WEIGHT)

    return Xz

# -------------- New batched cosine distance vectorizer --------------------
def _featurevec_vector_batched(q_item, g_edc, g_decay, T_edc_gallery,
                               include_edc=True, decay_names=('T60','C50','EDT','DR'),
                               chunk_size=2048):
    """
    Returns CPU distance vector [G] of cosine distances between query and all gallery items.
    """
    G = (g_edc.shape[0] if g_edc is not None else g_decay.shape[0])
    out_parts = []
    for s in tqdm(range(0, G, chunk_size), desc="FEATS chunks", unit="chunk", leave=False):
        e = min(s + chunk_size, G)
        edc_c   = g_edc[s:e].to(device, non_blocking=True)   if g_edc   is not None else None
        dec_c   = g_decay[s:e].to(device, non_blocking=True) if g_decay is not None else None

        # Build [1+chunk, D] feature matrix (query first)
        Xz = _build_feature_blocks_for_chunk(q_item, edc_c, dec_c, T_edc_gallery,
                                             include_edc=include_edc, decay_names=decay_names)  # [B, D]
        # Cosine distance = 1 - cosine similarity
        Xn = torch.nn.functional.normalize(Xz, p=2, dim=1)   # unit rows
        qv = Xn[0:1]                                         # [1, D]
        sims = (qv @ Xn[1:].T).squeeze(0)                    # [chunk]
        dvec = (1.0 - sims).detach().cpu()
        out_parts.append(dvec)

        del edc_c, dec_c, Xz, Xn, qv, sims, dvec
        torch.cuda.empty_cache()
    return torch.cat(out_parts, dim=0)  # [G] on CPU

# ---------------- GPU setup ----------------
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
torch.backends.cudnn.benchmark = True

n_fft = (513 - 1) * 2; win_length = 512; hop_length = 256; power = 1
istft = GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=power).to(device)

# ---------------- Helpers ----------------
def _as_2d_numpy_cpu(wav_t):
    arr = wav_t.detach().cpu().numpy()
    return arr[None, :] if arr.ndim == 1 else arr

def room_metric_diffs_gpu(wav_gt_t, wav_x_t, fs):
    L = min(wav_gt_t.shape[0], wav_x_t.shape[0])
    gt = _as_2d_numpy_cpu(wav_gt_t[:L]); xx = _as_2d_numpy_cpu(wav_x_t[:L])
    t60_gt, t60_x = compute_t60(gt, xx, fs=fs, advanced=True)
    t60_gt = np.atleast_1d(t60_gt).astype(float); t60_x  = np.atleast_1d(t60_x).astype(float)
    with np.errstate(divide='ignore', invalid='ignore'):
        t60_diff = np.abs(t60_x - t60_gt) / (np.abs(t60_gt) + 1e-12)
    invalid_mask = (t60_gt < -0.5) | (t60_x < -0.5)
    t60_diff[invalid_mask] = 1.0
    t60_err_pct = float(np.mean(t60_diff) * 100.0)
    edt_gt, edt_x = evaluate_edt(xx, gt, fs=fs); edt_mae = float(np.mean(np.abs(edt_x - edt_gt)))
    c50_gt, c50_x = evaluate_clarity(xx, gt, fs=fs); c50_mae = float(np.mean(np.abs(c50_x - c50_gt)))
    return {'EDT': edt_mae, 'C50': c50_mae, 'T60': t60_err_pct}

def _pair_metrics_gpu_with_edc(stft_a, wav_a, stft_b, wav_b, edc_a, edc_b):
    """Compute MSE/SPL/MAG/MAG2 + EDC (via compute_audio_distance) + room metrics."""
    pair_stft = torch.stack([stft_a, stft_b], dim=0)
    L = min(wav_a.shape[0], wav_b.shape[0])
    pair_wav  = torch.stack([wav_a[:L], wav_b[:L]], dim=0)

    # EDC distance (B=2) — use same T_edc for both curves
    T_edc = edc_a.shape[0]
    if edc_b.shape[0] != T_edc:
        # adjust pred/ref EDC length to GT's length if needed
        if edc_b.shape[0] > T_edc: edc_b = edc_b[:T_edc]
        else: edc_b = torch.nn.functional.pad(edc_b, (0, T_edc - edc_b.shape[0]))
    pair_edc = torch.stack([edc_a, edc_b], dim=0)  # [2, T_edc]

    with torch.cuda.amp.autocast(enabled=use_amp), torch.no_grad():
        mse  = compute_audio_distance(pair_stft, wavs=pair_wav, metric='MSE')[0,1].item()
        spl  = compute_audio_distance(pair_stft, wavs=pair_wav, metric='SPL', fs=SAMPLE_RATE)[0,1].item()
        mag  = compute_audio_distance(pair_stft, metric='MAG')[0,1].item()
        mag2 = compute_audio_distance(pair_stft, metric='MAG2')[0,1].item()
        edcD = compute_audio_distance(pair_stft, wavs=None, edc_curves=pair_edc, metric='EDC')[0,1].item()  # :contentReference[oaicite:1]{index=1}

    out = {'MSE': mse, 'SPL': spl, 'MAG': mag, 'MAG2': mag2, 'EDC': edcD}
    out.update(room_metric_diffs_gpu(wav_a[:L], wav_b[:L], fs=SAMPLE_RATE))
    return out

def _build_query_blocks_for_chunk(q_stft, q_wav, q_item, g_stfts_chunk, g_wavs_chunk, g_edc_chunk, g_decay_chunk):
    stft_blk = torch.cat([q_stft.unsqueeze(0), g_stfts_chunk], dim=0)
    min_len = min(q_wav.shape[0], g_wavs_chunk.shape[1])
    wav_blk = torch.cat([q_wav[:min_len].unsqueeze(0), g_wavs_chunk[:, :min_len]], dim=0)
    edc_blk = None
    if g_edc_chunk is not None:
        q_edc = q_item.get('edc')
        if q_edc is None:
            T_edc = g_edc_chunk.shape[1]
            q_edc = compute_edc_db(wav_blk[0].float(), T_target=T_edc).to(device)  # :contentReference[oaicite:2]{index=2}
        else:
            q_edc = q_edc.to(device)
        edc_blk = torch.cat([q_edc.unsqueeze(0), g_edc_chunk], dim=0)
    decay_blk = None
    if (q_item.get('decay_feats') is not None) and (g_decay_chunk is not None):
        decay_blk = torch.cat([q_item['decay_feats'].unsqueeze(0).to(device), g_decay_chunk], dim=0)
    return stft_blk, wav_blk, edc_blk, decay_blk

def _metric_vector_batched(q_stft, q_wav, q_item, g_stfts, g_wavs, g_edc, g_decay, metric, chunk_size=2048):
    G = g_stfts.shape[0]; parts = []
    for s in tqdm(range(0, G, chunk_size), desc=f"{metric} chunks", unit="chunk", leave=False):
        e = min(s + chunk_size, G)
        stfts_c = g_stfts[s:e].to(device, non_blocking=True)
        wavs_c  = g_wavs[s:e].to(device, non_blocking=True)
        edc_c   = g_edc[s:e].to(device, non_blocking=True)   if g_edc   is not None else None
        dec_c   = g_decay[s:e].to(device, non_blocking=True) if g_decay is not None else None
        stft_blk, wav_blk, edc_blk, decay_blk = _build_query_blocks_for_chunk(q_stft, q_wav, q_item, stfts_c, wavs_c, edc_c, dec_c)
        with torch.cuda.amp.autocast(enabled=use_amp), torch.no_grad():
            D = compute_audio_distance(stft=stft_blk, wavs=wav_blk, edc_curves=edc_blk, decay_feats=decay_blk, metric=metric, fs=SAMPLE_RATE)
        parts.append(D[0, 1:].detach().cpu())
        del stfts_c, wavs_c, edc_c, dec_c, stft_blk, wav_blk, edc_blk, decay_blk, D
        torch.cuda.empty_cache()
    return torch.cat(parts, dim=0)  # CPU [G]

def _mixed_vector_batched(q_stft, q_wav, q_item, g_stfts, g_wavs, g_edc, g_decay, weights, chunk_size=2048):
    acc = None
    for m, w in tqdm(list(zip(weights[0::2], weights[1::2])), desc="MIXED sub-metrics", unit="metric", leave=False):
        d = _metric_vector_batched(q_stft, q_wav, q_item, g_stfts, g_wavs, g_edc, g_decay, metric=m, chunk_size=chunk_size)
        mask = torch.isfinite(d); mu = d[mask].mean(); sd = d[mask].std().clamp_min(1e-6)
        z = (d - mu) / sd
        acc = z * float(w) if acc is None else acc + z * float(w)
    return acc  # CPU [G]

# --- add near imports ---
def _nanmean(x: torch.Tensor, dim: int, keepdim: bool = False) -> torch.Tensor:
    mask = ~torch.isnan(x)
    x0 = torch.where(mask, x, torch.zeros_like(x))
    cnt = mask.sum(dim=dim, keepdim=keepdim).clamp_min(1)
    return x0.sum(dim=dim, keepdim=keepdim) / cnt

def _nanstd(x: torch.Tensor, dim: int, keepdim: bool = False, eps: float = 1e-6) -> torch.Tensor:
    m = _nanmean(x, dim=dim, keepdim=True)
    v = _nanmean((x - m) ** 2, dim=dim, keepdim=keepdim)
    return v.clamp_min(0).sqrt().clamp_min(eps)

# ---------------- Load eval files ----------------
file_paths = sorted(glob.glob(EVAL_PATTERN), key=lambda x: int(os.path.basename(x).split('_')[1].split('.')[0]))
if MAX_EVAL_FILES is not None: file_paths = file_paths[:MAX_EVAL_FILES]

ds_test = FurnishedRoomSTFTDataset(root_dir=ROOT_DIR, split="test", sample_rate=SAMPLE_RATE, return_wav=True, mode="reference")
ds_ref  = FurnishedRoomSTFTDataset(root_dir=ROOT_DIR, split="reference", sample_rate=SAMPLE_RATE, return_wav=True, mode="reference")

# Preload gallery to CPU
ref_loader = DataLoader(ds_ref, batch_size=len(ds_ref), shuffle=False)
batch_ref = next(iter(ref_loader))
ref_ids       = batch_ref['id']
ref_stfts_cpu = batch_ref['stft']
ref_wavs_cpu  = batch_ref['wav']
ref_edc_cpu   = batch_ref.get('edc')
ref_decay_cpu = batch_ref.get('decay_feats')
ref_mic_pose  = batch_ref['mic_pose']     # ← needed by the retriever
ref_src_pose  = batch_ref['source_pose']  # ← needed by the retriever
ref_rot       = batch_ref['rot']          # ← needed by the retriever
G = ref_stfts_cpu.size(0)
T_edc_gallery = (ref_edc_cpu.shape[1] if ref_edc_cpu is not None else 60)

# --- NEW: Embedding backend setup ---
Zg = None  # gallery embeddings (GPU) if we use the model
if RETRIEVAL_BACKEND.upper() == "EMBEDDING":
    ckpt    = torch.load(CKPT_PATH, map_location=device)
    model   = RIRRetrievalMLP(**ckpt["model_config"]).to(device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    grid_vec = torch.load(GRID_VEC_PATH, map_location=device).to(device)   # use_global_grid=True
    with torch.no_grad():
        Gg    = grid_vec.unsqueeze(0).expand(G, -1).to(device)
        mic_g = ref_mic_pose.to(device)
        src_g = ref_src_pose.to(device)
        rot_g = ref_rot.to(device)
        Zg    = model(Gg, mic_g, src_g, rot_g)  # [G, D] embeddings on device


records = []

# ---------------- Per-file loop ----------------
for fp in tqdm(file_paths, desc=f"Processing eval files (batched on {device.type.upper()})", unit="file"):
    data = np.load(fp, allow_pickle=True).item()
    idx = int(data["audio_idx"])
    test_item = ds_test[idx]

    # Move query to GPU
    wav_gt   = test_item['wav'].squeeze().to(device, non_blocking=True)
    stft_gt  = test_item['stft'].squeeze(0).to(device, non_blocking=True)

    stft_pred = torch.from_numpy(data["pred_stft"]).float().squeeze(0).to(device)
    with torch.cuda.amp.autocast(enabled=use_amp):
        mag_pred  = torch.exp(stft_pred) - 1e-3
        wav_pred  = istft(mag_pred.unsqueeze(0)).squeeze(0)

    # align for waveform metrics
    L = min(wav_gt.shape[0], wav_pred.shape[0]); wav_gt, wav_pred = wav_gt[:L], wav_pred[:L]

    # ---- Retrieval distances (batched) ----
    # ---- Retrieval (two backends) ----
    if RETRIEVAL_BACKEND.upper() == "EMBEDDING":
        # Build query embedding with the same model + global grid features
        test_item = ds_test[idx]  # already set above; keeping for clarity
        with torch.no_grad():
            Gq    = grid_vec.unsqueeze(0).to(device)
            mic_q = test_item['mic_pose'].unsqueeze(0).to(device)
            src_q = test_item['source_pose'].unsqueeze(0).to(device)
            rot_q = test_item['rot'].unsqueeze(0).to(device)
            Zq    = model(Gq, mic_q, src_q, rot_q)     # [1, D]

            S = (Zq @ Zg.T).squeeze(0)                 # [G]
            top3_idx = torch.topk(S, k=TOPK, largest=True).indices.tolist()
            top3_ids = [ref_ids[i] for i in top3_idx]

        # for uniform downstream, synthesize a "distance" vector if you need it later
        # (not strictly required here)
        d_vec_cpu = None

    else:
        # ---- Retrieval distances (batched metrics; existing behavior) ----
        if RETRIEVAL_METRIC == "FEATS":
            d_vec_cpu = _featurevec_vector_batched(
                q_item=test_item,
                g_edc=ref_edc_cpu, 
                g_decay=ref_decay_cpu,
                T_edc_gallery=T_edc_gallery,
                include_edc=FEATS_INCLUDE_EDC,
                decay_names=tuple(FEATS_USE_DECAYS),
                chunk_size=FEATS_CHUNK_SIZE_G
            )
        elif RETRIEVAL_METRIC != "MIXED":
            d_vec_cpu = _metric_vector_batched(
                stft_gt, wav_gt, test_item,
                ref_stfts_cpu, ref_wavs_cpu, ref_edc_cpu, ref_decay_cpu,
                metric=RETRIEVAL_METRIC, chunk_size=CHUNK_SIZE_G
            )
        else:
            d_vec_cpu = _mixed_vector_batched(
                stft_gt, wav_gt, test_item,
                ref_stfts_cpu, ref_wavs_cpu, ref_edc_cpu, ref_decay_cpu,
                weights=MIXED_WEIGHTS, chunk_size=CHUNK_SIZE_G
            )

        top3_idx = torch.argsort(d_vec_cpu)[:TOPK].tolist()
        top3_ids = [ref_ids[i] for i in top3_idx]


    # ---- Build EDC curves (GT / Pred / Refs) for plotting & distances ----
    # Consistent target length based on gallery (fallback 60 if missing)
    T_edc_gallery = (ref_edc_cpu.shape[1] if ref_edc_cpu is not None else 60)

    # GT EDC: prefer dataset's; else compute to gallery length
    q_edc = test_item.get('edc')
    if q_edc is None or q_edc.numel() == 0:
        q_edc = compute_edc_db(wav_gt.float(), T_target=T_edc_gallery).to(device)
    else:
        # ensure correct length
        q_edc = q_edc.to(device)
        if q_edc.shape[0] != T_edc_gallery:
            if q_edc.shape[0] > T_edc_gallery:
                q_edc = q_edc[:T_edc_gallery]
            else:
                q_edc = torch.nn.functional.pad(q_edc, (0, T_edc_gallery - q_edc.shape[0]))

    # Pred EDC: always compute from wav_pred to same length
    p_edc = compute_edc_db(wav_pred.float(), T_target=T_edc_gallery).to(device)

    # (optional sanity) avoid NaNs
    q_edc = torch.nan_to_num(q_edc, nan=0.0)
    p_edc = torch.nan_to_num(p_edc, nan=0.0)


    # ---- Metrics: pred vs GT (with EDC) ----
    pred_metrics = _pair_metrics_gpu_with_edc(stft_gt, wav_gt, stft_pred, wav_pred, q_edc, p_edc)

    # ---- Metrics: top-3 refs vs GT (with EDC) + store ref EDC curves for plotting ----
    edc_refs_curves = []
    refs_out = []
    for i in tqdm(top3_idx, desc="Top-K metrics", unit="ref", leave=False):
        stft_r = ref_stfts_cpu[i].to(device, non_blocking=True)
        wav_r  = ref_wavs_cpu[i].to(device, non_blocking=True)
        # ref EDC curve (already precomputed in dataset)
        if ref_edc_cpu is not None:
            r_edc = ref_edc_cpu[i].to(device, non_blocking=True)
        else:
            # fallback if not present
            r_edc = compute_edc_db(wav_r.float(), T_target=T_edc_gallery).to(device)
        m = _pair_metrics_gpu_with_edc(stft_gt, wav_gt, stft_r, wav_r, q_edc, r_edc)
        refs_out.append({'id': ref_ids[i], **m})
        edc_refs_curves.append(r_edc.detach().cpu())

    records.append({
        'idx': idx,
        'id': test_item['id'],
        'wav_gt':   wav_gt.detach().cpu(),   'stft_gt':  stft_gt.detach().cpu(),
        'wav_pred': wav_pred.detach().cpu(), 'stft_pred': stft_pred.detach().cpu(),
        'edc_gt':   q_edc.detach().cpu(),    'edc_pred': p_edc.detach().cpu(),       # <-- store EDC curves
        'edc_refs': edc_refs_curves,                                                     # list of 3 curves
        'retrieval': {
            'backend': RETRIEVAL_BACKEND,                 # ← new
            'metric': RETRIEVAL_METRIC if RETRIEVAL_BACKEND.upper()=="METRIC" else None,
            'weights': MIXED_WEIGHTS if (RETRIEVAL_BACKEND.upper()=="METRIC" and RETRIEVAL_METRIC=="MIXED") else None,
            'normalize_mode': "per_query",
            'top3_ids': top3_ids,
            'ckpt_path': CKPT_PATH if RETRIEVAL_BACKEND.upper()=="EMBEDDING" else None,
        },
        'metrics_pred': pred_metrics,
        'metrics_refs': refs_out,   # list of 3 dicts (each has 'EDC' now)
    })

    del wav_gt, stft_gt, stft_pred, wav_pred, q_edc, p_edc
    torch.cuda.empty_cache()

# ---------------- Save cache ----------------
with open(CACHE_PATH, "wb") as f:
    pickle.dump({
        "records": records,
        "params": {
            "sample_rate": SAMPLE_RATE,
            "n_fft": n_fft, "win_length": win_length, "hop_length": hop_length, "power": power,
            "retrieval_metric": RETRIEVAL_METRIC,
            "mixed_weights": MIXED_WEIGHTS,
            "normalize_mode": "per_query",
            "topk": TOPK,
            "root_dir": ROOT_DIR,
            "chunk_size_g": CHUNK_SIZE_G,
            "retrieval_backend": RETRIEVAL_BACKEND,
            "ckpt_path": CKPT_PATH if RETRIEVAL_BACKEND.upper()=="EMBEDDING" else None,
            "grid_vec_path": GRID_VEC_PATH if RETRIEVAL_BACKEND.upper()=="EMBEDDING" else None,
        }
    }, f)

print(f"Cached {len(records)} records to {CACHE_PATH}. "
      f"Example keys: {list(records[0].keys()) if records else '[]'}")

In [None]:
# --- CELL 2: Load cache, compute averages & deltas, plot by metric (now with EDC curves) ---

import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
from retriever import FurnishedRoomSTFTDataset

CACHE_PATH = "./records_cache_autorefs.pkl"  # update if different

# --- User knobs ---
SORT_METRIC = "SPL"      # can now also be "EDC"
TOP_N       = 20

def three_line_block(id_str, m):
    # 1) ID
    # 2) SPL, MSE, MAG, MAG2
    # 3) T60%, C50Δ, EDTΔ, EDC
    line1 = f"{id_str}"
    line2 = (f"SPL: {m['SPL']:.4f}   "
             f"MSE: {m['MSE']:.2e}   "
             f"MAG: {m['MAG']:.2f}   "
             f"MAG2: {m['MAG2']:.2e}")
    line3 = (f"T60%: {m['T60']:.1f}%   "
             f"C50Δ: {m['C50']:.2f} dB   "
             f"EDTΔ: {m['EDT']:.3f} s   "
             f"EDC: {m['EDC']:.3f}")
    return f"{line1}\n{line2}\n{line3}"

def _align_edc(a: torch.Tensor, b: torch.Tensor):
    La, Lb = a.shape[0], b.shape[0]
    L = min(La, Lb)
    return a[:L], b[:L], L

with open(CACHE_PATH, "rb") as f:
    cache = pickle.load(f)
records = cache["records"]; params = cache["params"]

if not records:
    raise RuntimeError("No records in cache. Run Cell 1 first.")

# include EDC in comparisons
sample_metrics = ['SPL','MSE','MAG','MAG2','EDC','T60','C50','EDT']
for key in sample_metrics:
    if key not in records[0]['metrics_pred']:
        raise ValueError(f"Missing metric '{key}' in cached data.")

def _accumulate_averages(rows):
    K = 3
    sums_pred = {m:0.0 for m in sample_metrics}
    sums_ref  = [{m:0.0 for m in sample_metrics} for _ in range(K)]
    n = len(rows)
    for rec in rows:
        for m in sample_metrics:
            sums_pred[m] += float(rec['metrics_pred'][m])
        for k in range(K):
            for m in sample_metrics:
                sums_ref[k][m] += float(rec['metrics_refs'][k][m])
    avg_pred = {m: sums_pred[m]/n for m in sample_metrics}
    avg_ref  = [{m: sums_ref[k][m]/n for m in sample_metrics} for k in range(K)]
    deltas   = [{m: (avg_ref[k][m] - avg_pred[m]) for m in sample_metrics} for k in range(K)]
    return avg_pred, avg_ref, deltas

avg_pred_all, avg_ref_all, deltas_all = _accumulate_averages(records)

print("=== Averages over ALL selected eval samples ===")
print("Prediction vs GT:")
for m in sample_metrics: print(f"  {m:>4}: {avg_pred_all[m]:.8f}")
for k in range(3):
    print(f"\nTop-{k+1} reference vs GT:")
    for m in sample_metrics: print(f"  {m:>4}: {avg_ref_all[k][m]:.8f}")
    print("  Δ(ref - pred):")
    for m in sample_metrics: print(f"    {m:>4}: {deltas_all[k][m]:+.8f}  ({'better' if deltas_all[k][m]<0 else 'worse'})")

records_sorted = sorted(records, key=lambda r: r['metrics_pred'][SORT_METRIC], reverse=True)
worst = records_sorted[:min(TOP_N, len(records_sorted))]
# worst = records_sorted[-TOP_N:]

avg_pred_topN, avg_ref_topN, deltas_topN = _accumulate_averages(worst)
print(f"\n=== Averages over Top-{len(worst)} WORST by {SORT_METRIC} ===")
print("Prediction vs GT:")
for m in sample_metrics: print(f"  {m:>4}: {avg_pred_topN[m]:.8f}")
for k in range(3):
    print(f"\nTop-{k+1} reference vs GT:")
    for m in sample_metrics: print(f"  {m:>4}: {avg_ref_topN[k][m]:.8f}")
    print("  Δ(ref - pred):")
    for m in sample_metrics: print(f"    {m:>4}: {deltas_topN[k][m]:+.8f}  ({'better' if deltas_topN[k][m]<0 else 'worse'})")

# --- Plotting (time, STFT, EDC) for Top_N worst by SORT_METRIC ---
for rec in worst:
    # 1) Time-domain canvas
    fig, axes = plt.subplots(2, 3, figsize=(18, 8))
    axes[0,0].plot(rec['wav_gt'].numpy());   axes[0,0].set_title('Ground Truth', fontsize=11, pad=8)
    axes[0,1].plot(rec['wav_pred'].numpy()); axes[0,1].set_title('Prediction', fontsize=11, pad=8)
    axes[0,0].set_xlabel('Sample'); axes[0,0].set_ylabel('Amplitude')
    axes[0,1].set_xlabel('Sample'); axes[0,1].set_ylabel('Amplitude')
    axes[0,2].axis('off')
    for i, rm in enumerate(rec['metrics_refs'][:3]):
        rid = rec['retrieval']['top3_ids'][i]
        ds_ref = 'dataset_ref' in globals() and dataset_ref or FurnishedRoomSTFTDataset(
            root_dir=params.get("root_dir", "../data/RAF/EmptyRoom"), split="reference",
            sample_rate=params["sample_rate"], return_wav=True, mode="reference"
        )
        if 'dataset_ref' not in globals(): dataset_ref = ds_ref
        wav_r = dataset_ref[dataset_ref.id2idx[rid]]['wav']
        axes[1,i].plot(wav_r.numpy())
        axes[1,i].set_xlabel('Sample'); axes[1,i].set_ylabel('Amplitude')
        axes[1,i].set_title(three_line_block(f"Ref {rm['id']}", rm), fontsize=10, pad=10)
    fig.suptitle(three_line_block(f"{rec['id']} (sorted by {SORT_METRIC})", rec['metrics_pred']), fontsize=12, y=0.90)
    fig.subplots_adjust(top=0.83, hspace=0.45, wspace=0.25)
    plt.tight_layout(rect=[0, 0.02, 1, 0.88]); plt.show()

    # 2) STFT canvas
    fig2, axes2 = plt.subplots(2, 3, figsize=(18, 9))
    im = axes2[0,0].imshow(rec['stft_gt'].numpy(), aspect='auto', origin='lower');  axes2[0,0].set_title('GT STFT', fontsize=11, pad=8)
    im2= axes2[0,1].imshow(rec['stft_pred'].numpy(), aspect='auto', origin='lower'); axes2[0,1].set_title('Pred STFT', fontsize=11, pad=8)
    axes2[0,0].set_xlabel('Time frame'); axes2[0,0].set_ylabel('Frequency bin')
    axes2[0,1].set_xlabel('Time frame'); axes2[0,1].set_ylabel('Frequency bin')
    plt.colorbar(im, ax=axes2[0,0], fraction=0.046, pad=0.04); plt.colorbar(im2, ax=axes2[0,1], fraction=0.046, pad=0.04)
    axes2[0,2].axis('off')
    for i, rm in enumerate(rec['metrics_refs'][:3]):
        rid = rec['retrieval']['top3_ids'][i]
        ds_ref = 'dataset_ref' in globals() and dataset_ref or FurnishedRoomSTFTDataset(
            root_dir=params.get("root_dir", "../data/RAF/EmptyRoom"), split="reference",
            sample_rate=params["sample_rate"], return_wav=True, mode="reference"
        )
        if 'dataset_ref' not in globals(): dataset_ref = ds_ref
        stft_r = dataset_ref[dataset_ref.id2idx[rid]]['stft']
        imr = axes2[1,i].imshow(stft_r.numpy(), aspect='auto', origin='lower')
        axes2[1,i].set_title(three_line_block(f"Ref {rm['id']}", rm), fontsize=10, pad=10)
        axes2[1,i].set_xlabel('Time frame'); axes2[1,i].set_ylabel('Frequency bin')
        plt.colorbar(imr, ax=axes2[1,i], fraction=0.046, pad=0.04)
    fig2.suptitle(three_line_block(f"{rec['id']} (sorted by {SORT_METRIC})", rec['metrics_pred']), fontsize=12, y=0.98)
    fig2.subplots_adjust(top=0.82, hspace=0.50, wspace=0.28)
    plt.tight_layout(rect=[0, 0.02, 1, 0.87]); plt.show()

    # 3) EDC canvas (new/updated): line plots of EDC curves
    fig3, axes3 = plt.subplots(2, 3, figsize=(18, 7))

    # GT vs Pred
    edc_gt = rec['edc_gt']
    edc_pr = rec['edc_pred']
    edc_gt_al, edc_pr_al, L = _align_edc(edc_gt, edc_pr)
    t = np.arange(L)
    axes3[0,0].plot(t, edc_gt_al.numpy(), label='GT')
    axes3[0,0].plot(t, edc_pr_al.numpy(), label='Pred', alpha=0.9)
    axes3[0,0].set_title('EDC: GT vs Pred', fontsize=11, pad=8)
    axes3[0,0].set_xlabel('EDC frame'); axes3[0,0].set_ylabel('EDC (dB)')
    axes3[0,0].legend(loc='best')

    axes3[0,1].axis('off'); axes3[0,2].axis('off')

    # three refs: GT vs each top-3 ref
    for i, rm in enumerate(rec['metrics_refs'][:3]):
        edc_r = rec['edc_refs'][i]
        edc_gt_al, edc_r_al, Lr = _align_edc(rec['edc_gt'], edc_r)
        tr = np.arange(Lr)
        ax = axes3[1,i]
        ax.plot(tr, edc_gt_al.numpy(), label='GT')
        ax.plot(tr, edc_r_al.numpy(), label=f"Ref {rm['id']}", alpha=0.9)
        ax.set_xlabel('EDC frame'); ax.set_ylabel('EDC (dB)')
        ax.set_title(three_line_block(f"Ref {rm['id']}", rm), fontsize=10, pad=10)
        ax.legend(loc='best')

    fig3.suptitle(three_line_block(f"{rec['id']} (sorted by {SORT_METRIC})", rec['metrics_pred']), fontsize=12, y=0.96)
    fig3.subplots_adjust(top=0.86, hspace=0.45, wspace=0.25)
    plt.tight_layout(rect=[0, 0.02, 1, 0.90]); plt.show()

In [None]:
# === CLEAN 5-COLUMN GALLERY (GT | Pred | Top-1 | Top-2 | Top-3) ===
# Paper-ready: no ticks, bold headers row 1, metrics above plots.

# ==== KNOBS ====
CACHE_PATH         = "./records_cache_autorefs.pkl"
ROOT_DIR           = "../data/RAF/FurnishedRoom"
SAMPLE_RATE        = 48000

NUM_ROWS           = 15
SORT_BY_METRIC     = "SPL"   # "SPL","T60","C50","EDT","EDC"
WORST_FIRST        = True    # True → worst→best, False → best→worst

SHOW_WAVEFORMS     = True
SHOW_STFTS         = True

FIG_DPI            = 130
FIG_W              = 14
FIG_H_WAVE_PERROW  = 1.05
FIG_H_STFT_PERROW  = 1.05

COL_TITLE_SIZE     = 12
METRIC_FONTSIZE    = 8
BORDER_LINEWIDTH   = 0.6
BORDER_ALPHA       = 0.85

# ==== IMPORTS ====
import pickle, numpy as np, torch
import matplotlib.pyplot as plt
from retriever import FurnishedRoomSTFTDataset
torch.set_grad_enabled(False)

# ==== LOAD CACHE ====
with open(CACHE_PATH, "rb") as f:
    cache = pickle.load(f)
records = cache["records"]; params = cache["params"]
if not records:
    raise RuntimeError("No records in cache.")

sample_metrics = ['SPL','T60','C50','EDT','EDC']
for m in sample_metrics:
    if m not in records[0]['metrics_pred']:
        raise ValueError(f"Missing metric '{m}' in cache.")

# ==== REF DATASETS ====
dataset_ref_primary = FurnishedRoomSTFTDataset(
    root_dir=params.get("root_dir", ROOT_DIR),
    split="reference",
    sample_rate=params.get("sample_rate", SAMPLE_RATE),
    return_wav=True, mode="reference"
)
dataset_ref_fallback = FurnishedRoomSTFTDataset(
    root_dir=params.get("root_dir", ROOT_DIR),
    split="train",
    sample_rate=params.get("sample_rate", SAMPLE_RATE),
    return_wav=True, mode="reference"
)

def _find_ref_by_id(rid):
    rid_str = str(rid)
    idx = dataset_ref_primary.id2idx.get(rid_str, None)
    if idx is not None:
        return dataset_ref_primary[idx]
    idx = dataset_ref_fallback.id2idx.get(rid_str, None)
    if idx is not None:
        return dataset_ref_fallback[idx]
    return None

# ==== HELPERS ====
def _fmt_one_line(m):
    return f"ΔSTFT:{m['SPL']:.3f}  T60%:{m['T60']:.2f}  C50:{m['C50']:.3f}  EDT:{m['EDT']:.3f}"

def _put_metrics_above(ax, text, fontsize, y_offset=1.05):
    ax.text(0.5, y_offset, text, transform=ax.transAxes,
            ha="center", va="bottom", fontsize=fontsize,
            fontweight="normal", clip_on=False)

def _collect_vmin_vmax(*arrs):
    vmin = min(float(a.min()) for a in arrs)
    vmax = max(float(a.max()) for a in arrs)
    return vmin, vmax

def _sort_key(rec, metric):
    return float(rec['metrics_pred'][metric])

# ==== PREPARE ROWS ====
recs_sorted = sorted(records, key=lambda r: _sort_key(r, SORT_BY_METRIC), reverse=WORST_FIRST)
rows = recs_sorted[:min(NUM_ROWS, len(recs_sorted))]

col_keys   = ["GT", "Pred", "Ref1", "Ref2", "Ref3"]
col_titles = ["Ground Truth", "Prediction", "Top-1 Retrieved", "Top-2 Retrieved", "Top-3 Retrieved"]

gallery = []
missing_count = 0
for rec in rows:
    top3_ids = rec['retrieval']['top3_ids'][:3]
    ref_blocks = []
    for i, rid in enumerate(top3_ids):
        ref_item = _find_ref_by_id(rid)
        if ref_item is None:
            missing_count += 1
            ref_blocks.append({"wav": None, "stft": None, "m": rec['metrics_refs'][i], "missing": True})
        else:
            ref_blocks.append({
                "wav":  ref_item['wav'].squeeze().cpu(),
                "stft": ref_item['stft'].squeeze(0).cpu(),
                "m":    rec['metrics_refs'][i],
                "missing": False
            })
    gallery.append({
        "id": rec['id'],
        "GT":   {"wav": rec['wav_gt'].cpu(),  "stft": rec['stft_gt'].cpu()},
        "Pred": {"wav": rec['wav_pred'].cpu(), "stft": rec['stft_pred'].cpu(), "m": rec['metrics_pred']},
        "Ref1": ref_blocks[0], "Ref2": ref_blocks[1], "Ref3": ref_blocks[2],
    })

if missing_count:
    print(f"[INFO] {missing_count} refs not found in either split.")

# ==== FIGURE A: WAVEFORMS ====
if SHOW_WAVEFORMS:
    figA, axsA = plt.subplots(len(gallery), 5,
                              figsize=(FIG_W, FIG_H_WAVE_PERROW*len(gallery)), dpi=FIG_DPI)
    if len(gallery) == 1: axsA = np.expand_dims(axsA, 0)

    for r, row in enumerate(gallery):
        for c, key in enumerate(col_keys):
            ax = axsA[r, c]; blk = row[key]
            if blk.get("missing", False):
                ax.axis("off")
                ax.text(0.5, 0.5, "ref not found", transform=ax.transAxes,
                        ha="center", va="center", fontsize=8)
            else:
                w = blk["wav"].float().numpy()
                t = np.arange(len(w)) / SAMPLE_RATE
                ax.plot(t, w, linewidth=0.8)
                ax.set_xticks([]); ax.set_yticks([])
                for sp in ax.spines.values():
                    sp.set_linewidth(BORDER_LINEWIDTH); sp.set_alpha(BORDER_ALPHA)

            if r == 0:
                ax.set_title(col_titles[c], fontweight="bold", fontsize=COL_TITLE_SIZE, pad=20)
                if key != "GT" and not blk.get("missing", False):
                    _put_metrics_above(ax, _fmt_one_line(blk["m"]), METRIC_FONTSIZE, y_offset=1.15)
            else:
                if key != "GT" and not blk.get("missing", False):
                    _put_metrics_above(ax, _fmt_one_line(blk["m"]), METRIC_FONTSIZE, y_offset=1.05)

    figA.suptitle(f"Waveforms — rows: samples (sorted by {SORT_BY_METRIC}, {'worst→best' if WORST_FIRST else 'best→worst'})",
                  y=0.995, fontsize=10)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()

# ==== FIGURE B: STFTs ====
if SHOW_STFTS:
    figB, axsB = plt.subplots(len(gallery), 5,
                              figsize=(FIG_W, FIG_H_STFT_PERROW*len(gallery)), dpi=FIG_DPI)
    if len(gallery) == 1: axsB = np.expand_dims(axsB, 0)

    for r, row in enumerate(gallery):
        Sarr = []
        for key in col_keys:
            blk = row[key]
            Sarr.append(None if blk.get("missing", False) else blk["stft"].float().numpy())
        valid = [a for a in Sarr if a is not None]
        vmin, vmax = _collect_vmin_vmax(*valid) if valid else (0, 1)

        for c, key in enumerate(col_keys):
            ax = axsB[r, c]; blk = row[key]
            if blk.get("missing", False):
                ax.axis("off")
                ax.text(0.5, 0.5, "ref not found", transform=ax.transAxes,
                        ha="center", va="center", fontsize=8)
            else:
                ax.imshow(blk["stft"].float().numpy(), origin="lower", aspect="auto", vmin=vmin, vmax=vmax)
                ax.set_xticks([]); ax.set_yticks([])
                for sp in ax.spines.values():
                    sp.set_linewidth(BORDER_LINEWIDTH); sp.set_alpha(BORDER_ALPHA)

            if r == 0:
                ax.set_title(col_titles[c], fontweight="bold", fontsize=COL_TITLE_SIZE, pad=20)
                if key != "GT" and not blk.get("missing", False):
                    _put_metrics_above(ax, _fmt_one_line(blk["m"]), METRIC_FONTSIZE, y_offset=1.15)
            else:
                if key != "GT" and not blk.get("missing", False):
                    _put_metrics_above(ax, _fmt_one_line(blk["m"]), METRIC_FONTSIZE, y_offset=1.05)

    figB.suptitle(f"STFT log-mag — rows: samples (sorted by {SORT_BY_METRIC}, {'worst→best' if WORST_FIRST else 'best→worst'})",
                  y=0.995, fontsize=10)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()


In [None]:
# --- Retrieval for all splits: top-10 from reference using MIXED(EDC,SPL) or EMBEDDING ---
import os, json
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# use your retriever dataset + distance
from retriever import FurnishedRoomSTFTDataset
from evaluator import compute_audio_distance

# Optional embedding backend (RIRRetrievalMLP)
EMBEDDING_CKPT_PATH = './outputs/20250812_204815/rir_retrieval_model.ckpt'
GRID_VEC_PATH       = './features.pt'   # use_global_grid=True

# ─── configs ─────────────────────────────────────────────────────────────────────────────
root                 = '../data/RAF/EmptyRoom'

# Backend: "MIXED" (EDC/SPL) or "EMBEDDING"
RETRIEVAL_BACKEND    = "MIXED"          # or "EMBEDDING"

# For MIXED backend (default weights)
MIXED_WEIGHTS        = [('EDC', 0.6), ('SPL', 0.4)]   # change here if desired

# For both backends
topk                 = 10
query_batch_size     = 2048     # tune to your GPU (works with chunking, so flexible)
gallery_chunk_size   = 2048    # for MIXED: process gallery in chunks to avoid OOM
device               = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

# Splits mapping to your dataset
splits = {
    "train":      "train",
    "evaluation": "validation",
    "test":       "test",
}

# ─── helpers ─────────────────────────────────────────────────────────────────────────────
def _z_per_query(d_qg: torch.Tensor) -> torch.Tensor:
    """
    Z-score per query row, ignoring inf/nan. d_qg: [B, G] on CPU.
    Returns z-normalized distances (mean 0, std 1) per query.
    """
    x = d_qg
    mask = torch.isfinite(x)
    # mean/std per row
    row_cnt = mask.sum(dim=1, keepdim=True).clamp_min(1)
    row_sum = torch.where(mask, x, torch.zeros_like(x)).sum(dim=1, keepdim=True)
    mu = row_sum / row_cnt
    var = torch.where(mask, (x - mu)**2, torch.zeros_like(x)).sum(dim=1, keepdim=True) / row_cnt
    sd = var.clamp_min(0).sqrt().clamp_min(1e-6)
    z = (x - mu) / sd
    # keep inf/nan as large numbers so they rank last
    z = torch.where(mask, z, torch.full_like(z, 1e6))
    return z

def _compute_block_distance(metric: str, q_batch, g_batch):
    """
    Compute [B, G_chunk] distance block for a single metric, on GPU, then return CPU tensor.
    Uses the (B+G_chunk)^2 trick with compute_audio_distance, then slices the [B x G_chunk] submatrix.
    """
    # Flatten STFTs (whatever "stft" shape is, we keep consistent with your example)
    flat_q = q_batch["stft"].view(q_batch["stft"].shape[0], -1).to(device, non_blocking=True)
    flat_g = g_batch["stft"].view(g_batch["stft"].shape[0], -1).to(device, non_blocking=True)

    wav_q  = q_batch["wav"].to(device, non_blocking=True)
    wav_g  = g_batch["wav"].to(device, non_blocking=True)

    with torch.no_grad():
        all_flat = torch.cat([flat_q, flat_g], dim=0)     # [B+G, D]
        # align waveforms to same length per batch pair
        L = min(wav_q.shape[1], wav_g.shape[1]) if wav_q.ndim == 2 else min(wav_q.shape[-1], wav_g.shape[-1])
        all_wav  = torch.cat([wav_q[:, :L],  wav_g[:, :L]], dim=0)  # [B+G, L]
        D_full   = compute_audio_distance(all_flat, all_wav, metric=metric)  # [B+G, B+G]

        B = flat_q.shape[0]
        D_qg = D_full[:B, B:].detach().cpu()             # [B, G_chunk]
    # cleanup
    del flat_q, flat_g, wav_q, wav_g, all_flat, all_wav, D_full
    torch.cuda.empty_cache()
    return D_qg

def _mixed_distance_matrix(q_loader, g_ds, g_batch):
    """
    Compute MIXED distance matrix per query batch against full gallery in chunks.
    Returns a list of (batch_ids, D_mix_cpu) where D_mix_cpu is [B, G] CPU.
    """
    G = len(g_ds)
    g_ids = g_ds.ids
    results = []

    # Pre-split the gallery into chunks of size gallery_chunk_size
    g_slices = []
    start = 0
    while start < G:
        end = min(start + gallery_chunk_size, G)
        g_slices.append((start, end))
        start = end

    for batch_idx, q_batch in enumerate(tqdm(q_loader, desc="Query batches (MIXED)", leave=False)):
        B = q_batch["stft"].shape[0]
        # Collect z-normalized weighted sum across gallery chunks, per metric
        mix_acc = None

        for metric, w in MIXED_WEIGHTS:
            rows = []
            for (gs, ge) in g_slices:
                g_part = {k: v[gs:ge] for k, v in g_batch.items()}  # slice this gallery chunk
                D = _compute_block_distance(metric, q_batch, g_part)  # [B, ge-gs]
                rows.append(D)

            D_full = torch.cat(rows, dim=1)    # [B, G] CPU
            z = _z_per_query(D_full) * float(w)
            mix_acc = z if mix_acc is None else (mix_acc + z)

        results.append((q_loader.dataset.ids[batch_idx*query_batch_size : batch_idx*query_batch_size + B], mix_acc))
    return results

def _embedding_topk(q_loader, g_ds, Zg, model, grid_vec):
    """
    EMBEDDING backend: cosine similarity top-k using pose + global grid.
    Returns a dict mapping q_id -> list[g_ids].
    """
    refs_for_split = {}
    g_ids = g_ds.ids
    Zg_n = torch.nn.functional.normalize(Zg, p=2, dim=1)  # [G, D]

    for batch_idx, q_batch in enumerate(tq_loader := tqdm(q_loader, desc="Query batches (EMBEDDING)", leave=False)):
        bsize = q_batch["stft"].size(0)
        start = batch_idx * query_batch_size
        end   = start + bsize
        batch_ids = q_loader.dataset.ids[start:end]

        with torch.no_grad():
            Gq    = grid_vec.unsqueeze(0).expand(bsize, -1).to(device)
            mic_q = q_batch['mic_pose'].to(device)
            src_q = q_batch['source_pose'].to(device)
            rot_q = q_batch['rot'].to(device)
            Zq    = model(Gq, mic_q, src_q, rot_q)               # [B, D]
            Zq_n  = torch.nn.functional.normalize(Zq, p=2, dim=1)

            sims  = Zq_n @ Zg_n.T                                 # [B, G]
            # For each query row, sort sims descending (largest = closest)
            top_idxs = torch.argsort(sims, dim=1, descending=True)  # [B, G]

        # map each query ID -> list of top-k gallery IDs (w/ self-exclusion)
        for i, qid in enumerate(batch_ids):
            order = top_idxs[i].tolist()
            # self-exclusion if query also exists in reference
            # drop the first occurrence where g_id == qid
            filtered = [g_ids[j] for j in order if g_ids[j] != qid]
            refs_for_split[qid] = filtered[:topk]

        del Gq, mic_q, src_q, rot_q, Zq, Zq_n, sims, top_idxs
        torch.cuda.empty_cache()

    return refs_for_split

# ─── preload gallery ONCE ────────────────────────────────────────────────────────────────
g_ds     = FurnishedRoomSTFTDataset(root, split="reference", return_wav=True, mode="reference")
g_loader = DataLoader(g_ds, batch_size=len(g_ds), shuffle=False)
batch_g  = next(iter(g_loader))   # dict of tensors on CPU
G        = len(g_ds)

# For EMBEDDING, prep model + gallery embeddings
if RETRIEVAL_BACKEND.upper() == "EMBEDDING":
    from retriever import RIRRetrievalMLP
    ckpt  = torch.load(EMBEDDING_CKPT_PATH, map_location=device)
    model = RIRRetrievalMLP(**ckpt["model_config"]).to(device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    grid_vec = torch.load(GRID_VEC_PATH, map_location=device).to(device)
    with torch.no_grad():
        Gg    = grid_vec.unsqueeze(0).expand(G, -1).to(device)
        mic_g = batch_g['mic_pose'].to(device)
        src_g = batch_g['source_pose'].to(device)
        rot_g = batch_g['rot'].to(device)
        Zg    = model(Gg, mic_g, src_g, rot_g)  # [G, D] embeddings

    # free some refs (keep batch_g for MIXED if user flips)
    del Gg, mic_g, src_g, rot_g
    torch.cuda.empty_cache()

references = {}

# ─── process each split ─────────────────────────────────────────────────────────────────
for split_name, split_id in splits.items():
    q_ds     = FurnishedRoomSTFTDataset(root, split=split_id, return_wav=True, mode="reference")
    q_loader = DataLoader(q_ds, batch_size=query_batch_size, shuffle=False, drop_last=False)
    refs_for_split = {}

    if RETRIEVAL_BACKEND.upper() == "EMBEDDING":
        refs_for_split = _embedding_topk(q_loader, g_ds, Zg, model, grid_vec)

    else:
        # MIXED backend
        # compute per-metric distances in gallery chunks; z-norm per query; weight & sum
        mix_batches = _mixed_distance_matrix(q_loader, g_ds, batch_g)  # list of (batch_ids, [B, G])

        g_ids = g_ds.ids
        for (batch_ids, D_mix) in tqdm(mix_batches, desc=f"Assembling top-{topk} ({split_name})", leave=False):
            # For each query, argsort ascending (smaller distance = closer)
            order = torch.argsort(D_mix, dim=1)  # [B, G]

            for i, qid in enumerate(batch_ids):
                idxs = order[i].tolist()
                # self-exclusion: drop first occurrence where gallery id == qid
                filtered = [g_ids[j] for j in idxs if g_ids[j] != qid]
                refs_for_split[qid] = filtered[:topk]

    references[split_name] = refs_for_split

# ─── save to JSON (same shape as your example) ──────────────────────────────────────────
out_path = "references.json"
with open(out_path, "w") as f:
    json.dump(references, f, indent=2)

print(f"Wrote {out_path} with top-{topk} retrievals "
      f"(backend={RETRIEVAL_BACKEND}, weights={MIXED_WEIGHTS if RETRIEVAL_BACKEND.upper()=='MIXED' else 'cosine embedding'})")


In [None]:
import os, glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from torchaudio.transforms import GriffinLim
from evaluator import compute_audio_distance, compute_edc_db

# ---------------- Knobs ----------------
EVAL_PATTERN = "../eval_results/refine_best/renders/eval_*.npy"   # refine_best, baseline

TOPK = 1000                 # number of worst samples to plot; set None or 0 for ALL
METRIC = "C50"            # one of: "MSE","MAG","MAG2","SPL","EDC"
PLOT_BOTH = True          # True => 2 rows (mic + source). False => only one row controlled by POSE_MODE
POSE_MODE = "mic"         # used only if PLOT_BOTH=False

SAMPLE_RATE = 48000
# ISTFT params that match your STFT (513 bins -> n_fft=1024)
n_fft = (513 - 1) * 2
win_length = 512
hop_length = 256
power = 1.0

# ---------------- Helpers ----------------
def ensure_2d_stft(x):
    """Squeeze and ensure shape (F,T) as torch.float32."""
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    x = x.squeeze()
    if x.ndim != 2:
        raise ValueError(f"STFT must be 2D (F,T); got shape {tuple(x.shape)}")
    return x.float()

def ensure_1d_wav(x):
    """Squeeze and ensure 1D waveform as torch.float32."""
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    x = x.squeeze()
    if x.ndim != 1:
        raise ValueError(f"Waveform must be 1D; got shape {tuple(x.shape)}")
    return x.float()

def stft_pair(stft_gt_t: torch.Tensor, stft_pred_t: torch.Tensor) -> torch.Tensor:
    """Stack two STFTs into (2, F, T) after aligning T."""
    T = min(stft_gt_t.shape[1], stft_pred_t.shape[1])
    return torch.stack([stft_gt_t[:, :T], stft_pred_t[:, :T]], dim=0)

def reconstruct_pred_waveform_from_logmag(stft_pred_t: torch.Tensor) -> torch.Tensor:
    """
    Your eval files typically store log-magnitude STFT for pred.
    Following your reference: mag = exp(logmag) - 1e-3, then Griffin-Lim.
    Returns torch.FloatTensor 1D.
    """
    mag = torch.exp(stft_pred_t) - 1e-3
    istft = GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=power)
    wav_pred_t = istft(mag)  # (samples,)
    return wav_pred_t.float()

def compute_score(sample_dict, metric: str) -> float:
    """
    Compute scalar distance between prediction and GT using evaluator.
    - STFT-only metrics: MSE, MAG, MAG2
    - Waveform/EDC metrics: SPL, EDC (needs wav_pred)
    """
    stft_gt = ensure_2d_stft(sample_dict["data"])
    stft_pred = ensure_2d_stft(sample_dict["pred_stft"])

    pair_stft = stft_pair(stft_gt, stft_pred)

    wavs = None
    edc_curves = None
    mu = metric.upper()

    if mu in ("SPL", "EDC"):
        wav_gt = ensure_1d_wav(sample_dict["waveform"])
        wav_pred = reconstruct_pred_waveform_from_logmag(stft_pred)

        # length-align
        L = min(wav_gt.shape[0], wav_pred.shape[0])
        wavs = torch.stack([wav_gt[:L], wav_pred[:L]], dim=0)  # (2, L)

        if mu == "EDC":
            T_edc = 60  # typical target points; adjust if you prefer
            edc_gt = compute_edc_db(wavs[0], T_target=T_edc)
            edc_pr = compute_edc_db(wavs[1], T_target=T_edc)
            edc_curves = torch.stack([edc_gt, edc_pr], dim=0)  # (2, T_edc)

    # evaluator expects tensors
    D = compute_audio_distance(
        stft=pair_stft,
        wavs=wavs,
        edc_curves=edc_curves,
        metric=mu,
        fs=SAMPLE_RATE
    )
    return float(D[0, 1].item())

def collect_records(pattern, metric):
    files = sorted(glob.glob(pattern))
    recs = []
    for f in tqdm(files, desc="Scoring files", unit="file"):
        d = np.load(f, allow_pickle=True).item()
        score = compute_score(d, metric)
        mic = np.asarray(d["mic_pose"], dtype=float).reshape(3)
        src = np.asarray(d["source_pose"], dtype=float).reshape(3)
        recs.append({"score": score, "mic": mic, "src": src, "file": f})
    return recs

def pick_indices_by_topk(scores, topk):
    if (not topk) or (topk >= len(scores)):
        return np.arange(len(scores))
    # worst = largest distance
    return np.argsort(scores)[-topk:]

def _scatter_planes(axs_row, pts3, scores, title_prefix, vmin=None, vmax=None):
    """
    axs_row: 3 axes (XY, XZ, YZ)
    pts3: (N, 3)
    scores: (N,)
    """
    ax_xy, ax_xz, ax_yz = axs_row
    sc1 = ax_xy.scatter(pts3[:,0], pts3[:,1], c=scores, vmin=vmin, vmax=vmax)
    ax_xy.set_title(f"{title_prefix} XY"); ax_xy.set_xlabel("X"); ax_xy.set_ylabel("Y")

    sc2 = ax_xz.scatter(pts3[:,0], pts3[:,2], c=scores, vmin=vmin, vmax=vmax)
    ax_xz.set_title(f"{title_prefix} XZ"); ax_xz.set_xlabel("X"); ax_xz.set_ylabel("Z")

    sc3 = ax_yz.scatter(pts3[:,1], pts3[:,2], c=scores, vmin=vmin, vmax=vmax)
    ax_yz.set_title(f"{title_prefix} YZ"); ax_yz.set_xlabel("Y"); ax_yz.set_ylabel("Z")
    return sc1  # colorbar handle

# ---------------- Run ----------------
records = collect_records(EVAL_PATTERN, METRIC)
scores = np.array([r["score"] for r in records], dtype=float)

idx = pick_indices_by_topk(scores, TOPK)
scores_sel = scores[idx]
mic_pts = np.stack([records[i]["mic"] for i in idx], axis=0)
src_pts = np.stack([records[i]["src"] for i in idx], axis=0)

# Shared color scale so rows are comparable
vmin, vmax = float(scores_sel.min()), float(scores_sel.max())

if PLOT_BOTH:
    fig, axs = plt.subplots(2, 3, figsize=(15, 10), constrained_layout=True)
    sc_mic = _scatter_planes(axs[0], mic_pts, scores_sel, f"MIC (Top-{TOPK or 'All'})", vmin=vmin, vmax=vmax)
    sc_src = _scatter_planes(axs[1], src_pts, scores_sel, f"SRC (Top-{TOPK or 'All'})", vmin=vmin, vmax=vmax)

    cbar = fig.colorbar(sc_mic, ax=axs, orientation="horizontal", fraction=0.04, pad=0.08)
    cbar.set_label(f"{METRIC} distance (higher = worse)")
    fig.suptitle(f"Quality by Pose — metric={METRIC}, top-k={TOPK or 'All'}", y=0.995)
    plt.show()
else:
    which = "mic" if POSE_MODE.lower().startswith("m") else "src"
    pts = mic_pts if which == "mic" else src_pts
    fig, axs = plt.subplots(1, 3, figsize=(15, 4), constrained_layout=True)
    sc = _scatter_planes(axs, pts, scores_sel, f"{which.upper()} (Top-{TOPK or 'All'})", vmin=vmin, vmax=vmax)
    cbar = fig.colorbar(sc, ax=axs, orientation="horizontal", fraction=0.04, pad=0.08)
    cbar.set_label(f"{METRIC} distance (higher = worse)")
    fig.suptitle(f"{which.upper()} poses — metric={METRIC}, top-k={TOPK or 'All'}", y=0.995)
    plt.show()

In [None]:
# === Paper-style table: NeRAF (Baseline) vs Ours vs Gain (EDT, T60, C50, ΔSTFT) ===
# - 3 rows × 4 cols, plus a thin colorbar column aligned per row
# - Row colorbars: Baseline (viridis swatch), Ours (viridis swatch), Gain (blue→grey→red swatch)
# - Gain: per-column normalization/band so no “all-grey” surprises
# - No per-axis ticks; X/Y shown once (outer labels)
# - Plots X vs Z but labels vertical axis as “Y”
# - GPU-batched GLIM; uses your compute_audio_distance

import os, glob, numpy as np, torch, matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import gridspec
from tqdm.auto import tqdm
from torchaudio.transforms import GriffinLim

from evaluator import compute_audio_distance

# ------------------- Knobs -------------------
BASELINE_PATTERN = "../eval_results/baseline/renders/eval_*.npy"
REFINE_PATTERN   = "../eval_results/refine_best/renders/eval_*.npy"

MAX_FILES   = None          # e.g., 200 for quick tests; None/0 => all
POSE_MODE   = "mic"         # "mic" or "src" (we plot X vs Z but call it Y)
COLOR_MODE  = "spectrum"    # "spectrum" or "discrete" for Gain
NEUTRAL_BAND = "auto"       # "auto" => 5% of each column's max |Δ|; or float

SAMPLE_RATE = 48000
n_fft       = (513 - 1) * 2
win_length  = 512
hop_length  = 256
power       = 1.0
BATCH_GLIM  = 128

COL_METRICS = ["EDT", "T60", "C50", "SPL"]   # SPL will be labeled as ΔSTFT
COL_TITLES  = {"EDT": "EDT", "T60": "T60", "C50": "C50", "SPL": "ΔSTFT"}

plt.rcParams.update({
    "font.family": "DejaVu Serif",
    "font.size": 11,
    "axes.titlesize": 12,
})

COLORS_3 = ListedColormap(["#e41a1c", "#9e9e9e", "#377eb8"])  # red/grey/blue
SPECTRUM_COLORS = ["#e41a1c", "#9e9e9e", "#377eb8"]

device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
torch.backends.cudnn.benchmark = True

# ---------------- Helpers ----------------
def _load_aligned(bpat, rpat, max_files=None):
    bfiles = sorted(glob.glob(bpat))
    rfiles = sorted(glob.glob(rpat))
    bmap = {os.path.basename(f): f for f in bfiles}
    rmap = {os.path.basename(f): f for f in rfiles}
    keys = sorted(set(bmap) & set(rmap))
    if not keys:
        raise RuntimeError("No overlapping eval_*.npy files.")
    if max_files and max_files > 0:
        keys = keys[:max_files]
    return keys, bmap, rmap

def _ensure_2d_stft(x):
    t = torch.from_numpy(x).float().squeeze()
    if t.ndim != 2: raise ValueError(f"STFT must be 2D (F,T); got {tuple(t.shape)}")
    return t

def _ensure_1d_wav(x):
    t = torch.from_numpy(x).float().squeeze()
    if t.ndim != 1: raise ValueError(f"Waveform must be 1D; got {tuple(t.shape)}")
    return t

def _batch_griffin_lim(mag_batch, istft):
    with torch.cuda.amp.autocast(enabled=use_amp), torch.no_grad():
        wav = istft(mag_batch)
    return wav

def _batched_logmag_to_wav(stft_list, istft, batch=BATCH_GLIM):
    out = [None]*len(stft_list)
    for s in tqdm(range(0, len(stft_list), batch), desc="GLIM batches", unit="batch"):
        e = min(s+batch, len(stft_list))
        Ts = [p.shape[1] for p in stft_list[s:e]]
        Tm = min(Ts)
        mags = torch.stack([torch.exp(p[:, :Tm]) - 1e-3 for p in stft_list[s:e]], dim=0).to(device, non_blocking=True)
        wav_b = _batch_griffin_lim(mags, istft)
        for i, w in enumerate(wav_b.detach().cpu()):
            out[s+i] = w.contiguous()
        del mags, wav_b
        torch.cuda.empty_cache()
    return out

def _pair_score(metric, stft_gt, wav_gt, stft_pred, wav_pred):
    mu = metric.upper()
    L = min(len(wav_gt), len(wav_pred))
    wavs = torch.stack([wav_gt[:L], wav_pred[:L]], dim=0)
    T = min(stft_gt.shape[1], stft_pred.shape[1])
    stft_pair = torch.stack([stft_gt[:, :T], stft_pred[:, :T]], dim=0)
    with torch.no_grad():
        D = compute_audio_distance(stft_pair, wavs=wavs, metric=mu, fs=SAMPLE_RATE)
    return float(D[0,1].item())

def _auto_band(delta_vals):
    m = float(np.nanmax(np.abs(delta_vals))) if len(delta_vals) else 1.0
    return max(1e-9, 0.05 * m) if NEUTRAL_BAND == "auto" else float(NEUTRAL_BAND)

def _norm_discrete(delta_vals):
    band = _auto_band(delta_vals)
    bounds = [-np.inf, -band, band, np.inf]
    norm = BoundaryNorm(bounds, COLORS_3.N)
    return COLORS_3, norm, band

def _norm_spectrum(delta_vals):
    m = float(np.nanmax(np.abs(delta_vals))) if len(delta_vals) else 1.0
    band = _auto_band(delta_vals)
    cmap = mcolors.LinearSegmentedColormap.from_list("red_grey_blue", [
        (0.0,  SPECTRUM_COLORS[0]),
        (0.5,  SPECTRUM_COLORS[1]),
        (1.0,  SPECTRUM_COLORS[2]),
    ])
    class NeutralTwoSlopeNorm(mcolors.Normalize):
        def __init__(self, vmin, vcenter, vmax, band):
            super().__init__(vmin, vmax); self.vcenter = vcenter; self.band = band
        def __call__(self, value, clip=None):
            x = np.ma.asarray(value); res = np.empty_like(x, dtype=float); res.fill(np.nan)
            neg = (x <= -self.band); res[neg] = 0.5 * (x[neg] - self.vmin) / (-self.band - self.vmin)
            neut = (np.abs(x) < self.band); res[neut] = 0.5
            pos = (x >= self.band); res[pos] = 0.5 + 0.5*(x[pos] - self.band)/(self.vmax - self.band)
            if clip: res = np.clip(res, 0.0, 1.0)
            return res
    norm = NeutralTwoSlopeNorm(vmin=-m, vcenter=0.0, vmax=m, band=band)
    return cmap, norm, band

def _layered_scatter(ax, x, y, vals, cmap, norm, band, s_grey=9, s_color=11):
    vals = np.asarray(vals)
    mg = np.abs(vals) < band
    mr = vals >= band
    mb = vals <= -band
    if mg.any():
        ax.scatter(x[mg], y[mg], c=vals[mg], cmap=cmap, norm=norm,
                   s=s_grey, linewidths=0, alpha=0.9, zorder=1)
    if mr.any():
        ax.scatter(x[mr], y[mr], c=vals[mr], cmap=cmap, norm=norm,
                   s=s_color, edgecolors='none', alpha=1.0, zorder=3)
    if mb.any():
        ax.scatter(x[mb], y[mb], c=vals[mb], cmap=cmap, norm=norm,
                   s=s_color, edgecolors='none', alpha=1.0, zorder=3)

# ---------------- Main ----------------
def main():
    keys, bmap, rmap = _load_aligned(BASELINE_PATTERN, REFINE_PATTERN, max_files=MAX_FILES)

    stft_gt_list, wav_gt_list = [], []
    stft_b_list,  stft_r_list = [], []
    poses = np.zeros((len(keys), 3), dtype=np.float32)
    for i, k in enumerate(tqdm(keys, desc="Loading eval dicts", unit="file")):
        d_b = np.load(bmap[k], allow_pickle=True).item()
        d_r = np.load(rmap[k], allow_pickle=True).item()
        stft_gt_list.append(_ensure_2d_stft(d_b["data"]))
        wav_gt_list.append(_ensure_1d_wav(d_b["waveform"]))
        stft_b_list.append(_ensure_2d_stft(d_b["pred_stft"]))
        stft_r_list.append(_ensure_2d_stft(d_r["pred_stft"]))
        poses[i] = np.asarray(d_b["source_pose" if POSE_MODE.lower().startswith("s") else "mic_pose"],
                              dtype=float).reshape(3)

    istft = GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=power).to(device)
    wav_b_list = _batched_logmag_to_wav(stft_b_list, istft, batch=BATCH_GLIM)
    wav_r_list = _batched_logmag_to_wav(stft_r_list, istft, batch=BATCH_GLIM)

    base = {m: np.zeros(len(keys), dtype=np.float32) for m in COL_METRICS}
    ours = {m: np.zeros(len(keys), dtype=np.float32) for m in COL_METRICS}
    for m in COL_METRICS:
        for i in tqdm(range(len(keys)), desc=f"Scoring {m}", unit="sample", leave=False):
            base[m][i] = _pair_score(m, stft_gt_list[i], wav_gt_list[i], stft_b_list[i], wav_b_list[i])
            ours[m][i] = _pair_score(m, stft_gt_list[i], wav_gt_list[i], stft_r_list[i], wav_r_list[i])

    # === Figure ===
    # 4 data columns + 1 thin colorbar column; thin CB aligned per row
    fig = plt.figure(figsize=(12, 7.6))
    gs = gridspec.GridSpec(
        nrows=3, ncols=5, figure=fig,
        left=0.07, right=0.96, top=0.92, bottom=0.10,
        wspace=0.18, hspace=0.38,
        width_ratios=[1, 1, 1, 1, 0.035]
    )

    # axes grid (3x4) + row colorbar axes in last column
    axes = np.array([[fig.add_subplot(gs[r, c]) for c in range(4)] for r in range(3)])
    cbar_axes = [fig.add_subplot(gs[r, 4]) for r in range(3)]  # row-aligned, thin

    # column headers once
    for c, m in enumerate(COL_METRICS):
        axes[0, c].set_title(COL_TITLES.get(m, m), pad=6)

    # row labels as ylabels (keeps left margin tight)
    axes[0, 0].set_ylabel("Baseline (NeRAF)", rotation=90, labelpad=28, fontsize=12, weight="bold")
    axes[1, 0].set_ylabel("Ours",             rotation=90, labelpad=28, fontsize=12, weight="bold")
    axes[2, 0].set_ylabel("Gain",             rotation=90, labelpad=28, fontsize=12, weight="bold")

    # outer labels only
    fig.supxlabel("X", y=0.06, fontsize=12)
    fig.supylabel("Y", x=0.055, fontsize=12)

    X = poses[:, 0]
    Y = poses[:, 2]  # plotted from Z, labeled as Y

    viridis_cmap = plt.cm.get_cmap("viridis")

    # ----- Baseline & Ours rows -----
    # per-column vmin/vmax (units differ); qualitative row colorbars (no ticks)
    for c, m in enumerate(COL_METRICS):
        vmin = float(min(base[m].min(), ours[m].min()))
        vmax = float(max(base[m].max(), ours[m].max()))
        axes[0, c].scatter(X, Y, c=base[m], cmap=viridis_cmap, vmin=vmin, vmax=vmax, s=8, linewidths=0)
        axes[1, c].scatter(X, Y, c=ours[m], cmap=viridis_cmap, vmin=vmin, vmax=vmax, s=8, linewidths=0)

    # Baseline row colorbar (thin, aligned)
    sm_row = plt.cm.ScalarMappable(norm=mcolors.Normalize(0, 1), cmap=viridis_cmap)
    cb1 = fig.colorbar(sm_row, cax=cbar_axes[0])
    # cb1.set_label("lower  →  higher", labelpad=4)

    # Ours row colorbar (thin, aligned)
    cb2 = fig.colorbar(sm_row, cax=cbar_axes[1])
    cb2.set_label("lower  is better", labelpad=4)

    # ----- Gain row -----
    # Per-column normalization/band so one metric with large Δ doesn't neutralize others.
    # We still show a single qualitative swatch for the row (no ticks).
    # Build a “representative” swatch mapper from the first column’s band/range.
    if COLOR_MODE.lower().startswith("disc"):
        get_norm = _norm_discrete
    else:
        get_norm = _norm_spectrum

    first_sm = None
    for c, m in enumerate(COL_METRICS):
        delta = ours[m] - base[m]
        cmap_c, norm_c, band_c = get_norm(delta)
        _layered_scatter(axes[2, c], X, Y, delta, cmap_c, norm_c, band_c, s_grey=9, s_color=11)
        if first_sm is None:
            first_sm = plt.cm.ScalarMappable(norm=norm_c, cmap=cmap_c)

    cb3 = fig.colorbar(first_sm, cax=cbar_axes[2])
    # cb3.set_label("Δ (negative = Retrieval Augmented better)", labelpad=4)

    # Clean panels: no ticks, subtle frames
    for r in range(3):
        for c in range(4):
            ax = axes[r, c]
            ax.set_xticks([]); ax.set_yticks([])
            for sp in ax.spines.values():
                sp.set_linewidth(0.6); sp.set_alpha(0.6)

    fig.suptitle(f"Columns: EDT, T60, C50, ΔSTFT   |   Pose = {POSE_MODE.upper()}",
                 y=0.975, fontsize=12)
    plt.show()

if __name__ == "__main__":
    main()


In [None]:
# === Paper-style table: Baseline vs Ours (two designs) vs two Gain rows ===
# Rows (top→bottom):
#   1) Baseline (NeRAF)
#   2) Ours (Design 1: feature fusion)        <-- REFINE_PATTERN
#   3) Ours (Design 2: output modification)   <-- STFT_PATTERN
#   4) Gain (Design 1 - Baseline)
#   5) Gain (Design 2 - Baseline)
#
# Notes:
# - 4 metric columns (EDT, T60, C50, ΔSTFT) + 1 thin colorbar column per row
# - Per-column scaling for each data row; gain rows use blue↔grey↔red with neutral band
# - No per-axis ticks; outer X/Y only; X vs Z plotted, Y label says “Y” (because traditions)
# - GPU-batched GLIM; uses compute_audio_distance

import os, glob, numpy as np, torch, matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import gridspec
from tqdm.auto import tqdm
from torchaudio.transforms import GriffinLim

from evaluator import compute_audio_distance

# ------------------- Knobs -------------------
BASELINE_PATTERN = "../eval_results/baseline/renders/eval_*.npy"
REFINE_PATTERN   = "../eval_results/refine_best/renders/eval_*.npy"  # Design 1 (feature-fusion)
STFT_PATTERN     = "../eval_results/stft_no_edc/renders/eval_*.npy"    # Design 2 (output-mod)

MAX_FILES   = None          # e.g., 200 for quick tests; None/0 => all
POSE_MODE   = "mic"         # "mic" or "src" (we plot X vs Z but call it Y)
COLOR_MODE  = "spectrum"    # "spectrum" or "discrete" for Gain rows
NEUTRAL_BAND = "auto"       # "auto" => 5% of each column's max |Δ|; or float

SAMPLE_RATE = 48000
n_fft       = (513 - 1) * 2
win_length  = 512
hop_length  = 256
power       = 1.0
BATCH_GLIM  = 128

COL_METRICS = ["EDT", "T60", "C50", "SPL"]   # SPL will be labeled as ΔSTFT
COL_TITLES  = {"EDT": "EDT", "T60": "T60", "C50": "C50", "SPL": "ΔSTFT"}

plt.rcParams.update({
    "font.family": "DejaVu Serif",
    "font.size": 11,
    "axes.titlesize": 12,
})

COLORS_3 = ListedColormap(["#e41a1c", "#9e9e9e", "#377eb8"])  # red/grey/blue
SPECTRUM_COLORS = ["#e41a1c", "#9e9e9e", "#377eb8"]

device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
torch.backends.cudnn.benchmark = True

# ---------------- Helpers ----------------
def _load_aligned_three(bpat, r1pat, r2pat, max_files=None):
    bfiles = sorted(glob.glob(bpat))
    r1files = sorted(glob.glob(r1pat))
    r2files = sorted(glob.glob(r2pat))
    bmap  = {os.path.basename(f): f for f in bfiles}
    r1map = {os.path.basename(f): f for f in r1files}
    r2map = {os.path.basename(f): f for f in r2files}
    keys = sorted(set(bmap) & set(r1map) & set(r2map))
    if not keys:
        raise RuntimeError("No overlapping eval_*.npy files across all three patterns.")
    if max_files and max_files > 0:
        keys = keys[:max_files]
    return keys, bmap, r1map, r2map

def _ensure_2d_stft(x):
    t = torch.from_numpy(x).float().squeeze()
    if t.ndim != 2: raise ValueError(f"STFT must be 2D (F,T); got {tuple(t.shape)}")
    return t

def _ensure_1d_wav(x):
    t = torch.from_numpy(x).float().squeeze()
    if t.ndim != 1: raise ValueError(f"Waveform must be 1D; got {tuple(t.shape)}")
    return t

def _batch_griffin_lim(mag_batch, istft):
    with torch.cuda.amp.autocast(enabled=use_amp), torch.no_grad():
        wav = istft(mag_batch)
    return wav

def _batched_logmag_to_wav(stft_list, istft, batch=BATCH_GLIM):
    out = [None]*len(stft_list)
    for s in tqdm(range(0, len(stft_list), batch), desc="GLIM batches", unit="batch"):
        e = min(s+batch, len(stft_list))
        Ts = [p.shape[1] for p in stft_list[s:e]]
        Tm = min(Ts)
        mags = torch.stack([torch.exp(p[:, :Tm]) - 1e-3 for p in stft_list[s:e]], dim=0).to(device, non_blocking=True)
        wav_b = _batch_griffin_lim(mags, istft)
        for i, w in enumerate(wav_b.detach().cpu()):
            out[s+i] = w.contiguous()
        del mags, wav_b
        torch.cuda.empty_cache()
    return out

def _pair_score(metric, stft_gt, wav_gt, stft_pred, wav_pred):
    mu = metric.upper()
    L = min(len(wav_gt), len(wav_pred))
    wavs = torch.stack([wav_gt[:L], wav_pred[:L]], dim=0)
    T = min(stft_gt.shape[1], stft_pred.shape[1])
    stft_pair = torch.stack([stft_gt[:, :T], stft_pred[:, :T]], dim=0)
    with torch.no_grad():
        D = compute_audio_distance(stft_pair, wavs=wavs, metric=mu, fs=SAMPLE_RATE)
    return float(D[0,1].item())

def _auto_band(delta_vals):
    m = float(np.nanmax(np.abs(delta_vals))) if len(delta_vals) else 1.0
    return max(1e-9, 0.05 * m) if NEUTRAL_BAND == "auto" else float(NEUTRAL_BAND)

def _norm_discrete(delta_vals):
    band = _auto_band(delta_vals)
    bounds = [-np.inf, -band, band, np.inf]
    norm = BoundaryNorm(bounds, COLORS_3.N)
    return COLORS_3, norm, band

def _norm_spectrum(delta_vals):
    m = float(np.nanmax(np.abs(delta_vals))) if len(delta_vals) else 1.0
    band = _auto_band(delta_vals)
    cmap = mcolors.LinearSegmentedColormap.from_list("red_grey_blue", [
        (0.0,  SPECTRUM_COLORS[0]),
        (0.5,  SPECTRUM_COLORS[1]),
        (1.0,  SPECTRUM_COLORS[2]),
    ])
    class NeutralTwoSlopeNorm(mcolors.Normalize):
        def __init__(self, vmin, vcenter, vmax, band):
            super().__init__(vmin, vmax); self.vcenter = vcenter; self.band = band
        def __call__(self, value, clip=None):
            x = np.ma.asarray(value); res = np.empty_like(x, dtype=float); res.fill(np.nan)
            neg = (x <= -self.band); res[neg] = 0.5 * (x[neg] - self.vmin) / (-self.band - self.vmin)
            neut = (np.abs(x) < self.band); res[neut] = 0.5
            pos = (x >= self.band); res[pos] = 0.5 + 0.5*(x[pos] - self.band)/(self.vmax - self.band)
            if clip: res = np.clip(res, 0.0, 1.0)
            return res
    norm = NeutralTwoSlopeNorm(vmin=-m, vcenter=0.0, vmax=m, band=band)
    return cmap, norm, band

def _layered_scatter(ax, x, y, vals, cmap, norm, band, s_grey=9, s_color=11):
    vals = np.asarray(vals)
    mg = np.abs(vals) < band
    mr = vals >= band
    mb = vals <= -band
    if mg.any():
        ax.scatter(x[mg], y[mg], c=vals[mg], cmap=cmap, norm=norm,
                   s=s_grey, linewidths=0, alpha=0.9, zorder=1)
    if mr.any():
        ax.scatter(x[mr], y[mr], c=vals[mr], cmap=cmap, norm=norm,
                   s=s_color, edgecolors='none', alpha=1.0, zorder=3)
    if mb.any():
        ax.scatter(x[mb], y[mb], c=vals[mb], cmap=cmap, norm=norm,
                   s=s_color, edgecolors='none', alpha=1.0, zorder=3)

# ---------------- Main ----------------
def main():
    keys, bmap, r1map, r2map = _load_aligned_three(BASELINE_PATTERN, REFINE_PATTERN, STFT_PATTERN, max_files=MAX_FILES)

    stft_gt_list, wav_gt_list = [], []
    stft_b_list,  stft_r1_list, stft_r2_list = [], [], []
    poses = np.zeros((len(keys), 3), dtype=np.float32)

    for i, k in enumerate(tqdm(keys, desc="Loading eval dicts", unit="file")):
        d_b  = np.load(bmap[k],  allow_pickle=True).item()
        d_r1 = np.load(r1map[k], allow_pickle=True).item()
        d_r2 = np.load(r2map[k], allow_pickle=True).item()

        stft_gt_list.append(_ensure_2d_stft(d_b["data"]))
        wav_gt_list.append(_ensure_1d_wav(d_b["waveform"]))
        stft_b_list.append(_ensure_2d_stft(d_b["pred_stft"]))
        stft_r1_list.append(_ensure_2d_stft(d_r1["pred_stft"]))  # Design 1
        stft_r2_list.append(_ensure_2d_stft(d_r2["pred_stft"]))  # Design 2

        poses[i] = np.asarray(d_b["source_pose" if POSE_MODE.lower().startswith("s") else "mic_pose"],
                              dtype=float).reshape(3)

    istft = GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=power).to(device)

    wav_b_list  = _batched_logmag_to_wav(stft_b_list,  istft, batch=BATCH_GLIM)
    wav_r1_list = _batched_logmag_to_wav(stft_r1_list, istft, batch=BATCH_GLIM)
    wav_r2_list = _batched_logmag_to_wav(stft_r2_list, istft, batch=BATCH_GLIM)

    base = {m: np.zeros(len(keys), dtype=np.float32) for m in COL_METRICS}
    ours1 = {m: np.zeros(len(keys), dtype=np.float32) for m in COL_METRICS}
    ours2 = {m: np.zeros(len(keys), dtype=np.float32) for m in COL_METRICS}

    for m in COL_METRICS:
        for i in tqdm(range(len(keys)), desc=f"Scoring {m}", unit="sample", leave=False):
            base[m][i]  = _pair_score(m, stft_gt_list[i], wav_gt_list[i], stft_b_list[i],  wav_b_list[i])
            ours1[m][i] = _pair_score(m, stft_gt_list[i], wav_gt_list[i], stft_r1_list[i], wav_r1_list[i])
            ours2[m][i] = _pair_score(m, stft_gt_list[i], wav_gt_list[i], stft_r2_list[i], wav_r2_list[i])

    # === Figure ===
    # 4 data columns + 1 thin colorbar column; thin CB aligned per row
    NROWS = 5
    fig = plt.figure(figsize=(12, 10.4))
    gs = gridspec.GridSpec(
        nrows=NROWS, ncols=5, figure=fig,
        left=0.07, right=0.96, top=0.94, bottom=0.08,
        wspace=0.18, hspace=0.35,
        width_ratios=[1, 1, 1, 1, 0.035]
    )

    # axes grid + row colorbar axes in last column
    axes = np.array([[fig.add_subplot(gs[r, c]) for c in range(4)] for r in range(NROWS)])
    cbar_axes = [fig.add_subplot(gs[r, 4]) for r in range(NROWS)]  # row-aligned, thin

    # column headers once
    for c, m in enumerate(COL_METRICS):
        axes[0, c].set_title(COL_TITLES.get(m, m), pad=6)

    # row labels (bold)
    axes[0, 0].set_ylabel("Baseline (NeRAF)", rotation=90, labelpad=28, fontsize=12, weight="bold")
    axes[1, 0].set_ylabel("Ours (Design 1:\nfeature fusion)", rotation=90, labelpad=28, fontsize=12, weight="bold")
    axes[2, 0].set_ylabel("Ours (Design 2:\noutput modification)", rotation=90, labelpad=28, fontsize=12, weight="bold")
    axes[3, 0].set_ylabel("Gain (D1 − Base)", rotation=90, labelpad=28, fontsize=12, weight="bold")
    axes[4, 0].set_ylabel("Gain (D2 − Base)", rotation=90, labelpad=28, fontsize=12, weight="bold")

    # outer labels only
    fig.supxlabel("X", y=0.05, fontsize=12)
    fig.supylabel("Y", x=0.055, fontsize=12)

    X = poses[:, 0]
    Y = poses[:, 2]  # plotted from Z, labeled as Y

    viridis_cmap = plt.cm.get_cmap("viridis")

    # ----- Baseline & Ours rows -----
    # Per-column vmin/vmax (units differ); qualitative row colorbars (no ticks)
    for c, m in enumerate(COL_METRICS):
        vmin = float(min(base[m].min(), ours1[m].min(), ours2[m].min()))
        vmax = float(max(base[m].max(), ours1[m].max(), ours2[m].max()))
        axes[0, c].scatter(X, Y, c=base[m],  cmap=viridis_cmap, vmin=vmin, vmax=vmax, s=8, linewidths=0)
        axes[1, c].scatter(X, Y, c=ours1[m], cmap=viridis_cmap, vmin=vmin, vmax=vmax, s=8, linewidths=0)
        axes[2, c].scatter(X, Y, c=ours2[m], cmap=viridis_cmap, vmin=vmin, vmax=vmax, s=8, linewidths=0)

    sm_row = plt.cm.ScalarMappable(norm=mcolors.Normalize(0, 1), cmap=viridis_cmap)
    fig.colorbar(sm_row, cax=cbar_axes[0])
    fig.colorbar(sm_row, cax=cbar_axes[1])
    cb_ours2 = fig.colorbar(sm_row, cax=cbar_axes[2])
    cb_ours2.set_label("lower is better", labelpad=4)

    # ----- Gain rows -----
    if COLOR_MODE.lower().startswith("disc"):
        get_norm = _norm_discrete
    else:
        get_norm = _norm_spectrum

    # Row 4: Gain D1 (ours1 − base)
    first_sm_d1 = None
    for c, m in enumerate(COL_METRICS):
        delta1 = ours1[m] - base[m]
        cmap_c, norm_c, band_c = get_norm(delta1)
        _layered_scatter(axes[3, c], X, Y, delta1, cmap_c, norm_c, band_c, s_grey=9, s_color=11)
        if first_sm_d1 is None:
            first_sm_d1 = plt.cm.ScalarMappable(norm=norm_c, cmap=cmap_c)
    fig.colorbar(first_sm_d1, cax=cbar_axes[3])

    # Row 5: Gain D2 (ours2 − base)
    first_sm_d2 = None
    for c, m in enumerate(COL_METRICS):
        delta2 = ours2[m] - base[m]
        cmap_c, norm_c, band_c = get_norm(delta2)
        _layered_scatter(axes[4, c], X, Y, delta2, cmap_c, norm_c, band_c, s_grey=9, s_color=11)
        if first_sm_d2 is None:
            first_sm_d2 = plt.cm.ScalarMappable(norm=norm_c, cmap=cmap_c)
    fig.colorbar(first_sm_d2, cax=cbar_axes[4])

    # Clean panels: no ticks, subtle frames
    for r in range(NROWS):
        for c in range(4):
            ax = axes[r, c]
            ax.set_xticks([]); ax.set_yticks([])
            for sp in ax.spines.values():
                sp.set_linewidth(0.6); sp.set_alpha(0.6)

    fig.suptitle(
        "Columns: EDT, T60, C50, ΔSTFT   |   Pose = "
        f"{POSE_MODE.upper()}   |   Δ rows: negative = better than Baseline",
        y=0.985, fontsize=12
    )
    plt.show()

if __name__ == "__main__":
    main()


In [None]:
# === Baseline-only pose cartography across views (XY, YZ, XZ) ===
# Rows: ΔSTFT, EDC, T60, C50, EDT   (5)
# Cols: views (XY, YZ, XZ) + 1 thin colorbar column per row   (3 + 1)
# - Per-row normalization (so all 3 views of a metric share the same scale)
# - Viridis for all rows (baseline values); thin row-aligned colorbars
# - No ticks; subtle frames; compact, paper-ready
# - Built-in print-label swap: xy→xz, xz→xy, yz→zy (labels only; coordinates are correct)

import os, glob, numpy as np, torch, matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib import gridspec
from tqdm.auto import tqdm
from torchaudio.transforms import GriffinLim

from evaluator import compute_audio_distance

# ------------------- Knobs -------------------
BASELINE_PATTERN = "../eval_results/emptyroom/emptyroom/renders/eval_*.npy"
MAX_FILES    = None
POSE_MODE    = "mic"          # "mic" or "src"
PRINT_LABELS = True           # swap labels for printing: xy→xz, xz→xy, yz→zy
SAMPLE_RATE  = 48000
n_fft        = (513 - 1) * 2
win_length   = 512
hop_length   = 256
power        = 1.0
BATCH_GLIM   = 128

ROW_METRICS  = ["SPL", "EDC", "T60", "C50", "EDT"]   # SPL ≡ ΔSTFT
ROW_TITLES   = {"SPL":"ΔSTFT", "EDC":"EDC", "T60":"T60", "C50":"C50", "EDT":"EDT"}

plt.rcParams.update({
    "font.family": "DejaVu Serif",
    "font.size": 11,
    "axes.titlesize": 12,
})

device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
torch.backends.cudnn.benchmark = True

# ---------------- Helpers ----------------
def _load_baseline(bpat, max_files=None):
    bfiles = sorted(glob.glob(bpat))
    if not bfiles:
        raise RuntimeError("No eval_*.npy files found for baseline.")
    if max_files and max_files > 0:
        bfiles = bfiles[:max_files]
    return bfiles

def _ensure_2d_stft(x):
    t = torch.from_numpy(x).float().squeeze()
    if t.ndim != 2: raise ValueError(f"STFT must be 2D (F,T); got {tuple(t.shape)}")
    return t

def _ensure_1d_wav(x):
    t = torch.from_numpy(x).float().squeeze()
    if t.ndim != 1: raise ValueError(f"Waveform must be 1D; got {tuple(t.shape)}")
    return t

def _batch_griffin_lim(mag_batch, istft):
    with torch.cuda.amp.autocast(enabled=use_amp), torch.no_grad():
        wav = istft(mag_batch)
    return wav

def _batched_logmag_to_wav(stft_list, istft, batch=BATCH_GLIM):
    out = [None]*len(stft_list)
    for s in tqdm(range(0, len(stft_list), batch), desc="GLIM batches", unit="batch"):
        e = min(s+batch, len(stft_list))
        Ts = [p.shape[1] for p in stft_list[s:e]]
        Tm = min(Ts)
        mags = torch.stack([torch.exp(p[:, :Tm]) - 1e-3 for p in stft_list[s:e]], dim=0).to(device, non_blocking=True)
        wav_b = _batch_griffin_lim(mags, istft)
        for i, w in enumerate(wav_b.detach().cpu()):
            out[s+i] = w.contiguous()
        del mags, wav_b
        torch.cuda.empty_cache()
    return out

def _pair_score(metric, stft_gt, wav_gt, stft_pred, wav_pred):
    mu = metric.upper()
    L = min(len(wav_gt), len(wav_pred))
    wavs = torch.stack([wav_gt[:L], wav_pred[:L]], dim=0)
    T = min(stft_gt.shape[1], stft_pred.shape[1])
    stft_pair = torch.stack([stft_gt[:, :T], stft_pred[:, :T]], dim=0)
    with torch.no_grad():
        D = compute_audio_distance(stft_pair, wavs=wavs, metric=mu, fs=SAMPLE_RATE)
    return float(D[0,1].item())

def _view_coords(poses):
    X = poses[:, 0]; Y = poses[:, 1]; Z = poses[:, 2]
    return {
        "xy": (X, Y),
        "yz": (Y, Z),
        "xz": (X, Z),
    }

def _print_name(name):
    if not PRINT_LABELS: return name.upper()
    swap = {"xy":"XZ", "xz":"XY", "yz":"ZY"}  # label swap only
    return swap.get(name.lower(), name.upper())

# ---------------- Main ----------------
def main():
    bfiles = _load_baseline(BASELINE_PATTERN, max_files=MAX_FILES)

    stft_gt_list, wav_gt_list = [], []
    stft_b_list = []
    poses = np.zeros((len(bfiles), 3), dtype=np.float32)

    for i, f in enumerate(tqdm(bfiles, desc="Loading eval dicts", unit="file")):
        d = np.load(f, allow_pickle=True).item()
        stft_gt_list.append(_ensure_2d_stft(d["data"]))
        wav_gt_list.append(_ensure_1d_wav(d["waveform"]))
        stft_b_list.append(_ensure_2d_stft(d["pred_stft"]))
        poses[i] = np.asarray(d["source_pose" if POSE_MODE.lower().startswith("s") else "mic_pose"],
                              dtype=float).reshape(3)

    istft = GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=power).to(device)
    wav_b_list = _batched_logmag_to_wav(stft_b_list, istft, batch=BATCH_GLIM)

    # Score baseline for all metrics
    base = {m: np.zeros(len(bfiles), dtype=np.float32) for m in ROW_METRICS}
    for m in ROW_METRICS:
        for i in tqdm(range(len(bfiles)), desc=f"Scoring {m}", unit="sample", leave=False):
            base[m][i] = _pair_score(m, stft_gt_list[i], wav_gt_list[i], stft_b_list[i], wav_b_list[i])

    # === Figure: 5 rows × (3 views + 1 colorbar) ===
    VIEWS = ["xy", "yz", "xz"]
    coords = _view_coords(poses)

    NROWS = len(ROW_METRICS)
    NCOLS = 4  # 3 views + colorbar
    fig = plt.figure(figsize=(10.8, 11.6))
    gs = gridspec.GridSpec(
        nrows=NROWS, ncols=NCOLS, figure=fig,
        left=0.07, right=0.96, top=0.94, bottom=0.06,
        wspace=0.16, hspace=0.35,
        width_ratios=[1, 1, 1, 0.035]
    )

    axes = np.array([[fig.add_subplot(gs[r, c]) for c in range(3)] for r in range(NROWS)])
    cbar_axes = [fig.add_subplot(gs[r, 3]) for r in range(NROWS)]

    # Column headers (view names, with print-swap applied)
    for j, v in enumerate(VIEWS):
        axes[0, j].set_title(_print_name(v), pad=6)

    viridis = plt.cm.get_cmap("viridis")

    # Rows per metric: consistent vmin/vmax across the 3 views
    for r, m in enumerate(ROW_METRICS):
        vmin, vmax = float(base[m].min()), float(base[m].max())
        for j, v in enumerate(VIEWS):
            x, y = coords[v]
            axes[r, j].scatter(x, y, c=base[m], cmap=viridis, vmin=vmin, vmax=vmax,
                               s=8, linewidths=0)
        sm = plt.cm.ScalarMappable(norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=viridis)
        cb = fig.colorbar(sm, cax=cbar_axes[r])
        if r == 0:
            cb.set_label("lower is better", labelpad=4)

        # Row labels (left-most panel)
        axes[r, 0].set_ylabel(ROW_TITLES.get(m, m), rotation=90, labelpad=28, fontsize=12, weight="bold")

    # Clean panels
    for r in range(NROWS):
        for j in range(3):
            ax = axes[r, j]
            ax.set_xticks([]); ax.set_yticks([])
            for sp in ax.spines.values():
                sp.set_linewidth(0.6); sp.set_alpha(0.6)

    fig.suptitle(
        f"Baseline pose cartography — Views: {', '.join(_print_name(v) for v in VIEWS)}   |   "
        f"Pose = {POSE_MODE.upper()}",
        y=0.985, fontsize=12
    )
    plt.show()

if __name__ == "__main__":
    main()


In [None]:
# --- ONE-CELL DASHBOARD (GPU + tqdm): Compare Retrieval vs Baseline, rank by improvement, plot GT/R/B ---

# ==== KNOBS ====
EVAL_PATTERN_RET   = "../eval_results/refine_best/render/eval_*.npy"
EVAL_PATTERN_BASE  = "../eval_results/baseline/renders/eval_*.npy"
ROOT_DIR           = "../data/RAF/FurnishedRoom"
SAMPLE_RATE        = 48000
MAX_EVAL_FILES     = None          # e.g., 300 or None for all overlaps
TOP_N              = 30            # how many best-improved samples to visualize
RANK_METRIC        = "C50"         # one of: "SPL", "EDC", "EDT", "T60", "C50"
# plotting
FIGSIZE_PER_SAMPLE = (14, 6)       # width, height for each sample (two rows x three columns)
DPI                = 120

# ==== IMPORTS ====
import os, glob, pickle
import numpy as np
import torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torchaudio.transforms import GriffinLim

# Your project utilities (as in your snippet)
from evaluator import compute_audio_distance, compute_edc_db
from retriever import FurnishedRoomSTFTDataset  # RIRRetrievalMLP not needed here

import sys
sys.path.append('../NeRAF')
from NeRAF_helper import compute_t60, evaluate_edt, evaluate_clarity

# ==== DEVICE / ISTFT SETUP (GPU) ====
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
torch.backends.cudnn.benchmark = True
torch.set_grad_enabled(False)

n_fft = (513 - 1) * 2; win_length = 512; hop_length = 256; power = 1
istft = GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=power).to(device)

# ==== HELPERS (GPU-friendly; only move to CPU for plotting) ====
def _as_2d_numpy_cpu(wav_t):
    arr = wav_t.detach().cpu().numpy()
    return arr[None, :] if arr.ndim == 1 else arr

def _pair_metrics_with_edc(stft_a, wav_a, stft_b, wav_b, edc_a, edc_b, fs=SAMPLE_RATE):
    """
    GPU: compute SPL, MAG/MAG2, MSE + EDC distance via compute_audio_distance
    and room metrics EDT/C50 MAE and T60 %err. Returns dict of floats.
    """
    # align waveforms for fair metric computation (on GPU)
    L = min(wav_a.shape[0], wav_b.shape[0])
    wav_a = wav_a[:L]; wav_b = wav_b[:L]

    # EDC equal length (use GT length), still on GPU
    T_edc = edc_a.shape[0]
    if edc_b.shape[0] != T_edc:
        if edc_b.shape[0] > T_edc: edc_b = edc_b[:T_edc]
        else: edc_b = torch.nn.functional.pad(edc_b, (0, T_edc - edc_b.shape[0]))

    pair_stft = torch.stack([stft_a, stft_b], dim=0)
    pair_wav  = torch.stack([wav_a, wav_b], dim=0)
    pair_edc  = torch.stack([edc_a, edc_b], dim=0)

    with torch.cuda.amp.autocast(enabled=use_amp):
        mse  = compute_audio_distance(pair_stft, wavs=pair_wav, metric='MSE')[0,1].item()
        spl  = compute_audio_distance(pair_stft, wavs=pair_wav, metric='SPL', fs=fs)[0,1].item()
        mag  = compute_audio_distance(pair_stft, metric='MAG')[0,1].item()
        mag2 = compute_audio_distance(pair_stft, metric='MAG2')[0,1].item()
        edcD = compute_audio_distance(pair_stft, wavs=None, edc_curves=pair_edc, metric='EDC')[0,1].item()

    # Room-style metrics: need numpy; ship minimal data to CPU
    gt = _as_2d_numpy_cpu(wav_a); xx = _as_2d_numpy_cpu(wav_b)
    t60_gt, t60_x = compute_t60(gt, xx, fs=fs, advanced=True)
    t60_gt = np.atleast_1d(t60_gt).astype(float); t60_x = np.atleast_1d(t60_x).astype(float)
    with np.errstate(divide='ignore', invalid='ignore'):
        t60_diff = np.abs(t60_x - t60_gt) / (np.abs(t60_gt) + 1e-12)
    invalid_mask = (t60_gt < -0.5) | (t60_x < -0.5)
    t60_diff[invalid_mask] = 1.0
    t60_err_pct = float(np.mean(t60_diff) * 100.0)

    edt_gt, edt_x = evaluate_edt(xx, gt, fs=fs); edt_mae = float(np.mean(np.abs(edt_x - edt_gt)))
    c50_gt, c50_x = evaluate_clarity(xx, gt, fs=fs); c50_mae = float(np.mean(np.abs(c50_x - c50_gt)))

    return {'MSE': mse, 'SPL': spl, 'MAG': mag, 'MAG2': mag2, 'EDC': edcD,
            'EDT': edt_mae, 'C50': c50_mae, 'T60': t60_err_pct}

def _load_results(pattern, desc="scan"):
    paths_all = glob.glob(pattern)
    # tqdm with numeric idx parse can be slow; wrap in progress bar
    out = {}
    for p in tqdm(paths_all, desc=f"Loading {desc}", unit="file", leave=False):
        try:
            d = np.load(p, allow_pickle=True).item()
            idx = int(d["audio_idx"])
            out[idx] = {"path": p, "pred_stft": d["pred_stft"]}
        except Exception as e:
            print(f"[WARN] Skipping {p}: {e}")
    # sort by numeric audio_idx
    out = dict(sorted(out.items(), key=lambda kv: kv[0]))
    return out

def _ensure_logmag_to_wav(stft_logmag_t):
    with torch.cuda.amp.autocast(enabled=use_amp):
        mag = torch.exp(stft_logmag_t) - 1e-3
        wav = istft(mag.unsqueeze(0)).squeeze(0)
    return wav

def _make_edc(wav_t, T_target=60):
    return torch.nan_to_num(compute_edc_db(wav_t.float(), T_target=T_target), nan=0.0)

def _collect_vmin_vmax(*arrays):
    vmin = min([a.min() for a in arrays])
    vmax = max([a.max() for a in arrays])
    return float(vmin), float(vmax)

# ==== LOAD DATASETS / RESULTS (with tqdm) ====
ret = _load_results(EVAL_PATTERN_RET, desc="retrieval")
bas = _load_results(EVAL_PATTERN_BASE, desc="baseline")

# intersect on audio_idx to ensure fair comparison
common_idx = sorted(set(ret.keys()) & set(bas.keys()))
if MAX_EVAL_FILES is not None:
    common_idx = common_idx[:MAX_EVAL_FILES]

# dataset for GT
ds_test = FurnishedRoomSTFTDataset(root_dir=ROOT_DIR, split="test",
                                   sample_rate=SAMPLE_RATE, return_wav=True, mode="normal")

# ==== PASS 1: compute metrics and stash plotting payloads (tqdm + GPU) ====
records = []
for idx in tqdm(common_idx, desc="Scoring (GPU)", unit="file"):
    item = ds_test[idx]

    # --- Ground truth (to GPU) ---
    stft_gt  = item['stft'].squeeze(0).to(device, non_blocking=True)   # [F,T]
    wav_gt   = item['wav'].squeeze().to(device, non_blocking=True)

    # EDC target length (prefer dataset if present)
    if item.get('edc') is not None and item['edc'].numel() > 0:
        T_edc = int(item['edc'].shape[0])
    else:
        T_edc = 60

    # Ensure GT EDC curve (length T_edc)
    if item.get('edc') is not None and item['edc'].numel() > 0:
        edc_gt = item['edc'].to(device, non_blocking=True)
        if edc_gt.shape[0] != T_edc:
            edc_gt = edc_gt[:T_edc] if edc_gt.shape[0] > T_edc else torch.nn.functional.pad(edc_gt, (0, T_edc - edc_gt.shape[0]))
        edc_gt = torch.nan_to_num(edc_gt, nan=0.0)
    else:
        edc_gt = _make_edc(wav_gt, T_target=T_edc).to(device, non_blocking=True)

    # --- Retrieval (stft_best) ---
    pred_stft_R = torch.from_numpy(ret[idx]["pred_stft"]).float().squeeze(0).to(device, non_blocking=True)
    wav_R       = _ensure_logmag_to_wav(pred_stft_R)
    L_R         = min(wav_gt.shape[0], wav_R.shape[0])
    wav_R       = wav_R[:L_R]; wav_gt_R = wav_gt[:L_R]  # aligned for R metrics
    edc_R       = _make_edc(wav_R, T_target=T_edc).to(device, non_blocking=True)

    # --- Baseline ---
    pred_stft_B = torch.from_numpy(bas[idx]["pred_stft"]).float().squeeze(0).to(device, non_blocking=True)
    wav_B       = _ensure_logmag_to_wav(pred_stft_B)
    L_B         = min(wav_gt.shape[0], wav_B.shape[0])
    wav_B       = wav_B[:L_B]; wav_gt_B = wav_gt[:L_B]  # aligned for B metrics
    edc_B       = _make_edc(wav_B, T_target=T_edc).to(device, non_blocking=True)

    # --- Metrics (vs GT) on GPU (with brief CPU hops inside helper) ---
    metrics_R = _pair_metrics_with_edc(stft_gt, wav_gt_R, pred_stft_R, wav_R, edc_gt, edc_R, fs=SAMPLE_RATE)
    metrics_B = _pair_metrics_with_edc(stft_gt, wav_gt_B, pred_stft_B, wav_B, edc_gt, edc_B, fs=SAMPLE_RATE)

    # Ranking score: improvement (baseline_error - retrieval_error); bigger is better
    if RANK_METRIC not in metrics_R:
        raise ValueError(f"RANK_METRIC '{RANK_METRIC}' not in computed metrics: {list(metrics_R.keys())}")
    improve = float(metrics_B[RANK_METRIC] - metrics_R[RANK_METRIC])

    # Store plotting payloads (keep tensors on GPU for now; move to CPU right before plotting)
    records.append({
        "idx": idx,
        "plot_gpu": {
            "wav": {"gt": wav_gt, "ret": wav_R, "bas": wav_B},
            "stft": {"gt": stft_gt, "ret": pred_stft_R, "bas": pred_stft_B},
        },
        "metrics_R": metrics_R,
        "metrics_B": metrics_B,
        "improve": improve
    })

    # keep VRAM tidy
    del stft_gt, wav_gt, pred_stft_R, wav_R, pred_stft_B, wav_B, edc_gt, edc_R, edc_B, wav_gt_R, wav_gt_B
    if device.type == "cuda":
        torch.cuda.empty_cache()

# ==== RANK & SELECT TOP-N ====
records_sorted = sorted(records, key=lambda r: r["improve"], reverse=True)
top = records_sorted[:TOP_N]

# ==== PLOTTING (with tqdm; tensors -> CPU only here) ====
def _title_from_metrics(name, m):
    # Only show the 5 requested: SPL, EDC, EDT, T60, C50
    return (f"{name}\n"
            f"ΔSTFT:{m['SPL']:.3f}  EDC:{m['EDC']:.3f}  "
            f"EDT:{m['EDT']:.3f}  T60%:{m['T60']:.2f}  C50:{m['C50']:.3f}")

for rec in tqdm(top, desc="Plotting", unit="sample"):
    idx = rec["idx"]
    wav_gt = rec["plot_gpu"]["wav"]["gt"].detach().cpu().float().numpy()
    wav_R  = rec["plot_gpu"]["wav"]["ret"].detach().cpu().float().numpy()
    wav_B  = rec["plot_gpu"]["wav"]["bas"].detach().cpu().float().numpy()

    S_gt   = rec["plot_gpu"]["stft"]["gt"].detach().cpu().float().numpy()
    S_R    = rec["plot_gpu"]["stft"]["ret"].detach().cpu().float().numpy()
    S_B    = rec["plot_gpu"]["stft"]["bas"].detach().cpu().float().numpy()

    # Time axis
    t_gt = np.arange(len(wav_gt)) / SAMPLE_RATE
    t_R  = np.arange(len(wav_R))  / SAMPLE_RATE
    t_B  = np.arange(len(wav_B))  / SAMPLE_RATE

    # Consistent spectrogram scaling
    vmin, vmax = _collect_vmin_vmax(S_gt, S_R, S_B)

    fig, axes = plt.subplots(2, 3, figsize=FIGSIZE_PER_SAMPLE, dpi=DPI,
                             gridspec_kw={'height_ratios':[1, 1]})
    fig.suptitle(f"audio_idx={idx}", y=1.02)

    # --- Row 1: time-domain ---
    axes[0,0].plot(t_R, wav_R, linewidth=0.75)
    axes[0,0].set_title(_title_from_metrics("Retrieval (waveform)", rec["metrics_R"]))
    axes[0,0].set_xlabel("Time [s]"); axes[0,0].set_ylabel("Amp")

    axes[0,1].plot(t_gt, wav_gt, linewidth=0.75)
    axes[0,1].set_title("GT (waveform)")
    axes[0,1].set_xlabel("Time [s]"); axes[0,1].set_ylabel("Amp")

    axes[0,2].plot(t_B, wav_B, linewidth=0.75)
    axes[0,2].set_title(_title_from_metrics("Baseline (waveform)", rec["metrics_B"]))
    axes[0,2].set_xlabel("Time [s]"); axes[0,2].set_ylabel("Amp")

    # --- Row 2: STFTs ---
    axes[1,0].imshow(S_R, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    axes[1,0].set_title(_title_from_metrics("Retrieval (STFT log-mag)", rec["metrics_R"]))
    axes[1,0].set_xlabel("Frames"); axes[1,0].set_ylabel("Bins")

    axes[1,1].imshow(S_gt, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    axes[1,1].set_title("GT (STFT log-mag)")
    axes[1,1].set_xlabel("Frames"); axes[1,1].set_ylabel("Bins")

    axes[1,2].imshow(S_B, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    axes[1,2].set_title(_title_from_metrics("Baseline (STFT log-mag)", rec["metrics_B"]))
    axes[1,2].set_xlabel("Frames"); axes[1,2].set_ylabel("Bins")

    plt.tight_layout()
    plt.show()


print(f"Plotted top {len(top)} samples ranked by improvement in {RANK_METRIC} on {device.type.upper()}.")

In [None]:
# --- CLEAN, SORTED, GPU GALLERY (rows=samples; cols=GT | Baseline | D1(stft_best) | D2(refine_best)) ---
# Makes TWO figures: Waveforms and STFT log-mags.

# ==== KNOBS ====
EVAL_PATTERN_BASE   = "../eval_results/baseline/renders/eval_*.npy"      # Baseline (NeRAF)
EVAL_PATTERN_D1     = "../eval_results/stft_no_edc/renders/eval_*.npy"   # Design 1: Feature Fusion
EVAL_PATTERN_D2     = "../eval_results/refine_best/renders/eval_*.npy"   # Design 2: Output Modification
ROOT_DIR            = "../data/RAF/FurnishedRoom"
SAMPLE_RATE         = 48000

MAX_EVAL_FILES      = None     # cap overlaps before sorting (None → all)
NUM_SAMPLES         = 10     # up to 10 rows
SORT_BY_METRIC      = "SPL"  # "SPL","T60","C50","EDT","EDC"

FIG_DPI             = 130
FIG_W               = 14
FIG_H_WAVE          = 1.1
FIG_H_STFT          = 1.1
COL_TITLE_SIZE      = 12     # bold column names (top row)
METRIC_FONTSIZE     = 8      # single-line metrics (non-bold)

# ==== IMPORTS ====
import os, glob, numpy as np, torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torchaudio.transforms import GriffinLim

from evaluator import compute_audio_distance, compute_edc_db
from retriever import FurnishedRoomSTFTDataset

import sys
sys.path.append('../NeRAF')
from NeRAF_helper import compute_t60, evaluate_edt, evaluate_clarity

# ==== DEVICE / ISTFT ====
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
torch.backends.cudnn.benchmark = True
torch.set_grad_enabled(False)

n_fft = (513 - 1) * 2; win_length = 512; hop_length = 256; power = 1
istft = GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=power).to(device)

# ==== HELPERS ====
def _load_results(pattern, desc):
    out = {}
    for p in tqdm(glob.glob(pattern), desc=f"Loading {desc}", unit="file", leave=False):
        try:
            d = np.load(p, allow_pickle=True).item()
            out[int(d["audio_idx"])] = {"pred_stft": d["pred_stft"]}
        except Exception as e:
            print(f"[WARN] skip {p}: {e}")
    return dict(sorted(out.items(), key=lambda kv: kv[0]))

def _ensure_logmag_to_wav(stft_logmag_t):
    with torch.cuda.amp.autocast(enabled=use_amp):
        mag = torch.exp(stft_logmag_t) - 1e-3
        wav = istft(mag.unsqueeze(0)).squeeze(0)
    return wav

def _make_edc(wav_t, T_target=60):
    return torch.nan_to_num(compute_edc_db(wav_t.float(), T_target=T_target), nan=0.0)

def _as_2d_numpy_cpu(x):
    a = x.detach().cpu().numpy()
    return a[None, :] if a.ndim == 1 else a

def _metrics_vs_gt(stft_gt, wav_gt, stft_pred, wav_pred, edc_gt, fs=SAMPLE_RATE):
    # align
    L = min(wav_gt.shape[0], wav_pred.shape[0])
    wav_gt = wav_gt[:L]; wav_pred = wav_pred[:L]
    pair_stft = torch.stack([stft_gt, stft_pred], dim=0)
    pair_wav  = torch.stack([wav_gt, wav_pred], dim=0)

    with torch.cuda.amp.autocast(enabled=use_amp):
        spl = compute_audio_distance(pair_stft, wavs=pair_wav, metric='SPL', fs=fs)[0,1].item()
        # EDT/C50/T60 via numpy APIs
        edt_gt, edt_x = evaluate_edt(_as_2d_numpy_cpu(wav_pred), _as_2d_numpy_cpu(wav_gt), fs=fs)
        c50_gt, c50_x = evaluate_clarity(_as_2d_numpy_cpu(wav_pred), _as_2d_numpy_cpu(wav_gt), fs=fs)
        t60_gt, t60_x = compute_t60(_as_2d_numpy_cpu(wav_gt), _as_2d_numpy_cpu(wav_pred), fs=fs, advanced=True)
        t60_gt = np.atleast_1d(t60_gt).astype(float); t60_x = np.atleast_1d(t60_x).astype(float)
        with np.errstate(divide='ignore', invalid='ignore'):
            t60_diff = np.abs(t60_x - t60_gt) / (np.abs(t60_gt) + 1e-12)
        t60_diff[(t60_gt < -0.5) | (t60_x < -0.5)] = 1.0
        t60_err = float(np.mean(t60_diff) * 100.0)
        edt_mae = float(np.mean(np.abs(edt_x - edt_gt)))
        c50_mae = float(np.mean(np.abs(c50_x - c50_gt)))
    return {'SPL': float(spl), 'T60': t60_err, 'C50': c50_mae, 'EDT': edt_mae}

def _fmt_one_line(m):  # single line metrics
    return f"ΔSTFT:{m['SPL']:.3f}  T60%:{m['T60']:.2f}  C50:{m['C50']:.3f}  EDT:{m['EDT']:.3f}"

def _collect_vmin_vmax(*arrs):
    vmin = min(a.min() for a in arrs); vmax = max(a.max() for a in arrs)
    return float(vmin), float(vmax)

# NEW: Use CAD for sorting when available (SPL/EDC). Fallback to precomputed m for others.
_DIRECT_BY_CAD = {"SPL", "EDC"}

def _score_for_sort(row, metric, fs=SAMPLE_RATE):
    if metric in _DIRECT_BY_CAD:
        # Compute CAD on-the-fly between Baseline and GT using requested metric
        stft_gt = row["GT"]["stft"].to(device, non_blocking=True)
        wav_gt  = row["GT"]["wav"].to(device, non_blocking=True)
        stft_b  = row["B"]["stft"].to(device, non_blocking=True)
        wav_b   = row["B"]["wav"].to(device, non_blocking=True)
        # Align
        L = min(wav_gt.shape[0], wav_b.shape[0])
        wav_gt = wav_gt[:L]; wav_b = wav_b[:L]
        pair_stft = torch.stack([stft_gt, stft_b], dim=0)
        pair_wav  = torch.stack([wav_gt,  wav_b],  dim=0)
        with torch.cuda.amp.autocast(enabled=use_amp):
            val = compute_audio_distance(pair_stft, wavs=pair_wav, metric=metric, fs=fs)[0,1].item()
        return float(val)
    else:
        # Use previously computed scalar error
        return float(row["B"]["m"][metric])

# ==== LOAD ====
B   = _load_results(EVAL_PATTERN_BASE, "baseline")
D1  = _load_results(EVAL_PATTERN_D1,   "design-1 (stft_best)")
D2  = _load_results(EVAL_PATTERN_D2,   "design-2 (refine_best)")

common = sorted(set(B.keys()) & set(D1.keys()) & set(D2.keys()))
if MAX_EVAL_FILES: common = common[:MAX_EVAL_FILES]

ds = FurnishedRoomSTFTDataset(root_dir=ROOT_DIR, split="test",
                              sample_rate=SAMPLE_RATE, return_wav=True, mode="normal")

# ==== BUILD + SCORE (GPU) ====
rows = []
for idx in tqdm(common, desc="Preparing (GPU)", unit="file"):
    item = ds[idx]
    stft_gt = item['stft'].squeeze(0).to(device, non_blocking=True)
    wav_gt  = item['wav'].squeeze().to(device, non_blocking=True)

    stft_b  = torch.from_numpy(B[idx]["pred_stft"]).float().squeeze(0).to(device, non_blocking=True)
    stft_d1 = torch.from_numpy(D1[idx]["pred_stft"]).float().squeeze(0).to(device, non_blocking=True)
    stft_d2 = torch.from_numpy(D2[idx]["pred_stft"]).float().squeeze(0).to(device, non_blocking=True)

    wav_b  = _ensure_logmag_to_wav(stft_b)
    wav_d1 = _ensure_logmag_to_wav(stft_d1)
    wav_d2 = _ensure_logmag_to_wav(stft_d2)

    # metrics vs GT
    mB  = _metrics_vs_gt(stft_gt, wav_gt, stft_b,  wav_b,  None)
    mD1 = _metrics_vs_gt(stft_gt, wav_gt, stft_d1, wav_d1, None)
    mD2 = _metrics_vs_gt(stft_gt, wav_gt, stft_d2, wav_d2, None)

    rows.append({
        "idx": idx,
        "GT": {"wav": wav_gt, "stft": stft_gt},    # no metrics shown for GT
        "B":  {"wav": wav_b,  "stft": stft_b,  "m": mB},
        "D1": {"wav": wav_d1, "stft": stft_d1, "m": mD1},
        "D2": {"wav": wav_d2, "stft": stft_d2, "m": mD2},
    })

    # VRAM hygiene
    del stft_gt, wav_gt, stft_b, stft_d1, stft_d2, wav_b, wav_d1, wav_d2
    if device.type == "cuda": torch.cuda.empty_cache()

# ==== SORT by Baseline's chosen metric (worst→best), then take top N ====
if SORT_BY_METRIC not in ("SPL","T60","C50","EDT","EDC"):
    raise ValueError("SORT_BY_METRIC must be one of: SPL, T60, C50, EDT, EDC")
rows_sorted = sorted(rows, key=lambda r: _score_for_sort(r, SORT_BY_METRIC), reverse=True)
rows = rows_sorted[:NUM_SAMPLES]

# ==== FIGURE A: WAVEFORMS (no ticks, no idx labels) ====
figA, axsA = plt.subplots(len(rows), 4, figsize=(FIG_W, FIG_H_WAVE*len(rows)), dpi=FIG_DPI)
if len(rows) == 1: axsA = np.expand_dims(axsA, 0)

col_keys   = ["GT", "B", "D1", "D2"]
col_titles = ["Ground Truth", "Baseline (NeRAF)", "Design 1: Feature Fusion", "Design 2: Output Modification"]

for r, row in enumerate(rows):
    for c, key in enumerate(col_keys):
        ax = axsA[r, c]
        w = row[key]["wav"].detach().cpu().float().numpy()
        t = np.arange(len(w))/SAMPLE_RATE
        ax.plot(t, w, linewidth=0.8)
        ax.set_xticks([]); ax.set_yticks([]); ax.set_xlabel(""); ax.set_ylabel("")
        for sp in ax.spines.values():
            sp.set_linewidth(0.6); sp.set_alpha(0.8)

        # Top-row: bold headers only; metrics in normal weight just below header (avoid collisions)
        if r == 0:
            ax.set_title(col_titles[c], fontweight="bold", fontsize=COL_TITLE_SIZE, pad=18)
            if key != "GT":
                ax.text(0.5, 1.02, _fmt_one_line(row[key]["m"]),
                        transform=ax.transAxes, ha="center", va="bottom",
                        fontsize=METRIC_FONTSIZE, fontweight="normal")
        else:
            if key != "GT":
                ax.set_title(_fmt_one_line(row[key]["m"]), fontsize=METRIC_FONTSIZE, pad=4)

figA.suptitle("Waveforms — rows: samples | cols: GT | Baseline | D1 | D2", y=0.995, fontsize=10)
plt.tight_layout()
plt.show()

# ==== FIGURE B: STFTs (no ticks, same rules) ====
figB, axsB = plt.subplots(len(rows), 4, figsize=(FIG_W, FIG_H_STFT*len(rows)), dpi=FIG_DPI)
if len(rows) == 1: axsB = np.expand_dims(axsB, 0)

for r, row in enumerate(rows):
    Sarr = [row[k]["stft"].detach().cpu().float().numpy() for k in col_keys]
    vmin, vmax = _collect_vmin_vmax(*Sarr)

    for c, key in enumerate(col_keys):
        ax = axsB[r, c]
        ax.imshow(Sarr[c], origin="lower", aspect="auto", vmin=vmin, vmax=vmax)
        ax.set_xticks([]); ax.set_yticks([]); ax.set_xlabel(""); ax.set_ylabel("")
        for sp in ax.spines.values():
            sp.set_linewidth(0.6); sp.set_alpha(0.8)

        if r == 0:
            ax.set_title(col_titles[c], fontweight="bold", fontsize=COL_TITLE_SIZE, pad=18)
            if key != "GT":
                ax.text(0.5, 1.02, _fmt_one_line(row[key]["m"]),
                        transform=ax.transAxes, ha="center", va="bottom",
                        fontsize=METRIC_FONTSIZE, fontweight="normal")
        else:
            if key != "GT":
                ax.set_title(_fmt_one_line(row[key]["m"]), fontsize=METRIC_FONTSIZE, pad=4)

figB.suptitle("STFT log-mag — rows: samples | cols: GT | Baseline | D1 | D2", y=0.995, fontsize=10)
plt.tight_layout()
plt.show()
