In [1]:
from pathlib import Path
import json, math
import numpy as np

# --- inputs ---
RX_POS_TXT = Path("../NeRAF/data/RAF/EmptyRoom/metadata/all_rx_pos.txt")
TX_POS_TXT = Path("../NeRAF/data/RAF/EmptyRoom/metadata/all_tx_pos.txt")
RUN_DIR_JSON = Path("./outputs/EmptyRoom/images-jpeg-1k/nerfacto/2025-11-04_172911/dataparser_transforms.json")

OUT_JSON = Path("camera_path.json")

# Viewer header
DEFAULT_FOV = 100.0
ASPECT = 1.5
RENDER_H, RENDER_W = 256,256
SECONDS = 999
IS_CYCLE = False
SMOOTHNESS = 0
CAMERA_TYPE = "perspective"

GLOBAL_UP = np.array([0.0, 0.0, 1.0], dtype=float)  # Z-up

# Control how many you keep
LIMIT = None     # None = all
STRIDE = 1       # keep every STRIDE-th pair (e.g., 5)

# ---------- helpers ----------
def float_no_sci(x, ndigits=12):
    s = f"{x:.{ndigits}f}"
    if "." in s:
        s = s.rstrip("0").rstrip(".")
        if s == "-0":
            s = "0"
        if "." not in s:
            s += ".0"
    return float(s)

def matrix_row_major_list(m4):
    return [float_no_sci(m4[i, j]) for i in range(4) for j in range(4)]

def load_dataparser_transform(run_dir_json):
    with open(run_dir_json, "r") as f:
        meta = json.load(f)
    T_3x4 = np.array(meta["transform"], dtype=float)  # (3,4)
    s = float(meta["scale"])
    R = T_3x4[:, :3]
    t = T_3x4[:, 3]
    return R, t, s

def normalize(v, eps=1e-9):
    n = np.linalg.norm(v)
    return v / n if n > eps else v * 0.0

def build_upright_look_at(cam_pos, target_pos, global_up=np.array([0,0,1.0])):
    f = normalize(target_pos - cam_pos)
    r = np.cross(global_up, f)
    if np.linalg.norm(r) < 1e-6:
        r = np.cross(np.array([1.0,0.0,0.0]), f)
        if np.linalg.norm(r) < 1e-6:
            r = np.array([0.0,1.0,0.0])
    r = normalize(r)
    u = np.cross(f, r)
    c2w = np.eye(4, dtype=float)
    c2w[:3, 0] = r
    c2w[:3, 1] = u
    c2w[:3, 2] = f
    c2w[:3, 3] = cam_pos
    return c2w

def parse_rx_all(path):
    """Yield RX positions: each line is y,z,x -> xyz; skip NaNs."""
    with open(path, "r") as f:
        for line in f:
            parts = [p.strip() for p in line.strip().split(",")]
            if len(parts) < 3: 
                continue
            try:
                y = float(parts[0]); z = float(parts[1]); x = float(parts[2])
            except ValueError:
                continue
            if all(math.isfinite(v) for v in (x,y,z)):
                yield np.array([x,y,z], dtype=float)

def parse_tx_all(path):
    """Yield TX quaternion (HypA yzxW -> xyzW) + TX position (yzx -> xyz)."""
    with open(path, "r") as f:
        for line in f:
            parts = [p.strip() for p in line.strip().split(",")]
            if len(parts) < 7:
                continue
            try:
                qy = float(parts[0]); qz = float(parts[1]); qx = float(parts[2]); qw = float(parts[3])  # yzxW
                py = float(parts[4]); pz = float(parts[5]); px = float(parts[6])                      # yzx
            except ValueError:
                continue
            if not all(math.isfinite(v) for v in (qx,qy,qz,qw,px,py,pz)):
                continue
            quat_xyzW = (qx, qy, qz, qw)
            pos_xyz   = np.array([px, py, pz], dtype=float)
            yield quat_xyzW, pos_xyz

def world_to_nerf_point(p_world, R, t, s):
    return s * (R @ p_world + t)

# ---------- main ----------
R_dp, t_dp, s_dp = load_dataparser_transform(RUN_DIR_JSON)

rx_iter = parse_rx_all(RX_POS_TXT)
tx_iter = parse_tx_all(TX_POS_TXT)

camera_entries = []
count_total = 0
count_kept = 0

for idx, (rx_p, tx_tuple) in enumerate(zip(rx_iter, tx_iter)):
    count_total += 1
    if idx % STRIDE != 0:
        continue
    _, tx_p = tx_tuple  # quat parsed but unused for upright mode
    # world -> NeRF
    rx_n = world_to_nerf_point(rx_p, R_dp, t_dp, s_dp)
    tx_n = world_to_nerf_point(tx_p, R_dp, t_dp, s_dp)
    # upright look-at
    c2w = build_upright_look_at(rx_n, tx_n, GLOBAL_UP)
    camera_entries.append({
        "camera_to_world": matrix_row_major_list(c2w),
        "fov": float_no_sci(DEFAULT_FOV),
        "aspect": float_no_sci(ASPECT),
    })
    count_kept += 1
    if LIMIT is not None and count_kept >= LIMIT:
        break

camera_path = {
    "default_fov": float_no_sci(DEFAULT_FOV),
    "default_transition_sec": 2,
    "camera_type": CAMERA_TYPE,
    "render_height": RENDER_H,
    "render_width": RENDER_W,
    "seconds": SECONDS,
    "is_cycle": IS_CYCLE,
    "smoothness_value": SMOOTHNESS,
    "camera_path": camera_entries,
}

with open(OUT_JSON, "w") as f:
    json.dump(camera_path, f, ensure_ascii=False, indent=2)

print(f"Wrote {OUT_JSON.resolve()} with {count_kept} / {count_total} paired cameras (stride={STRIDE}, limit={LIMIT}).")

Wrote /media/scratch/projects/labuser/msc_user/MoNezami/ReverbRAG/camera_path.json with 47484 / 47484 paired cameras (stride=1, limit=None).


In [23]:
import os, json, math, numpy as np, torch, torchaudio
from tqdm import tqdm

SCENE_ROOT = "../NeRAF/data/RAF/EmptyRoom"
DATA_DIR   = os.path.join(SCENE_ROOT, "data")
FEATS_DIR  = os.path.join(SCENE_ROOT, "feats")
os.makedirs(FEATS_DIR, exist_ok=True)

# STFT params (RAF)
sr = 48000
n_fft, win_length, hop_length = 1024, 512, 256
T = 60
F = n_fft//2 + 1
DT_STFT = np.float16
MAX_SHARD_MB = 1024

stft_tf = torchaudio.transforms.Spectrogram(
    n_fft=n_fft, win_length=win_length, hop_length=hop_length,
    power=None, center=True, pad_mode="reflect"
)
def _logmag(x): return torch.log(x.abs() + 1e-3)

def _collect_sids(meta_json):
    with open(meta_json, "r") as f:
        splits = json.load(f)
    sids = []
    for v in splits.values():
        block = v[0] if (isinstance(v, list) and v and isinstance(v[0], list)) else v
        for sid in block:
            sids.append(f"{int(sid):06d}" if str(sid).isdigit() else str(sid))
    return sorted(set(sids))

def _load_wav(sid):
    p = os.path.join(DATA_DIR, sid, "rir.wav")
    wav, r = torchaudio.load(p)
    if r != sr: wav = torchaudio.functional.resample(wav, r, sr)
    wav = wav[:, : int(0.32 * sr)]
    return wav

def _stft60(wav):
    spec = stft_tf(wav)  # [1,F,T_full]
    if spec.shape[-1] > T:
        spec = spec[:, :, :T]
    elif spec.shape[-1] < T:
        minval = float(spec.abs().min())
        spec = torch.nn.functional.pad(spec, (0, T - spec.shape[-1]), value=minval)
    return _logmag(spec).squeeze(0)  # [F,T]

# Discover SIDs and (maybe) an existing index.json
SPLIT_JSON = os.path.join(SCENE_ROOT, "metadata", "data-split.json")
sids = _collect_sids(SPLIT_JSON)
idx_path = os.path.join(FEATS_DIR, "index.json")
index_meta = {"shards": [], "sid_to_ptr": {}}
if os.path.exists(idx_path):
    with open(idx_path, "r") as f:
        index_meta = json.load(f)

# Compute shard sizing
bytes_per_stft = F*T*np.dtype(DT_STFT).itemsize
items_per_shard = max(1, (MAX_SHARD_MB*1024*1024)//bytes_per_stft)
N = len(sids)
num_shards = math.ceil(N/items_per_shard)
print(f"[STFT] Items/shard≈{items_per_shard} → #shards={num_shards}")

def shard_paths(k):
    base = os.path.join(FEATS_DIR, f"shard_{k:03d}")
    return base+"_stft.npy"

# Build shards
for k in range(num_shards):
    start = k*items_per_shard
    end   = min(N, (k+1)*items_per_shard)
    n_k   = end - start
    st_p  = shard_paths(k)

    if os.path.exists(st_p):
        print(f"[STFT] shard {k} exists, skipping write.")
        st_mm = np.memmap(st_p, dtype=DT_STFT, mode="r+", shape=(n_k, F, T))
    else:
        print(f"[STFT] writing shard {k}: {n_k} items → {st_p}")
        st_mm = np.memmap(st_p, dtype=DT_STFT, mode="w+", shape=(n_k, F, T))

    for i, sid in tqdm(list(enumerate(sids[start:end], start=0)), total=n_k, desc=f"[STFT] {k+1}/{num_shards}"):
        if sid in index_meta["sid_to_ptr"]:
            continue
        x = _stft60(_load_wav(sid)).cpu().numpy().astype(DT_STFT)
        st_mm[i] = x
        index_meta["sid_to_ptr"][sid] = [k, i]
    del st_mm

# Write/merge index
# ensure one entry per shard with at least 'stft' path
existing = {sh["id"]: sh for sh in index_meta.get("shards", [])}
for k in range(num_shards):
    st_p = shard_paths(k)
    if k in existing:
        existing[k]["stft"] = st_p
        existing[k]["count"] = existing[k].get("count", 0) or sum(1 for v in index_meta["sid_to_ptr"].values() if v[0]==k)
        existing[k]["F"] = F; existing[k]["T"] = T
        existing[k].setdefault("dtypes", {})["stft"] = str(DT_STFT)
    else:
        existing[k] = {"id": k, "stft": st_p, "count": sum(1 for v in index_meta["sid_to_ptr"].values() if v[0]==k),
                       "F": F, "T": T, "dtypes": {"stft": str(DT_STFT)}}
index_meta["shards"] = [existing[k] for k in sorted(existing.keys())]

with open(idx_path, "w") as f:
    json.dump(index_meta, f)
print("[STFT] Done. Index saved to", idx_path)


[STFT] Items/shard≈17442 → #shards=3
[STFT] writing shard 0: 17442 items → ../NeRAF/data/RAF/EmptyRoom/feats/shard_000_stft.npy


[STFT] 1/3: 100%|██████████| 17442/17442 [01:08<00:00, 255.23it/s]


[STFT] writing shard 1: 17442 items → ../NeRAF/data/RAF/EmptyRoom/feats/shard_001_stft.npy


[STFT] 2/3: 100%|██████████| 17442/17442 [01:08<00:00, 253.44it/s]


[STFT] writing shard 2: 12600 items → ../NeRAF/data/RAF/EmptyRoom/feats/shard_002_stft.npy


[STFT] 3/3: 100%|██████████| 12600/12600 [00:48<00:00, 260.68it/s]


[STFT] Done. Index saved to ../NeRAF/data/RAF/EmptyRoom/feats/index.json


In [24]:
import os, json, numpy as np, torch, torchaudio
from tqdm import tqdm

SCENE_ROOT = "../NeRAF/data/RAF/EmptyRoom"
DATA_DIR   = os.path.join(SCENE_ROOT, "data")
FEATS_DIR  = os.path.join(SCENE_ROOT, "feats")
IDX_PATH   = os.path.join(FEATS_DIR, "index.json")

assert os.path.exists(IDX_PATH), "Missing feats/index.json — build STFT shards first."

# EDC settings
sr = 48000
T = 60
DT_EDC = np.float32

def _load_wav(sid):
    p = os.path.join(DATA_DIR, sid, "rir.wav")
    wav, r = torchaudio.load(p)
    if r != sr: wav = torchaudio.functional.resample(wav, r, sr)
    return wav.squeeze(0)[: int(0.32 * sr)]  # [S]

def _edc_db_60(w1d):
    e = (w1d.float()**2)
    edc = torch.flip(torch.cumsum(torch.flip(e, [0]), 0), [0])
    edc = edc / (edc[0] + 1e-12)
    edc_db = 10*torch.log10(edc + 1e-12)
    idx = torch.linspace(0, edc_db.numel()-1, steps=T).long()
    return edc_db[idx]  # [T]

with open(IDX_PATH, "r") as f:
    idx = json.load(f)

sid_to_ptr = {k: tuple(v) for k, v in idx["sid_to_ptr"].items()}
# Gather SIDs per shard using the EXISTING mapping
shard_to_sidrows = {}
for sid, (k, row) in sid_to_ptr.items():
    shard_to_sidrows.setdefault(k, []).append((sid, row))

# Build/overwrite EDC shard files using the STFT shard counts
for sh in idx["shards"]:
    k = sh["id"]
    count = int(sh["count"])
    edc_path = os.path.join(FEATS_DIR, f"shard_{k:03d}_edc.npy")

    # Create memmap with shape matching the STFT shard
    ed_mm = np.memmap(edc_path, dtype=DT_EDC, mode="w+", shape=(count, T))
    rows = shard_to_sidrows.get(k, [])
    print(f"[EDC] shard {k}: writing {len(rows)} rows into {count}-row memmap -> {edc_path}")

    for sid, row in tqdm(rows, total=len(rows), desc=f"[EDC] {k:03d}"):
        # guard: row must be < count (if not, your index.json is already inconsistent with STFT shards)
        if not (0 <= row < count):
            raise RuntimeError(f"Index mismatch: SID {sid} maps to row {row} but shard {k} has count {count}")
        w = _load_wav(sid)
        ed_mm[row] = _edc_db_60(w).cpu().numpy().astype(DT_EDC)

    del ed_mm

    # annotate shard record with EDC path & dtype
    sh["edc"] = edc_path
    sh.setdefault("dtypes", {})["edc"] = str(DT_EDC)

with open(IDX_PATH, "w") as f:
    json.dump(idx, f)

print("[EDC] Repair complete. index.json updated.")


[EDC] shard 0: writing 17442 rows into 17442-row memmap -> ../NeRAF/data/RAF/EmptyRoom/feats/shard_000_edc.npy


[EDC] 000: 100%|██████████| 17442/17442 [00:33<00:00, 523.85it/s]


[EDC] shard 1: writing 17442 rows into 17442-row memmap -> ../NeRAF/data/RAF/EmptyRoom/feats/shard_001_edc.npy


[EDC] 001: 100%|██████████| 17442/17442 [00:35<00:00, 491.62it/s]


[EDC] shard 2: writing 12600 rows into 12600-row memmap -> ../NeRAF/data/RAF/EmptyRoom/feats/shard_002_edc.npy


[EDC] 002: 100%|██████████| 12600/12600 [00:24<00:00, 510.68it/s]


[EDC] Repair complete. index.json updated.
