In [1]:
from pathlib import Path
import json, math
import numpy as np

# --- inputs ---
RX_POS_TXT = Path("../NeRAF/data/RAF/EmptyRoom/metadata/all_rx_pos.txt")
TX_POS_TXT = Path("../NeRAF/data/RAF/EmptyRoom/metadata/all_tx_pos.txt")
RUN_DIR_JSON = Path("./outputs/EmptyRoom/images-jpeg-1k/nerfacto/2025-11-04_172911/dataparser_transforms.json")

OUT_JSON = Path("camera_path.json")

# Viewer header
DEFAULT_FOV = 100.0
ASPECT = 1.5
RENDER_H, RENDER_W = 256,256
SECONDS = 999
IS_CYCLE = False
SMOOTHNESS = 0
CAMERA_TYPE = "perspective"

GLOBAL_UP = np.array([0.0, 0.0, 1.0], dtype=float)  # Z-up

# Control how many you keep
LIMIT = None     # None = all
STRIDE = 1       # keep every STRIDE-th pair (e.g., 5)

# ---------- helpers ----------
def float_no_sci(x, ndigits=12):
    s = f"{x:.{ndigits}f}"
    if "." in s:
        s = s.rstrip("0").rstrip(".")
        if s == "-0":
            s = "0"
        if "." not in s:
            s += ".0"
    return float(s)

def matrix_row_major_list(m4):
    return [float_no_sci(m4[i, j]) for i in range(4) for j in range(4)]

def load_dataparser_transform(run_dir_json):
    with open(run_dir_json, "r") as f:
        meta = json.load(f)
    T_3x4 = np.array(meta["transform"], dtype=float)  # (3,4)
    s = float(meta["scale"])
    R = T_3x4[:, :3]
    t = T_3x4[:, 3]
    return R, t, s

def normalize(v, eps=1e-9):
    n = np.linalg.norm(v)
    return v / n if n > eps else v * 0.0

def build_upright_look_at(cam_pos, target_pos, global_up=np.array([0,0,1.0])):
    f = normalize(target_pos - cam_pos)
    r = np.cross(global_up, f)
    if np.linalg.norm(r) < 1e-6:
        r = np.cross(np.array([1.0,0.0,0.0]), f)
        if np.linalg.norm(r) < 1e-6:
            r = np.array([0.0,1.0,0.0])
    r = normalize(r)
    u = np.cross(f, r)
    c2w = np.eye(4, dtype=float)
    c2w[:3, 0] = r
    c2w[:3, 1] = u
    c2w[:3, 2] = f
    c2w[:3, 3] = cam_pos
    return c2w

def parse_rx_all(path):
    """Yield RX positions: each line is y,z,x -> xyz; skip NaNs."""
    with open(path, "r") as f:
        for line in f:
            parts = [p.strip() for p in line.strip().split(",")]
            if len(parts) < 3: 
                continue
            try:
                y = float(parts[0]); z = float(parts[1]); x = float(parts[2])
            except ValueError:
                continue
            if all(math.isfinite(v) for v in (x,y,z)):
                yield np.array([x,y,z], dtype=float)

def parse_tx_all(path):
    """Yield TX quaternion (HypA yzxW -> xyzW) + TX position (yzx -> xyz)."""
    with open(path, "r") as f:
        for line in f:
            parts = [p.strip() for p in line.strip().split(",")]
            if len(parts) < 7:
                continue
            try:
                qy = float(parts[0]); qz = float(parts[1]); qx = float(parts[2]); qw = float(parts[3])  # yzxW
                py = float(parts[4]); pz = float(parts[5]); px = float(parts[6])                      # yzx
            except ValueError:
                continue
            if not all(math.isfinite(v) for v in (qx,qy,qz,qw,px,py,pz)):
                continue
            quat_xyzW = (qx, qy, qz, qw)
            pos_xyz   = np.array([px, py, pz], dtype=float)
            yield quat_xyzW, pos_xyz

def world_to_nerf_point(p_world, R, t, s):
    return s * (R @ p_world + t)

# ---------- main ----------
R_dp, t_dp, s_dp = load_dataparser_transform(RUN_DIR_JSON)

rx_iter = parse_rx_all(RX_POS_TXT)
tx_iter = parse_tx_all(TX_POS_TXT)

camera_entries = []
count_total = 0
count_kept = 0

for idx, (rx_p, tx_tuple) in enumerate(zip(rx_iter, tx_iter)):
    count_total += 1
    if idx % STRIDE != 0:
        continue
    _, tx_p = tx_tuple  # quat parsed but unused for upright mode
    # world -> NeRF
    rx_n = world_to_nerf_point(rx_p, R_dp, t_dp, s_dp)
    tx_n = world_to_nerf_point(tx_p, R_dp, t_dp, s_dp)
    # upright look-at
    c2w = build_upright_look_at(rx_n, tx_n, GLOBAL_UP)
    camera_entries.append({
        "camera_to_world": matrix_row_major_list(c2w),
        "fov": float_no_sci(DEFAULT_FOV),
        "aspect": float_no_sci(ASPECT),
    })
    count_kept += 1
    if LIMIT is not None and count_kept >= LIMIT:
        break

camera_path = {
    "default_fov": float_no_sci(DEFAULT_FOV),
    "default_transition_sec": 2,
    "camera_type": CAMERA_TYPE,
    "render_height": RENDER_H,
    "render_width": RENDER_W,
    "seconds": SECONDS,
    "is_cycle": IS_CYCLE,
    "smoothness_value": SMOOTHNESS,
    "camera_path": camera_entries,
}

with open(OUT_JSON, "w") as f:
    json.dump(camera_path, f, ensure_ascii=False, indent=2)

print(f"Wrote {OUT_JSON.resolve()} with {count_kept} / {count_total} paired cameras (stride={STRIDE}, limit={LIMIT}).")

Wrote /media/scratch/projects/labuser/msc_user/MoNezami/ReverbRAG/camera_path.json with 47484 / 47484 paired cameras (stride=1, limit=None).


In [23]:
import os, json, math, numpy as np, torch, torchaudio
from tqdm import tqdm

SCENE_ROOT = "../NeRAF/data/RAF/EmptyRoom"
DATA_DIR   = os.path.join(SCENE_ROOT, "data")
FEATS_DIR  = os.path.join(SCENE_ROOT, "feats")
os.makedirs(FEATS_DIR, exist_ok=True)

# STFT params (RAF)
sr = 48000
n_fft, win_length, hop_length = 1024, 512, 256
T = 60
F = n_fft//2 + 1
DT_STFT = np.float16
MAX_SHARD_MB = 1024

stft_tf = torchaudio.transforms.Spectrogram(
    n_fft=n_fft, win_length=win_length, hop_length=hop_length,
    power=None, center=True, pad_mode="reflect"
)
def _logmag(x): return torch.log(x.abs() + 1e-3)

def _collect_sids(meta_json):
    with open(meta_json, "r") as f:
        splits = json.load(f)
    sids = []
    for v in splits.values():
        block = v[0] if (isinstance(v, list) and v and isinstance(v[0], list)) else v
        for sid in block:
            sids.append(f"{int(sid):06d}" if str(sid).isdigit() else str(sid))
    return sorted(set(sids))

def _load_wav(sid):
    p = os.path.join(DATA_DIR, sid, "rir.wav")
    wav, r = torchaudio.load(p)
    if r != sr: wav = torchaudio.functional.resample(wav, r, sr)
    wav = wav[:, : int(0.32 * sr)]
    return wav

def _stft60(wav):
    spec = stft_tf(wav)  # [1,F,T_full]
    if spec.shape[-1] > T:
        spec = spec[:, :, :T]
    elif spec.shape[-1] < T:
        minval = float(spec.abs().min())
        spec = torch.nn.functional.pad(spec, (0, T - spec.shape[-1]), value=minval)
    return _logmag(spec).squeeze(0)  # [F,T]

# Discover SIDs and (maybe) an existing index.json
SPLIT_JSON = os.path.join(SCENE_ROOT, "metadata", "data-split.json")
sids = _collect_sids(SPLIT_JSON)
idx_path = os.path.join(FEATS_DIR, "index.json")
index_meta = {"shards": [], "sid_to_ptr": {}}
if os.path.exists(idx_path):
    with open(idx_path, "r") as f:
        index_meta = json.load(f)

# Compute shard sizing
bytes_per_stft = F*T*np.dtype(DT_STFT).itemsize
items_per_shard = max(1, (MAX_SHARD_MB*1024*1024)//bytes_per_stft)
N = len(sids)
num_shards = math.ceil(N/items_per_shard)
print(f"[STFT] Items/shard≈{items_per_shard} → #shards={num_shards}")

def shard_paths(k):
    base = os.path.join(FEATS_DIR, f"shard_{k:03d}")
    return base+"_stft.npy"

# Build shards
for k in range(num_shards):
    start = k*items_per_shard
    end   = min(N, (k+1)*items_per_shard)
    n_k   = end - start
    st_p  = shard_paths(k)

    if os.path.exists(st_p):
        print(f"[STFT] shard {k} exists, skipping write.")
        st_mm = np.memmap(st_p, dtype=DT_STFT, mode="r+", shape=(n_k, F, T))
    else:
        print(f"[STFT] writing shard {k}: {n_k} items → {st_p}")
        st_mm = np.memmap(st_p, dtype=DT_STFT, mode="w+", shape=(n_k, F, T))

    for i, sid in tqdm(list(enumerate(sids[start:end], start=0)), total=n_k, desc=f"[STFT] {k+1}/{num_shards}"):
        if sid in index_meta["sid_to_ptr"]:
            continue
        x = _stft60(_load_wav(sid)).cpu().numpy().astype(DT_STFT)
        st_mm[i] = x
        index_meta["sid_to_ptr"][sid] = [k, i]
    del st_mm

# Write/merge index
# ensure one entry per shard with at least 'stft' path
existing = {sh["id"]: sh for sh in index_meta.get("shards", [])}
for k in range(num_shards):
    st_p = shard_paths(k)
    if k in existing:
        existing[k]["stft"] = st_p
        existing[k]["count"] = existing[k].get("count", 0) or sum(1 for v in index_meta["sid_to_ptr"].values() if v[0]==k)
        existing[k]["F"] = F; existing[k]["T"] = T
        existing[k].setdefault("dtypes", {})["stft"] = str(DT_STFT)
    else:
        existing[k] = {"id": k, "stft": st_p, "count": sum(1 for v in index_meta["sid_to_ptr"].values() if v[0]==k),
                       "F": F, "T": T, "dtypes": {"stft": str(DT_STFT)}}
index_meta["shards"] = [existing[k] for k in sorted(existing.keys())]

with open(idx_path, "w") as f:
    json.dump(index_meta, f)
print("[STFT] Done. Index saved to", idx_path)


[STFT] Items/shard≈17442 → #shards=3
[STFT] writing shard 0: 17442 items → ../NeRAF/data/RAF/EmptyRoom/feats/shard_000_stft.npy


[STFT] 1/3: 100%|██████████| 17442/17442 [01:08<00:00, 255.23it/s]


[STFT] writing shard 1: 17442 items → ../NeRAF/data/RAF/EmptyRoom/feats/shard_001_stft.npy


[STFT] 2/3: 100%|██████████| 17442/17442 [01:08<00:00, 253.44it/s]


[STFT] writing shard 2: 12600 items → ../NeRAF/data/RAF/EmptyRoom/feats/shard_002_stft.npy


[STFT] 3/3: 100%|██████████| 12600/12600 [00:48<00:00, 260.68it/s]


[STFT] Done. Index saved to ../NeRAF/data/RAF/EmptyRoom/feats/index.json


In [24]:
import os, json, numpy as np, torch, torchaudio
from tqdm import tqdm

SCENE_ROOT = "../NeRAF/data/RAF/EmptyRoom"
DATA_DIR   = os.path.join(SCENE_ROOT, "data")
FEATS_DIR  = os.path.join(SCENE_ROOT, "feats")
IDX_PATH   = os.path.join(FEATS_DIR, "index.json")

assert os.path.exists(IDX_PATH), "Missing feats/index.json — build STFT shards first."

# EDC settings
sr = 48000
T = 60
DT_EDC = np.float32

def _load_wav(sid):
    p = os.path.join(DATA_DIR, sid, "rir.wav")
    wav, r = torchaudio.load(p)
    if r != sr: wav = torchaudio.functional.resample(wav, r, sr)
    return wav.squeeze(0)[: int(0.32 * sr)]  # [S]

def _edc_db_60(w1d):
    e = (w1d.float()**2)
    edc = torch.flip(torch.cumsum(torch.flip(e, [0]), 0), [0])
    edc = edc / (edc[0] + 1e-12)
    edc_db = 10*torch.log10(edc + 1e-12)
    idx = torch.linspace(0, edc_db.numel()-1, steps=T).long()
    return edc_db[idx]  # [T]

with open(IDX_PATH, "r") as f:
    idx = json.load(f)

sid_to_ptr = {k: tuple(v) for k, v in idx["sid_to_ptr"].items()}
# Gather SIDs per shard using the EXISTING mapping
shard_to_sidrows = {}
for sid, (k, row) in sid_to_ptr.items():
    shard_to_sidrows.setdefault(k, []).append((sid, row))

# Build/overwrite EDC shard files using the STFT shard counts
for sh in idx["shards"]:
    k = sh["id"]
    count = int(sh["count"])
    edc_path = os.path.join(FEATS_DIR, f"shard_{k:03d}_edc.npy")

    # Create memmap with shape matching the STFT shard
    ed_mm = np.memmap(edc_path, dtype=DT_EDC, mode="w+", shape=(count, T))
    rows = shard_to_sidrows.get(k, [])
    print(f"[EDC] shard {k}: writing {len(rows)} rows into {count}-row memmap -> {edc_path}")

    for sid, row in tqdm(rows, total=len(rows), desc=f"[EDC] {k:03d}"):
        # guard: row must be < count (if not, your index.json is already inconsistent with STFT shards)
        if not (0 <= row < count):
            raise RuntimeError(f"Index mismatch: SID {sid} maps to row {row} but shard {k} has count {count}")
        w = _load_wav(sid)
        ed_mm[row] = _edc_db_60(w).cpu().numpy().astype(DT_EDC)

    del ed_mm

    # annotate shard record with EDC path & dtype
    sh["edc"] = edc_path
    sh.setdefault("dtypes", {})["edc"] = str(DT_EDC)

with open(IDX_PATH, "w") as f:
    json.dump(idx, f)

print("[EDC] Repair complete. index.json updated.")


[EDC] shard 0: writing 17442 rows into 17442-row memmap -> ../NeRAF/data/RAF/EmptyRoom/feats/shard_000_edc.npy


[EDC] 000: 100%|██████████| 17442/17442 [00:33<00:00, 523.85it/s]


[EDC] shard 1: writing 17442 rows into 17442-row memmap -> ../NeRAF/data/RAF/EmptyRoom/feats/shard_001_edc.npy


[EDC] 001: 100%|██████████| 17442/17442 [00:35<00:00, 491.62it/s]


[EDC] shard 2: writing 12600 rows into 12600-row memmap -> ../NeRAF/data/RAF/EmptyRoom/feats/shard_002_edc.npy


[EDC] 002: 100%|██████████| 12600/12600 [00:24<00:00, 510.68it/s]


[EDC] Repair complete. index.json updated.


In [23]:
# CELL 1 — Global cache (fast):
#   • EDC via precomputed curves + broadcast distances (L1/L2, globally consistent)
#   • SPL via GPU hub (compute_audio_distance)
#   • Decays (T60/EDT/C50/DR) via one-shot scalars + broadcasting
#
# Saves: ./dist_cache/{METRIC}.npz with RAW & global-Z matrices aligned across metrics

import os, sys, glob, json, math
import numpy as np
import torch
from tqdm.auto import tqdm

# ---------------- KNOBS ----------------
SCENE_ROOT      = "../NeRAF/data/RAF/FurnishedRoom"                      # adjust
RENDERS_GLOB    = "../NeRAF/eval_results/furnishedroom/renders/eval_*.npy"
K_QUERIES       = 20                   # first K renders as queries
REF_SPLIT_TRY   = "reference"        # fallback to "train" if missing
DEVICE          = "cuda" if torch.cuda.is_available() else "cpu"

# Which metrics to compute/cache
METRICS = ["EDC", "SPL", "T60", "EDT", "C50", "DR"]

# EDC options (to mimic your hub normalization by default)
EVAL_FS         = 48000
EDC_BINS        = 60                 # fixed bins for all signals
EDC_ANCHOR_ZERO = True               # subtract first bin (0 dB at start)
EDC_ZSCORE      = True               # per-curve z-score after anchoring
EDC_DIST_MODE   = "l1"               # choose "l1" or "l2"
# scale normalization for distances so magnitudes are comparable
EDC_L2_SCALE    = "sqrtN"            # "sqrtN" -> divide by sqrt(EDC_BINS); None to disable
EDC_L1_SCALE    = "mean"             # "mean"  -> divide by N; None to disable

# GPU chunking over references (for SPL only)
REF_BATCH       = 2048

# Save dir
DIST_CACHE_DIR  = "./dist_cache"

# ---------------------------------------
# Project imports
sys.path.append(".")
sys.path.append("/mnt/data")
from evaluator import compute_audio_distance, compute_edc_db
try:
    from evaluator import compute_t60, evaluate_edt, evaluate_clarity
    HAVE_DECAY_HELPERS = True
except Exception:
    HAVE_DECAY_HELPERS = False

try:
    from data import RAFDataset
except Exception:
    raise RuntimeError("Cannot import RAFDataset; fix sys.path to your repo.")

# ---------- Utils ----------
def _ensure_dir(p): os.makedirs(p, exist_ok=True)

def _global_z(M: np.ndarray):
    m = np.isfinite(M)
    if not m.any():
        mu, sd = 0.0, 1.0
        Z = np.zeros_like(M, dtype=np.float32)
    else:
        mu = float(np.mean(M[m]))
        sd = float(np.std(M[m]) + 1e-6)
        Z = (M - mu) / sd
        Z[~np.isfinite(Z)] = 0.0
    return Z.astype(np.float32), float(mu), float(sd)

@torch.no_grad()
def _pad_batch_1d(wavs_list, device, dtype=torch.float32):
    if len(wavs_list) == 0:
        return torch.zeros(0, 1, dtype=dtype, device=device)
    Tmax = max(int(w.numel()) for w in wavs_list)
    out = torch.zeros(len(wavs_list), Tmax, dtype=dtype, device=device)
    for i, w in enumerate(wavs_list):
        t = int(w.numel())
        out[i, :t] = w.to(device=device, dtype=dtype)
    return out

def _dr_db(x_1d_np: np.ndarray) -> float:
    rms = np.sqrt(np.mean(x_1d_np**2) + 1e-12)
    peak = np.max(np.abs(x_1d_np)) + 1e-12
    return float(20.0 * np.log10(peak / max(rms, 1e-12)))

# ---------- Scan renders (queries) ----------
render_files = sorted(glob.glob(RENDErs_GLOB := RENDERS_GLOB))
assert len(render_files) > 0, f"No NPY renders found at: {RENDErs_GLOB}"
if K_QUERIES is not None:
    render_files = render_files[:int(K_QUERIES)]
query_ids = [os.path.basename(fp) for fp in render_files]
Q = len(query_ids)
print(f"[info] Queries (K) = {Q}")

# ---------- Reference bank ----------
try:
    ds_ref = RAFDataset(scene_root=SCENE_ROOT, split=REF_SPLIT_TRY,
                        model_kind="neraf", sample_rate=EVAL_FS, dataset_mode="full")
    ref_split_name = REF_SPLIT_TRY
except Exception as e:
    print(f"[warn] '{REF_SPLIT_TRY}' split missing ({e}); using 'train' as reference bank.")
    ds_ref = RAFDataset(scene_root=SCENE_ROOT, split="train",
                        model_kind="neraf", sample_rate=EVAL_FS, dataset_mode="full")
    ref_split_name = "train"

ref_ids    = list(ds_ref.ids)
id2idx_ref = ds_ref.id2idx
R = len(ref_ids)
print(f"[info] Reference bank: split='{ref_split_name}' | #items={R}")

# ---------- Preload references (CPU) ----------
REF_STFTS, REF_WAVS = [], []
for sid in tqdm(ref_ids, desc="Load refs", unit="ref"):
    it = ds_ref[id2idx_ref[sid]]
    REF_STFTS.append(it["stft"].squeeze(0).contiguous())  # [F,60]
    REF_WAVS.append(it["wav"].squeeze(0).contiguous())    # [T]
F, T = REF_STFTS[0].shape
print(f"[info] Reference STFT grid: F={F}, T={T}")

# ---------- Load queries (CPU) ----------
Q_STFTS, Q_WAVS = [], []
for fp in tqdm(render_files, desc="Load queries", unit="file"):
    pack = np.load(fp, allow_pickle=True).item()
    Q_STFTS.append(torch.from_numpy(pack["data"]).float().squeeze(0).contiguous())      # [F,60] GT
    Q_WAVS.append(torch.from_numpy(pack["waveform"]).float().squeeze(0).contiguous())   # [T]

# ---------- EDC CURVES: compute ONCE for all signals ----------
@torch.no_grad()
def _edc_curve(wav_1d: torch.Tensor, bins: int = EDC_BINS) -> torch.Tensor:
    # compute_edc_db returns shape [T_edc]; we fix T_edc=bins for consistency
    return compute_edc_db(wav_1d.float(), T_target=bins)  # on current device of wav_1d

def _normalize_edc_curves(mat: np.ndarray) -> np.ndarray:
    # mat: (N, BINS), in dB
    X = mat.copy()
    if EDC_ANCHOR_ZERO:
        X = X - X[:, :1]   # anchor at 0 dB at start
    if EDC_ZSCORE:
        mu = X.mean(axis=1, keepdims=True)
        sd = X.std(axis=1, keepdims=True) + 1e-6
        X = (X - mu) / sd
    return X.astype(np.float32)

def _edc_scalers():
    # Return scaling factor for distances to make magnitudes comparable
    if EDC_DIST_MODE.lower() == "l2":
        if EDC_L2_SCALE == "sqrtN":
            return 1.0 / math.sqrt(EDC_BINS)
        return 1.0
    elif EDC_DIST_MODE.lower() == "l1":
        if EDC_L1_SCALE == "mean":
            return 1.0 / float(EDC_BINS)
        return 1.0
    else:
        raise ValueError("EDC_DIST_MODE must be 'l1' or 'l2'")

print("\n[edc] Computing EDC curves for queries & refs (once each) ...")
Q_EDC = np.zeros((Q, EDC_BINS), dtype=np.float32)
for i, w in enumerate(tqdm(Q_WAVS, desc="EDC(Q)", unit="q")):
    edc = _edc_curve(w.to(DEVICE)).detach().cpu().numpy()
    Q_EDC[i, :] = edc

R_EDC = np.zeros((R, EDC_BINS), dtype=np.float32)
for j, w in enumerate(tqdm(REF_WAVS, desc="EDC(R)", unit="ref")):
    edc = _edc_curve(w.to(DEVICE)).detach().cpu().numpy()
    R_EDC[j, :] = edc

# Normalize per-curve (mimic hub path)
Q_EDC = _normalize_edc_curves(Q_EDC)  # (Q,B)
R_EDC = _normalize_edc_curves(R_EDC)  # (R,B)

# ---------- DECAY SCALARS: compute ONCE ----------
def _decay_scalars_for_wavs(wavs_list, tag):
    if not HAVE_DECAY_HELPERS:
        raise RuntimeError("Decay helpers missing in evaluator.py; cannot compute T60/EDT/C50/DR.")
    out = np.zeros((len(wavs_list), 4), dtype=np.float32)
    for i, w in enumerate(tqdm(wavs_list, desc=f"Decays({tag})", unit="wav")):
        w_np = w.detach().cpu().numpy()[None, :]
        t60, _ = compute_t60(w_np, w_np, fs=EVAL_FS, advanced=True)
        edt, _ = evaluate_edt(w_np, w_np, fs=EVAL_FS)
        c50, _ = evaluate_clarity(w_np, w_np, fs=EVAL_FS)
        t60 = float(np.atleast_1d(t60)[0]); edt = float(np.atleast_1d(edt)[0]); c50 = float(np.atleast_1d(c50)[0])
        dr = _dr_db(w.detach().cpu().numpy())
        out[i, :] = [t60, c50, edt, dr]
    return out  # (N,4)

DEC_Q = _decay_scalars_for_wavs(Q_WAVS, tag="Q")   # (Q,4)
DEC_R = _decay_scalars_for_wavs(REF_WAVS, tag="R") # (R,4)

# ---------- helper: save a matrix with global-Z ----------
def _save_matrix(metric_name: str, M: np.ndarray, meta_base: dict):
    Z, mu, sd = _global_z(M)
    os.makedirs(DIST_CACHE_DIR, exist_ok=True)
    path = os.path.join(DIST_CACHE_DIR, f"{metric_name}.npz")
    meta = {
        **meta_base, "metric": metric_name, "global_mu": mu, "global_sd": sd,
        "edc": {
            "bins": EDC_BINS, "anchor0": EDC_ANCHOR_ZERO, "zscore": EDC_ZSCORE,
            "dist_mode": EDC_DIST_MODE, "l2_scale": EDC_L2_SCALE, "l1_scale": EDC_L1_SCALE
        } if metric_name == "EDC" else None
    }
    np.savez_compressed(
        path,
        Z=Z.astype(np.float32),
        RAW=M.astype(np.float32),
        query_ids=np.array(query_ids, dtype=object),
        ref_ids=np.array(ref_ids, dtype=object),
        meta=json.dumps(meta)
    )
    print(f"[saved] {metric_name}: {path} | mu={mu:.6f} sd={sd:.6f}")

# ---------- META ----------
meta = {
    "scene_root": SCENE_ROOT,
    "renders_glob": RENDERS_GLOB,
    "fs": EVAL_FS,
    "ref_split": ref_split_name,
    "device": DEVICE,
    "notes": ("Rows=first K npys, Cols=all refs. RAW + global-Z. "
              "EDC via per-signal curves + broadcast distances; "
              "SPL via compute_audio_distance; Decays via broadcasted |ref-query|.")
}

# ---------- EDC matrix (K,R) via broadcasting over precomputed curves ----------
if "EDC" in METRICS:
    print(f"\n[metric] EDC({EDC_DIST_MODE.upper()}): broadcasting (K={Q}, R={R}, B={EDC_BINS}) ...")
    scale = _edc_scalers()

    # Q_EDC: (Q,B), R_EDC: (R,B) —> pairwise distances (Q,R)
    # Use memory-efficient trick: process refs in chunks if RAM is tight
    M = np.zeros((Q, R), dtype=np.float32)
    CH = max(1, 131072 // EDC_BINS)  # rough chunking heuristic to bound RAM
    for start in range(0, R, CH):
        end = min(start + CH, R)
        Rblk = R_EDC[start:end, :]                 # (b,B)
        # Expand and compute |Q[:,None,:] - Rblk[None,:,:]|
        diff = Q_EDC[:, None, :] - Rblk[None, :, :]   # (Q,b,B)
        if EDC_DIST_MODE.lower() == "l2":
            d = np.linalg.norm(diff, axis=2)          # (Q,b)
            if EDC_L2_SCALE == "sqrtN":
                d = d * scale
        elif EDC_DIST_MODE.lower() == "l1":
            d = np.sum(np.abs(diff), axis=2)          # (Q,b)
            if EDC_L1_SCALE == "mean":
                d = d * scale
        else:
            raise ValueError("EDC_DIST_MODE must be 'l1' or 'l2'")
        M[:, start:end] = d.astype(np.float32)

    _save_matrix("EDC", M, meta)

# ---------- SPL matrix (K,R) via GPU hub ----------
if "SPL" in METRICS:
    print(f"\n[metric] SPL: computing (K={Q}, R={R}) via hub ...")
    M = np.zeros((Q, R), dtype=np.float32)
    for qi in tqdm(range(Q), desc="SPL per-query", unit="q"):
        q_stft = Q_STFTS[qi].to(DEVICE, non_blocking=True)
        row_parts = []
        for start in range(0, R, REF_BATCH):
            end = min(start + REF_BATCH, R)
            ref_stfts_b = torch.stack(REF_STFTS[start:end], dim=0).to(DEVICE, non_blocking=True) # [B,F,T]
            stft_blk = torch.cat([q_stft.unsqueeze(0), ref_stfts_b], dim=0)                       # [1+B,F,T]
            with torch.no_grad():
                D = compute_audio_distance(stft_blk, wavs=None, edc_curves=None,
                                           decay_feats=None, metric="SPL", fs=EVAL_FS)
            row_parts.append(D[0, 1:].detach().cpu().numpy())
            del ref_stfts_b, stft_blk, D
            torch.cuda.empty_cache()
        M[qi, :] = np.concatenate(row_parts, axis=0)
    _save_matrix("SPL", M, meta)

# ---------- Decay matrices (K,R) via broadcasted |ref - query| ----------
col = {"T60":0, "C50":1, "EDT":2, "DR":3}
for m in ["T60","EDT","C50","DR"]:
    if m not in METRICS: continue
    print(f"\n[metric] {m}: broadcasting (K={Q}, R={R}) ...")
    qi = DEC_Q[:, col[m]].reshape(Q, 1)     # (Q,1)
    rj = DEC_R[:, col[m]].reshape(1, R)     # (1,R)
    M = np.abs(rj - qi)                      # (Q,R)
    _save_matrix(m, M, meta)

print("\n[done] All requested metrics cached globally in ./dist_cache/")


[info] Queries (K) = 20
[info] Reference bank: split='reference' | #items=20000


Load refs:   0%|          | 0/20000 [00:00<?, ?ref/s]

[info] Reference STFT grid: F=513, T=60


Load queries:   0%|          | 0/20 [00:00<?, ?file/s]


[edc] Computing EDC curves for queries & refs (once each) ...


EDC(Q):   0%|          | 0/20 [00:00<?, ?q/s]

EDC(R):   0%|          | 0/20000 [00:00<?, ?ref/s]

Decays(Q):   0%|          | 0/20 [00:00<?, ?wav/s]

Decays(R):   0%|          | 0/20000 [00:00<?, ?wav/s]


[metric] EDC(L1): broadcasting (K=20, R=20000, B=60) ...
[saved] EDC: ./dist_cache/EDC.npz | mu=0.141915 sd=0.065945

[metric] SPL: computing (K=20, R=20000) via hub ...


SPL per-query:   0%|          | 0/20 [00:00<?, ?q/s]

[saved] SPL: ./dist_cache/SPL.npz | mu=0.587531 sd=0.340065

[metric] T60: broadcasting (K=20, R=20000) ...
[saved] T60: ./dist_cache/T60.npz | mu=0.169088 sd=0.135457

[metric] EDT: broadcasting (K=20, R=20000) ...
[saved] EDT: ./dist_cache/EDT.npz | mu=0.256786 sd=0.219765

[metric] C50: broadcasting (K=20, R=20000) ...
[saved] C50: ./dist_cache/C50.npz | mu=8.121808 sd=6.199946

[metric] DR: broadcasting (K=20, R=20000) ...
[saved] DR: ./dist_cache/DR.npz | mu=5.458077 sd=4.075886

[done] All requested metrics cached globally in ./dist_cache/


In [None]:
# CELL 2 — Mix cached global-z metrics (first K queries), lazy-evaluate Top-k only
# --------------------------------------------------------------------------------
# - Loads ./dist_cache/*.npz built by Cell 1
# - Restricts to first K queries (same K used in Cell 1)
# - Combines any subset of metrics with normalized weights
# - Retrieves Top-k per query
# - Optional: fast evaluation summary that loads ONLY the Top-k refs (lazy), not all refs

import os, sys, glob, json, math
import numpy as np
import torch
from tqdm.auto import tqdm

# ---------------- KNOBS ----------------
CACHE_DIR       = "./dist_cache"

# Choose any subset and weights (renormalized to sum=1 over the active subset present on disk)
WEIGHTS         = {
    "EDC": 0.6,
    "SPL": 0.4,
    # "T60": 0.0,
    # "EDT": 0.0,
    # "C50": 0.0,
    # "DR":  0.0,
}

TOPK            = 3
DO_EVAL_SUMMARY = True      # set False to skip evaluator (just prints indices)

# IMPORTANT: limit to first K queries (same K used in Cell 1). None => use full cache size.
K_LIMIT         = 20

# ---------------------------------------
# Repo imports for evaluation only
sys.path.append(".")
sys.path.append("/mnt/data")
from evaluator import UnifiedEvaluator
try:
    from data import RAFDataset
except Exception:
    raise RuntimeError("Cannot import RAFDataset; adjust sys.path to your repo.")

def _load_metric(name: str):
    path = os.path.join(CACHE_DIR, f"{name}.npz")
    if not os.path.exists(path):
        return None
    pack = np.load(path, allow_pickle=True)
    Z = pack["Z"].astype(np.float32)                     # (Q,R)
    query_ids = list(pack["query_ids"])
    ref_ids   = list(pack["ref_ids"])
    meta = json.loads(str(pack["meta"]))
    return {"Z": Z, "query_ids": query_ids, "ref_ids": ref_ids, "meta": meta}

# ---- Resolve active metrics present on disk with positive weight
candidates = [m for m,w in WEIGHTS.items() if w > 0.0]
present = [m for m in candidates if os.path.exists(os.path.join(CACHE_DIR, f"{m}.npz"))]
assert len(present) > 0, f"No active metrics available in {CACHE_DIR} for weights={candidates}"

root = _load_metric(present[0])
Z0, query_ids_cache, ref_ids_cache, meta0 = root["Z"], root["query_ids"], root["ref_ids"], root["meta"]
Q_full, R = Z0.shape

# ---- Enforce K_LIMIT consistently (rows of all matrices)
if K_LIMIT is None:
    K = Q_full
else:
    K = int(min(K_LIMIT, Q_full))

# Slice helper for matrices and query id list
def _take_first_k_rows(M): return M[:K, :] if M.shape[0] >= K else M
query_ids = query_ids_cache[:K]

# Load remaining matrices and align
MATS = {present[0]: _take_first_k_rows(Z0)}
for m in present[1:]:
    dat = _load_metric(m)
    assert dat is not None, f"{m} vanished from cache."
    # Consistency checks (ordering must be identical across metrics)
    assert dat["ref_ids"] == ref_ids_cache, f"ref_ids mismatch between metrics ({present[0]} vs {m})"
    assert dat["query_ids"] == query_ids_cache, f"query_ids mismatch between metrics ({present[0]} vs {m})"
    MATS[m] = _take_first_k_rows(dat["Z"])

# Combine with normalized weights over the 'present' keys
ws = np.array([WEIGHTS[m] for m in present], dtype=np.float64)
ws = ws / (ws.sum() + 1e-12)

COMB = np.zeros((K, R), dtype=np.float32)
for i, m in enumerate(present):
    COMB += float(ws[i]) * MATS[m]

# ---- Top-k per query (smallest scores)
topk_indices = []
topk_ids     = []
for qi in range(K):
    idx = np.argpartition(COMB[qi], TOPK)[:TOPK]
    idx = idx[np.argsort(COMB[qi, idx])]  # stable sort of the k smallest
    topk_indices.append(idx.tolist())
    topk_ids.append([ref_ids_cache[j] for j in idx.tolist()])

print(f"Active metrics (global-z): {present} | weights={ws.round(6).tolist()}")
for qi in range(K):
    print(f"[q={query_ids[qi]}] TOP{TOPK}: {topk_ids[qi]}  scores={[float(COMB[qi,j]) for j in topk_indices[qi]]}")

# ---------------- Optional: evaluation summary (Baseline vs Top-k) ----------------
if DO_EVAL_SUMMARY:
    # Pull scene_root and renders_glob from cache meta to avoid mismatches with Cell 1
    SCENE_ROOT   = meta0.get("scene_root", "../NeRAF/data/RAF/FurnishedRoom")
    RENDERS_GLOB = meta0.get("renders_glob", "../NeRAF/eval_results/furnishedroom/renders/eval_*.npy")
    EVAL_FS      = int(meta0.get("fs", 48000))
    DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"

    # Build a lookup from filename to full path using the glob from cache meta
    render_all = sorted(glob.glob(RENDEERS_GLOB := RENDERS_GLOB))  # store str for messages
    name2path = {os.path.basename(p): p for p in render_all}

    # Use EXACT same first-K filenames from cache (fallback to position-based if missing)
    render_files = []
    for nm in query_ids:
        if nm in name2path:
            render_files.append(name2path[nm])
        else:
            print(f"[warn] {nm} not found under {RENDEERS_GLOB}; falling back to glob order.")
    if len(render_files) < K:
        render_files = sorted(glob.glob(RENDEERS_GLOB))[:K]

    # ===== SPEED FIX: only load the references we actually need (Top-k set) =====
    needed_ref_idx = sorted(set(j for row in topk_indices for j in row))  # unique global indices
    idxmap = {j:i for i,j in enumerate(needed_ref_idx)}  # map global gidx -> compact 0..M-1
    print(f"[info] Will load only {len(needed_ref_idx)} referenced STFTs out of R={R}.")

    ds_ref = RAFDataset(scene_root=SCENE_ROOT, split=meta0.get("ref_split","reference"),
                        model_kind="neraf", sample_rate=EVAL_FS, dataset_mode="full")
    id2idx_ref = ds_ref.id2idx

    # Load STFTs just for needed refs (CPU, lazily)
    SMALL_REF_STFTS = {}
    for gidx in tqdm(needed_ref_idx, desc="Load Top-k refs", unit="ref"):
        sid = ref_ids_cache[gidx]
        it = ds_ref[id2idx_ref[sid]]
        SMALL_REF_STFTS[gidx] = it["stft"].squeeze(0).contiguous()

    evaluator = UnifiedEvaluator(fs=EVAL_FS, edc_bins=60, edc_dist="l2")

    agg_keys = ["stft","edc","t60","edt","c50"]
    def _zero_agg():
        return {k:[0.0,0] for k in (
            ["base_"+x for x in agg_keys] +
            sum([[f"t{i}_{x}", f"d{i}_{x}"] for i in (1,2,3) for x in agg_keys], [])
        )}

    def _add(agg, k, v):
        vv = float(v)
        if math.isfinite(vv):
            agg[k][0] += vv; agg[k][1] += 1

    def _mean(pair): s,n = pair; return (s/max(n,1)) if n>0 else float("nan")

    @torch.no_grad()
    def _eval_pair(stft_a_2d, stft_b_2d):
        return evaluator.evaluate(
            stft_a_2d.view(1,1,*stft_a_2d.shape).to(DEVICE),
            stft_b_2d.view(1,1,*stft_b_2d.shape).to(DEVICE)
        )

    # Evaluate baseline and only Top-k refs
    agg = _zero_agg()
    for qi, fp in enumerate(tqdm(render_files[:K], desc="Eval summary", unit="q")):
        pack = np.load(fp, allow_pickle=True).item()
        stft_gt  = torch.from_numpy(pack["data"]).float().squeeze(0)
        stft_pr  = torch.from_numpy(pack["pred_stft"]).float().squeeze(0)

        base = _eval_pair(stft_pr, stft_gt)
        for k in agg_keys: _add(agg, f"base_{k}", base[k])

        # Top-k from combined matrix (already aligned to cache ref_ids)
        sel = topk_indices[qi]
        for rank, tag in enumerate(("t1","t2","t3"), start=1):
            if len(sel) < rank: continue
            gidx = sel[rank-1]
            ref_st = SMALL_REF_STFTS[gidx].to(DEVICE)
            top = _eval_pair(ref_st, stft_gt)
            for k in agg_keys:
                _add(agg, f"{tag}_{k}", top[k])
                _add(agg, f"d{rank}_{k}", float(top[k]-base[k]))

    # Print compact summary
    print("\n===== SUMMARY (means over processed files) =====")
    print(f"Files processed: {K} | Reference pool: {R} | Active metrics: {present}")
    print(f"{'':12s}{'stft':>10s}{'edc':>10s}{'t60':>10s}{'edt':>10s}{'c50':>10s}")

    def _line(title, keys):
        vals = [_mean(agg[k]) for k in keys]
        print(f"{title:<12s}" + "".join([f"{v:10.6f}" for v in vals]))

    _line("Baseline",  [f"base_{k}" for k in agg_keys])
    for r, tag in enumerate(("t1","t2","t3"), start=1):
        _line(f"Top{r}",       [f"{tag}_{k}" for k in agg_keys])
        _line(f"ΔTop{r}-Base", [f"d{r}_{k}"  for k in agg_keys])
    print()


Active metrics (global-z): ['EDC', 'SPL'] | weights=[0.6, 0.4]
[q=eval_000000.npy] TOP3: ['036096', '036106', '036105']  scores=[-1.6228878498077393, -1.5719029903411865, -1.551689624786377]
[q=eval_000001.npy] TOP3: ['036438', '036441', '036443']  scores=[-1.5854880809783936, -1.525365948677063, -1.5044968128204346]
[q=eval_000002.npy] TOP3: ['029754', '029744', '029737']  scores=[-1.611403465270996, -1.5437387228012085, -1.472914218902588]
[q=eval_000003.npy] TOP3: ['010023', '027124', '019589']  scores=[-1.394456386566162, -1.3839342594146729, -1.3662970066070557]
[q=eval_000004.npy] TOP3: ['033850', '033871', '033873']  scores=[-1.6325452327728271, -1.5399271249771118, -1.5194473266601562]
[info] Will load only 15 referenced STFTs out of R=20000.


Load Top-k refs:   0%|          | 0/15 [00:00<?, ?ref/s]

Eval summary:   0%|          | 0/5 [00:00<?, ?q/s]


===== SUMMARY (means over processed files) =====
Files processed: 5 | Reference pool: 20000 | Active metrics: ['EDC', 'SPL']
                  stft       edc       t60       edt       c50
Baseline      0.170139  0.118459  8.222797  0.034875  0.358303
Top1          0.213854  0.064593  3.754443  0.019775  0.255973
ΔTop1-Base    0.043714 -0.053866 -4.468353 -0.015100 -0.102330
Top2          0.213429  0.077529  3.190934  0.031950  0.619630
ΔTop2-Base    0.043290 -0.040930 -5.031863 -0.002925  0.261327
Top3          0.217224  0.112896  5.471308  0.030350  0.463162
ΔTop3-Base    0.047085 -0.005563 -2.751488 -0.004525  0.104859



In [38]:
# === CELL 1: Build per-metric distance caches (Numba + tqdm + chunking) ===
import os, glob, json, math, numpy as np, torch, numba
from tqdm import tqdm
from data import RAFDataset

# ---------------- KNOBS ----------------
SCENE_ROOT  = "../NeRAF/data/RAF/FurnishedRoom"
RENDS_GLOB  = "../NeRAF/eval_results/furnishedroom/renders/eval_*.npy"
REF_SPLIT   = "reference"       # reference bank
K_QUERIES   = 50            # first-K npy files; -1 => use all
EDC_BINS    = 60
DIST_DIR    = "./dist_cache"
os.makedirs(DIST_DIR, exist_ok=True)

# Chunk sizes for progress + lower peak RAM (tune if needed)
Q_CHUNK = 64     # rows per chunk (queries)
R_CHUNK = None   # keep None (we compute vs. all refs per q-chunk); set to int for 2D tiling if memory is tight

# ---------------- Numba kernels ----------------
@numba.njit(fastmath=True)
def _zscore_rowwise(M):
    Q,R = M.shape
    Z = np.empty_like(M)
    for i in range(Q):
        mu = np.mean(M[i]); sd = np.std(M[i]) + 1e-12
        Z[i] = (M[i] - mu) / sd
    return Z

@numba.njit(fastmath=True, parallel=True)
def _pairwise_l1(A, B):
    Q, R = A.shape[0], B.shape[0]
    D = np.empty((Q, R), np.float64)
    for i in numba.prange(Q):
        for j in range(R):
            D[i, j] = np.sum(np.abs(A[i] - B[j]))
    return D

@numba.njit(fastmath=True, parallel=True)
def _pairwise_l2(A, B):
    Q, R = A.shape[0], B.shape[0]
    D = np.empty((Q, R), np.float64)
    for i in numba.prange(Q):
        for j in range(R):
            diff = A[i] - B[j]
            D[i, j] = math.sqrt(np.sum(diff * diff))
    return D

@numba.njit(fastmath=True)
def _edc_db(wav, bins=60):
    e = wav * wav
    edc = np.cumsum(e[::-1])[::-1]
    edc = edc / (edc[0] + 1e-12)
    edc_db = 10.0 * np.log10(edc + 1e-12)
    idx = np.linspace(0, edc_db.size - 1, bins).astype(np.int64)
    return edc_db[idx]

@numba.njit(fastmath=True)
def _decay_from_edc(edc_db):
    # EDT via -10 dB crossing, T60 via -60 dB crossing (fallbacks if not reached)
    n = edc_db.size; s0 = edc_db[0]
    edt_f = -1; t60_f = -1
    for i in range(n):
        drop = s0 - edc_db[i]
        if edt_f < 0 and drop >= 10.0: edt_f = i
        if t60_f < 0 and drop >= 60.0:
            t60_f = i; break
    if t60_f < 0: t60_f = n
    if edt_f < 0: edt_f = min(n, int(t60_f / 6))
    edc_lin = 10.0 ** (edc_db / 10.0)
    split = n // 12  # crude ~50ms if ~0.32s total at 60 bins
    c50_db = 10.0 * np.log10((np.sum(edc_lin[:split]) + 1e-12) / (np.sum(edc_lin[split:]) + 1e-12))
    dr_db = 10.0 * np.log10((np.max(edc_lin) + 1e-12) / (np.mean(edc_lin) + 1e-12))
    # return T60 (frames), EDT (~ms proxy via *6), C50 (dB), DR (dB)
    return float(t60_f), float(edt_f * 6), float(c50_db), float(dr_db)

@numba.njit(parallel=True, fastmath=True)
def _batch_edc_decay(wavs, bins=60):
    N = wavs.shape[0]
    E = np.empty((N, bins), np.float64)
    D = np.empty((N, 4),   np.float64)  # T60, EDT, C50, DR
    for i in numba.prange(N):
        edc = _edc_db(wavs[i], bins)
        E[i] = edc
        t60, edt, c50, dr = _decay_from_edc(edc)
        D[i, 0] = t60; D[i, 1] = edt; D[i, 2] = c50; D[i, 3] = dr
    return E, D

def _save_small(name, Z, qids, rids):
    out = {
        "Z": Z.astype(np.float32),
        "query_ids": np.array(qids, dtype=object),
        "ref_ids":   np.array(rids, dtype=object),
        "meta": json.dumps({"metric": name, "norm": "row_zscore",
                            "bins": EDC_BINS, "K": int(Z.shape[0]), "R": int(Z.shape[1])})
    }
    path = os.path.join(DIST_DIR, f"{name}.npz")
    np.savez_compressed(path, **out)
    print(f"[saved] {name:8s} -> {path} | shape={Z.shape}")

def _iter_chunks(n, chunk):
    if chunk is None or chunk >= n:
        yield 0, n
    else:
        for s in range(0, n, chunk):
            e = min(s + chunk, n)
            yield s, e

# ---------------- Load queries/refs ----------------
files_all = sorted(glob.glob(RENDS_GLOB))
assert files_all, f"No npy renders under {RENDS_GLOB}"
files = files_all if (K_QUERIES is None or K_QUERIES <= 0) else files_all[:K_QUERIES]
query_ids = [os.path.basename(f) for f in files]

Q_STFT, Q_WAV = [], []
for f in tqdm(files, desc="Load queries"):
    d = np.load(f, allow_pickle=True).item()
    # If your npy keys differ, adjust here
    q_stft = torch.from_numpy(d["data"]).float().squeeze(0).numpy().astype(np.float64)
    q_wav  = torch.from_numpy(d["waveform"]).float().squeeze(0).numpy().astype(np.float64)
    Q_STFT.append(q_stft)
    Q_WAV.append(q_wav)
Q_STFT = np.stack(Q_STFT)                         # (K, F, T)
Q_WAV  = np.stack(Q_WAV)                          # (K, Tw)

ds_ref = RAFDataset(scene_root=SCENE_ROOT, split=REF_SPLIT,
                    model_kind="neraf", sample_rate=48000, dataset_mode="full")
ref_ids = list(ds_ref.ids)
R = len(ref_ids)

REF_STFT, REF_WAV = [], []
for sid in tqdm(ref_ids, desc=f"Load refs[{REF_SPLIT}]"):
    it = ds_ref[ds_ref.id2idx[sid]]
    REF_STFT.append(it["stft"].squeeze(0).numpy().astype(np.float64))
    REF_WAV.append(it["wav"].squeeze(0).numpy().astype(np.float64))
REF_STFT = np.stack(REF_STFT)                     # (R, F, T)
REF_WAV  = np.stack(REF_WAV)                      # (R, Tw)

K, F, T = Q_STFT.shape
print(f"[info] Queries: K={K}, STFT=({F},{T}), Refs: R={R}, WAV_T={REF_WAV.shape[1]}")

# ---------------- Precompute EDC & decays ----------------
print("Precompute EDC/decays ...")
Q_EDC, Q_DEC = _batch_edc_decay(Q_WAV, bins=EDC_BINS)    # Q_DEC: (K, 4) = [T60, EDT, C50, DR]
R_EDC, R_DEC = _batch_edc_decay(REF_WAV, bins=EDC_BINS)  # R_DEC: (R, 4)

# ---------------- Helper: chunked pairwise with tqdm ----------------
def _chunked_rowwise_zscore_and_save(name, row_builder):
    """
    row_builder(i0, i1) -> np.ndarray[(i1-i0), R] of float64 distances
    We z-score per row and write into full Z.
    """
    Z = np.empty((K, R), dtype=np.float32)
    num_chunks = sum(1 for _ in _iter_chunks(K, Q_CHUNK))
    pbar = tqdm(total=num_chunks, desc=name)
    for i0, i1 in _iter_chunks(K, Q_CHUNK):
        D_chunk = row_builder(i0, i1)               # float64 (rows, R)
        Z_chunk = _zscore_rowwise(D_chunk)
        Z[i0:i1] = Z_chunk.astype(np.float32)
        pbar.update(1)
    pbar.close()
    _save_small(name, Z, query_ids, ref_ids)

# ---------------- STFT distances (L2 over log-mag) ----------------
print("Pairwise STFT L2 ...")
Q_STFT_F = Q_STFT.reshape(K, -1)                  # (K, F*T)
R_STFT_F = REF_STFT.reshape(R, -1)                # (R, F*T)
def _stft_row_builder(i0, i1):
    return _pairwise_l2(Q_STFT_F[i0:i1], R_STFT_F)
_chunked_rowwise_zscore_and_save("STFT", _stft_row_builder)

# ---------------- EDC distances (L1 + L2) ----------------
print("Pairwise EDC (L1/L2) ...")

def _edc_l1_row_builder(i0, i1):
    return _pairwise_l1(Q_EDC[i0:i1], R_EDC)
_chunked_rowwise_zscore_and_save("EDC_L1", _edc_l1_row_builder)

def _edc_l2_row_builder(i0, i1):
    return _pairwise_l2(Q_EDC[i0:i1], R_EDC)
_chunked_rowwise_zscore_and_save("EDC_L2", _edc_l2_row_builder)

# ---------------- Decay distances (each metric separate) ----------------
print("Decay |Δ| matrices ...")
names = ["DEC_T60", "DEC_EDT", "DEC_C50", "DEC_DR"]
for c, name in enumerate(names):
    def _decay_row_builder(i0, i1, c=c):
        q = Q_DEC[i0:i1, c].reshape(-1, 1)   # (chunk,1)
        r = R_DEC[:, c].reshape(1, -1)       # (1,R)
        return np.abs(q - r)                 # (chunk,R)
    _chunked_rowwise_zscore_and_save(name, _decay_row_builder)

print("All distance caches built.")

Load queries: 100%|██████████| 50/50 [00:00<00:00, 2292.42it/s]
Load refs[train]: 100%|██████████| 31305/31305 [01:00<00:00, 519.15it/s]


[info] Queries: K=50, STFT=(513,60), Refs: R=31305, WAV_T=15360
Precompute EDC/decays ...
Pairwise STFT L2 ...


STFT: 100%|██████████| 1/1 [00:09<00:00,  9.59s/it]


[saved] STFT     -> ./dist_cache/STFT.npz | shape=(50, 31305)
Pairwise EDC (L1/L2) ...


EDC_L1: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]


[saved] EDC_L1   -> ./dist_cache/EDC_L1.npz | shape=(50, 31305)


EDC_L2: 100%|██████████| 1/1 [00:00<00:00, 24.21it/s]


[saved] EDC_L2   -> ./dist_cache/EDC_L2.npz | shape=(50, 31305)
Decay |Δ| matrices ...


DEC_T60: 100%|██████████| 1/1 [00:00<00:00, 83.66it/s]


[saved] DEC_T60  -> ./dist_cache/DEC_T60.npz | shape=(50, 31305)


DEC_EDT: 100%|██████████| 1/1 [00:00<00:00, 53.19it/s]


[saved] DEC_EDT  -> ./dist_cache/DEC_EDT.npz | shape=(50, 31305)


DEC_C50: 100%|██████████| 1/1 [00:00<00:00, 87.56it/s]


[saved] DEC_C50  -> ./dist_cache/DEC_C50.npz | shape=(50, 31305)


DEC_DR: 100%|██████████| 1/1 [00:00<00:00, 80.09it/s]


[saved] DEC_DR   -> ./dist_cache/DEC_DR.npz | shape=(50, 31305)
All distance caches built.


In [None]:
# === CELL 2: Mix metrics, Top-K retrieval, summary ===========================
import os, json, glob, math, numpy as np, torch
from tqdm import tqdm

from data import RAFDataset
from evaluator import UnifiedEvaluator  # uses your Griffin-Lim/RAF helpers

DIST_DIR   = "./dist_cache"
TOPK       = 3
# Example: mix T60 & EDT only → set both to 0.5; set unused to 0.0
WEIGHTS = {
    "STFT":   0.5,
    "EDC_L1": 0.5,
    "EDC_L2": 0.0,
    "DEC_T60":0.0,
    "DEC_EDT":0.0,
    "DEC_C50":0.0,
    "DEC_DR": 0.0,
}

def _load_npz(name):
    p = os.path.join(DIST_DIR, f"{name}.npz")
    if not os.path.exists(p): return None
    d = np.load(p, allow_pickle=True)
    return d["Z"].astype(np.float64), list(d["query_ids"]), list(d["ref_ids"]), json.loads(str(d["meta"]))

# ---- Load available metrics
present = []
stacks  = []
query_ids = ref_ids = None
for m, w in WEIGHTS.items():
    if w <= 0: continue
    obj = _load_npz(m)
    if obj is None: continue
    Z, qids, rids, meta = obj
    if query_ids is None:
        query_ids, ref_ids = qids, rids
    else:
        assert qids == query_ids and rids == ref_ids, f"ID mismatch in {m}"
    present.append(m); stacks.append(Z)

assert present, "No active metrics found in cache."

ws = np.array([WEIGHTS[m] for m in present], dtype=np.float64)
ws = ws / (ws.sum() + 1e-12)

# ---- Combine
COMB = np.zeros_like(stacks[0])
for w, Z in zip(ws, stacks): COMB += w * Z

# ---- Top-k indices
topk_idx = []
for i in range(COMB.shape[0]):
    row = COMB[i]
    idx = np.argpartition(row, TOPK)[:TOPK]
    idx = idx[np.argsort(row[idx])]
    topk_idx.append(idx.tolist())

print(f"Active metrics: {present} | weights (norm) = {ws.round(3).tolist()}")

# ==== Evaluation Summary (Baseline vs Top-1/2/3) ============================
# Use the same scene/splits as Cell 1 (read from meta of first metric)
meta0 = json.loads(str(np.load(os.path.join(DIST_DIR, f"{present[0]}.npz"), allow_pickle=True)["meta"]))
SCENE_ROOT = "../NeRAF/data/RAF/FurnishedRoom"
RENDS_GLOB = "../NeRAF/eval_results/furnishedroom/renders/eval_*.npy"
K = int(meta0.get("K", len(query_ids)))
R = int(meta0.get("R", len(ref_ids)))

# Build robust item fetcher across splits to avoid KeyErrors
_ds_cache = {}
def _get_ds(split):
    if split not in _ds_cache:
        _ds_cache[split] = RAFDataset(scene_root=SCENE_ROOT, split=split,
                                      model_kind="neraf", sample_rate=48000, dataset_mode="full")
    return _ds_cache[split]

def fetch_ref_stft_by_id(sid):
    for split in ("test","reference","train"):
        ds = _get_ds(split)
        if sid in ds.id2idx:
            it = ds[ds.id2idx[sid]]
            return it["stft"].squeeze(0).numpy().astype(np.float64)
    raise KeyError(f"Ref id {sid} not found in any split")

# Resolve npy paths for the same first-K queries
render_all = sorted(glob.glob(RENDS_GLOB))
name2path  = {os.path.basename(p): p for p in render_all}
render_files = []
for nm in query_ids[:K]:
    if nm in name2path: render_files.append(name2path[nm])

evaluator = UnifiedEvaluator(fs=48000, edc_bins=60, edc_dist="l2")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

agg_keys = ["stft","edc","t60","edt","c50"]
def _zero():
    return {k:[0.0,0] for k in (
        ["base_"+x for x in agg_keys] +
        sum([[f"t{i}_{x}", f"d{i}_{x}"] for i in (1,2,3) for x in agg_keys], [])
    )}

def _add(agg,k,v):
    vv = float(v)
    if math.isfinite(vv): agg[k][0]+=vv; agg[k][1]+=1

def _mean(pair): s,n = pair; return (s/max(n,1)) if n>0 else float("nan")

@torch.no_grad()
def _eval_pair(stft_a_2d, stft_b_2d):
    # inputs: (F,T) float64 -> torch (1,1,F,T)
    A = torch.from_numpy(stft_a_2d).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
    B = torch.from_numpy(stft_b_2d).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
    out = evaluator.evaluate(A, B)
    return {k: float(out[k]) for k in agg_keys}

# Evaluate
agg = _zero()
for qi, qname in enumerate(tqdm(query_ids[:K], desc="Summary eval")):
    pack = np.load(name2path[qname], allow_pickle=True).item()
    stft_gt  = pack["data"].astype(np.float32).squeeze(0)       # (F,T)
    stft_pr  = pack["pred_stft"].astype(np.float32).squeeze(0)

    base = _eval_pair(stft_pr, stft_gt)
    for k in agg_keys: _add(agg, f"base_{k}", base[k])

    sel = topk_idx[qi]
    for rank, tag in enumerate(("t1","t2","t3"), start=1):
        if len(sel) < rank: continue
        gidx = sel[rank-1]
        sid  = ref_ids[gidx]
        try:
            ref_st = fetch_ref_stft_by_id(sid)
        except KeyError:
            # skip missing refs silently; this avoids hard failure if splits differ
            continue
        top = _eval_pair(ref_st, stft_gt)
        for k in agg_keys:
            _add(agg, f"{tag}_{k}", top[k])
            _add(agg, f"d{rank}_{k}", float(top[k]-base[k]))

# Print summary
print("\n===== SUMMARY (means over processed files) =====")
print(f"Files processed: {K} | Reference pool: {R} | Active metrics: {present}")
print(f"{'':14s}{'stft':>10s}{'edc':>10s}{'t60':>10s}{'edt':>10s}{'c50':>10s}")

def _line(title, keys):
    vals = [_mean(agg[k]) for k in keys]
    print(f"{title:<14s}" + "".join([f"{v:10.6f}" for v in vals]))

_line("Baseline",  [f"base_{k}" for k in agg_keys])
for r, tag in enumerate(("t1","t2","t3"), start=1):
    _line(f"Top{r}",       [f"{tag}_{k}" for k in agg_keys])
    _line(f"ΔTop{r}-Base", [f"d{r}_{k}"  for k in agg_keys])
print()

Active metrics: ['STFT', 'EDC_L1'] | weights (norm) = [0.5, 0.5]


Summary eval: 100%|██████████| 50/50 [00:52<00:00,  1.05s/it]


===== SUMMARY (means over processed files) =====
Files processed: 50 | Reference pool: 31305 | Active metrics: ['STFT', 'EDC_L1']
                    stft       edc       t60       edt       c50
Baseline        0.177603  0.123689  6.534929  0.017785  0.585330
Top1            0.213988  0.110675  4.692591  0.012855  0.409413
ΔTop1-Base      0.036386 -0.013014 -1.842338 -0.004930 -0.175917
Top2            0.218874  0.100802  5.119940  0.012923  0.432830
ΔTop2-Base      0.041272 -0.022888 -1.414988 -0.004862 -0.152501
Top3            0.221983  0.114606  5.694140  0.020872  0.564453
ΔTop3-Base      0.044381 -0.009084 -0.840789  0.003088 -0.020877






In [1]:
# === CELL: Build top-10 retrieval JSON (CPU + Numba, per-row zscore) ===
import os, json, math, glob
import numpy as np
import torch, numba
from tqdm import tqdm
from data import RAFDataset

# ---------------- KNOBS ----------------
SCENE_ROOT      = "../NeRAF/data/RAF/EmptyRoom"
SPLIT_JSON_PATH = os.path.join(SCENE_ROOT, "metadata", "data-split.json")
OUT_JSON_PATH   = "./references_empty.json"

REF_SPLIT   = "reference"   # which split is your reference bank
EDC_BINS    = 60            # downsample EDC curve to this many bins
EDC_METRIC  = "L1"          # "L2" or "L1" on EDC vectors
W_EDC, W_STFT = 0.6, 0.4    # weights

# Performance / memory knobs
Q_CHUNK     = 256            # process queries in batches of this many for distance computation
DRY_RUN_K   = None            # e.g., 10 to sanity-check; set to None to use ALL queries

# ---------------- Helpers: Numba kernels ----------------
@numba.njit(fastmath=True)
def _zscore_rowwise(M):
    Q, R = M.shape
    Z = np.empty_like(M)
    for i in range(Q):
        mu = 0.0
        for j in range(R):
            mu += M[i, j]
        mu /= R
        sd = 0.0
        for j in range(R):
            d = M[i, j] - mu
            sd += d * d
        sd = math.sqrt(sd / (R + 1e-12)) + 1e-12
        for j in range(R):
            Z[i, j] = (M[i, j] - mu) / sd
    return Z

@numba.njit(fastmath=True, parallel=True)
def _pairwise_l1(A, B):
    Q, R = A.shape[0], B.shape[0]
    D = np.empty((Q, R), np.float64)
    for i in numba.prange(Q):
        for j in range(R):
            s = 0.0
            ai = A[i]
            bj = B[j]
            for k in range(ai.size):
                s += abs(ai[k] - bj[k])
            D[i, j] = s
    return D

@numba.njit(fastmath=True, parallel=True)
def _pairwise_l2(A, B):
    Q, R = A.shape[0], B.shape[0]
    D = np.empty((Q, R), np.float64)
    for i in numba.prange(Q):
        for j in range(R):
            s = 0.0
            ai = A[i]
            bj = B[j]
            for k in range(ai.size):
                d = ai[k] - bj[k]
                s += d * d
            D[i, j] = math.sqrt(s)
    return D

@numba.njit(fastmath=True)
def _edc_db(wav, bins=60):
    # Schroeder integral (linear energy -> cumulative backward), normalize, convert to dB, then sample to bins
    e = wav * wav
    edc = np.cumsum(e[::-1])[::-1]
    edc = edc / (edc[0] + 1e-12)
    edc_db = 10.0 * np.log10(edc + 1e-12)
    idx = np.linspace(0, edc_db.size - 1, bins).astype(np.int64)
    return edc_db[idx]

@numba.njit(parallel=True, fastmath=True)
def _batch_edc_db(wavs, bins=60):
    N = wavs.shape[0]
    E = np.empty((N, bins), np.float64)
    for i in numba.prange(N):
        E[i] = _edc_db(wavs[i], bins)
    return E

# ---------------- Load splits & datasets ----------------
# Build query id list = all splits except REF_SPLIT
with open(SPLIT_JSON_PATH, "r") as f:
    splits = json.load(f)

query_id_list = []
query_splits = []
for k, v in splits.items():
    if k.lower() == REF_SPLIT.lower():
        continue
    # v can be a list or nested under [0] depending on your writer – normalize
    ids_k = v[0] if (isinstance(v, list) and len(v) > 0 and isinstance(v[0], list)) else v
    for sid in ids_k:
        query_id_list.append(sid)
        query_splits.append(k)

# Build reference dataset ids
ref_ids_src = splits[REF_SPLIT]
ref_ids = ref_ids_src[0] if (isinstance(ref_ids_src, list) and len(ref_ids_src) > 0 and isinstance(ref_ids_src[0], list)) else ref_ids_src
ref_ids = list(ref_ids)

# For data access we use RAFDataset to read tensors
# We'll instantiate per split to reuse its indexing and transforms
def _make_ds(split):
    return RAFDataset(scene_root=SCENE_ROOT, split=split, model_kind="neraf",
                      sample_rate=48000, dataset_mode="full")

print("[info] Loading reference set…")
ds_ref = _make_ds(REF_SPLIT)
id2idx_ref = ds_ref.id2idx

# Map query ids to a dataset: simplest is to route per-split through a cache of RAFDataset objects
ds_cache = {}
def _get_ds_for_split(split_name):
    if split_name not in ds_cache:
        ds_cache[split_name] = _make_ds(split_name)
    return ds_cache[split_name]

# Optional dry-run trimming
if isinstance(DRY_RUN_K, int) and DRY_RUN_K > 0:
    query_id_list = query_id_list[:DRY_RUN_K]
    query_splits = query_splits[:DRY_RUN_K]
print(f"[info] #Queries={len(query_id_list)} (dry-run={DRY_RUN_K if DRY_RUN_K else 'OFF'}), #Refs={len(ref_ids)}")

# ---------------- Collect features to RAM ----------------
# Refs: STFT (F,T) -> flatten, WAV -> EDC
REF_STFT, REF_WAV = [], []
for sid in tqdm(ref_ids, desc="Load refs (stft+wav)"):
    it = ds_ref[id2idx_ref[sid]]
    REF_STFT.append(it["stft"].squeeze(0).numpy().astype(np.float64))
    REF_WAV.append(it["wav"].squeeze(0).numpy().astype(np.float64))
REF_STFT = np.stack(REF_STFT)              # (R, F, T)
R, F, T = REF_STFT.shape
REF_WAV  = np.stack(REF_WAV)               # (R, Tw)
REF_STFT_F = REF_STFT.reshape(R, F*T)      # (R, F*T)
REF_EDC    = _batch_edc_db(REF_WAV, bins=EDC_BINS)  # (R, EDC_BINS)

# Queries: gather STFT/WAV in the original per-split datasets
Q_STFT_F = np.empty((len(query_id_list), F*T), dtype=np.float64)
Q_EDC    = np.empty((len(query_id_list), EDC_BINS), dtype=np.float64)
Q_IDS    = []
for qi, (sid, split_name) in enumerate(tqdm(zip(query_id_list, query_splits),
                                            total=len(query_id_list), desc="Load queries")):
    ds = _get_ds_for_split(split_name)
    it = ds[ds.id2idx[sid]]
    st = it["stft"].squeeze(0).numpy().astype(np.float64).reshape(-1)  # (F*T,)
    Q_STFT_F[qi] = st
    wav = it["wav"].squeeze(0).numpy().astype(np.float64)
    Q_EDC[qi] = _edc_db(wav, bins=EDC_BINS)
    Q_IDS.append(sid)

# ---------------- Chunked distance + rowwise z-score + fusion ----------------
def _iter_chunks(n, chunk):
    if chunk is None or chunk >= n:
        yield 0, n
    else:
        for s in range(0, n, chunk):
            e = min(s + chunk, n)
            yield s, e

def _compute_topk_for_chunk(i0, i1, Ktop=10):
    # STFT distances (L2 on log-mag)
    D_stft = _pairwise_l2(Q_STFT_F[i0:i1], REF_STFT_F).astype(np.float64)   # (q,R)
    # EDC distances (L1 or L2)
    if EDC_METRIC.upper() == "L1":
        D_edc = _pairwise_l1(Q_EDC[i0:i1], REF_EDC).astype(np.float64)
    else:
        D_edc = _pairwise_l2(Q_EDC[i0:i1], REF_EDC).astype(np.float64)

    # Rowwise z-score for each metric independently
    Z_stft = _zscore_rowwise(D_stft)
    Z_edc  = _zscore_rowwise(D_edc)

    # Weighted fusion
    Z_mix = W_STFT * Z_stft + W_EDC * Z_edc  # lower is better

    # Top-K indices per row
    topk_idx = np.argpartition(Z_mix, Ktop, axis=1)[:, :Ktop]  # unsorted top-K
    # sort those K by fused distance
    rows = Z_mix.shape[0]
    out_idx = np.empty_like(topk_idx)
    for r in range(rows):
        idxs = topk_idx[r]
        vals = Z_mix[r, idxs]
        order = np.argsort(vals)
        out_idx[r] = idxs[order]
    return out_idx, Z_mix

# Build mapping query_id -> top-10 ref IDs, excluding self if present in ref bank
result = {}
KTOP = 10
num_chunks = sum(1 for _ in _iter_chunks(len(Q_IDS), Q_CHUNK))
pbar = tqdm(total=num_chunks, desc="Compute & rank (chunked)")
for i0, i1 in _iter_chunks(len(Q_IDS), Q_CHUNK):
    topk_idx_chunk, _ = _compute_topk_for_chunk(i0, i1, Ktop=KTOP + 5)  # take a few extra to filter leakage
    for row, qi in enumerate(range(i0, i1)):
        qid = Q_IDS[qi]
        # map indices to ref IDs and filter out self if it appears
        cand = [ref_ids[j] for j in topk_idx_chunk[row]]
        cand_noself = [c for c in cand if c != qid]
        result[qid] = cand_noself[:KTOP]
    pbar.update(1)
pbar.close()

# ---------------- Save JSON ----------------
with open(OUT_JSON_PATH, "w") as f:
    json.dump(result, f, indent=2)
print(f"[saved] Top-10 retrieval JSON -> {OUT_JSON_PATH}")
print("Example item:", next(iter(result.items())))

[info] Loading reference set…
[info] #Queries=47484 (dry-run=OFF), #Refs=37987


Load refs (stft+wav): 100%|██████████| 37987/37987 [28:28<00:00, 22.24it/s]  
Load queries: 100%|██████████| 47484/47484 [07:45<00:00, 101.99it/s]
Compute & rank (chunked): 100%|██████████| 186/186 [40:34<00:00, 13.09s/it]


[saved] Top-10 retrieval JSON -> ./references_empty.json
Example item: ('026251', ['026256', '005596', '005587', '026258', '026244', '026253', '026261', '026870', '005585', '005589'])
