# Preprocess subsequences to mmap_ninja (original DisruptCNN setting)

Pre-save **normalized** subsequences from decimated H5 using the same segment/tiling/label logic as `EceiDatasetOriginal` (shot lists, Twarn=300 ms). All normalization and preprocessing is done at save time so training/validation can load directly from memory-mapped files.

**Two variants (set `FLATTOP_ONLY` in config):**
- **`FLATTOP_ONLY = True`** (default): segment from t_flat_start to tend (flattop only). Output: `subseqs_original_mmap/`
- **`FLATTOP_ONLY = False`**: full segment from 0 to tend. Output: `subseqs_original_mmap_full/`

**Output dir contents:**
- `X/`, `target/`, `weight/` — RaggedMmap (mmap_ninja)
- `labels.npy` — seq_has_disrupt per index
- `train_inds.npy`, `test_inds.npy`, `val_inds.npy` — split indices
- `meta.json` — nsub, nrecept, data_step, pos_weight, neg_weight, flattop_only

**Requires:** `pip install mmap_ninja`

In [None]:
import json
import sys
from pathlib import Path

import numpy as np
import h5py
from tqdm import tqdm

# Project root (soen_fusion_zero)
ROOT = Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from disruptcnn.dataset_original import EceiDatasetOriginal

try:
    from mmap_ninja import RaggedMmap
    HAS_MMAP_NINJA = True
except ImportError:
    HAS_MMAP_NINJA = False
    print("Install mmap_ninja: pip install mmap_ninja")

# ── Config (edit for your environment) ─────────────────────────────
ROOT_DATA = Path("/home/idies/workspace/Storage/yhuang2/persistent/ecei")
DECIMATED_ROOT = ROOT_DATA / "dsrpt_decimated"
DISRUPT_FILE = "disruptcnn/shots/d3d_disrupt_ecei.final.txt"
CLEAR_FILE = "disruptcnn/shots/d3d_clear_ecei.final.txt"  # or None for disrupt-only
NORM_STATS = ROOT / "norm_stats.npz"

# flattop_only=True  → segment from t_flat_start to tend (flattop only); saves to subseqs_original_mmap
# flattop_only=False → full segment from 0 to tend; saves to subseqs_original_mmap_full
FLATTOP_ONLY = True  # set False for full-segment version
OUTPUT_DIR = ROOT / ("subseqs_original_mmap" if FLATTOP_ONLY else "subseqs_original_mmap_full")

NSUB_RAW = 781_250
NRECEPT_RAW = 30_000
DATA_STEP = 10
RANDOM_SEED = 42
MMAP_BATCH_SIZE = 512

print(f"Decimated root: {DECIMATED_ROOT}")
print(f"Flattop only: {FLATTOP_ONLY}  →  output: {OUTPUT_DIR.name}")
print(f"Output dir: {OUTPUT_DIR}")
print(f"mmap_ninja: {HAS_MMAP_NINJA}")

## Build dataset and split indices

Use `EceiDatasetOriginal` to get the exact same subsequence indices (shot_idxi, start_idxi, stop_idxi, disrupt_idxi) and train/test split. We only need the indices and normalization stats; we will read H5 and write mmap ourselves so all preprocessing is baked in.

In [None]:
inner = EceiDatasetOriginal(
    root=str(ROOT_DATA),
    clear_file=CLEAR_FILE if Path(ROOT / CLEAR_FILE).exists() else None,
    disrupt_file=DISRUPT_FILE,
    train=True,
    flattop_only=FLATTOP_ONLY,
    Twarn=300,
    test=0,
    normalize=True,
    data_step=DATA_STEP,
    nsub=NSUB_RAW,
    nrecept=NRECEPT_RAW,
    decimated_root=str(DECIMATED_ROOT),
    norm_stats_path=str(NORM_STATS),
)
inner.train_val_test_split(sizes=(0.8, 0.1, 0.1), random_seed=RANDOM_SEED)

n_seq = len(inner.shot_idxi)
train_inds = inner.train_inds
test_inds = inner.test_inds
val_inds = inner.val_inds
print(f"Total subsequences: {n_seq}")
print(f"Train: {len(train_inds)}, val: {len(val_inds)}, test: {len(test_inds)}")
print(f"nsub (decimated): {inner.nsub}, step in getitem: {inner._step_in_getitem}")
print(f"pos_weight: {inner.pos_weight:.4f}, neg_weight: {inner.neg_weight:.4f}")

## Save normalized subsequences to mmap_ninja

For each sequence index we: read the slice from decimated H5, apply offset (if any) and normalization, compute target and weight as in `EceiDatasetOriginal.__getitem__`, then append to RaggedMmap. All preprocessing is done here so training loads pre-normalized data.

In [None]:
if not HAS_MMAP_NINJA:
    raise RuntimeError("mmap_ninja required; pip install mmap_ninja")

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

X_batch, target_batch, weight_batch = [], [], []
mmap_initialized = False
x_mmap = t_mmap = w_mmap = None

def flush_mmap():
    global mmap_initialized, x_mmap, t_mmap, w_mmap
    if not X_batch:
        return
    if not mmap_initialized:
        RaggedMmap.from_lists(str(OUTPUT_DIR / "X"), X_batch)
        RaggedMmap.from_lists(str(OUTPUT_DIR / "target"), target_batch)
        RaggedMmap.from_lists(str(OUTPUT_DIR / "weight"), weight_batch)
        mmap_initialized = True
    else:
        x_mmap = x_mmap or RaggedMmap(str(OUTPUT_DIR / "X"))
        t_mmap = t_mmap or RaggedMmap(str(OUTPUT_DIR / "target"))
        w_mmap = w_mmap or RaggedMmap(str(OUTPUT_DIR / "weight"))
        x_mmap.extend(X_batch)
        t_mmap.extend(target_batch)
        w_mmap.extend(weight_batch)
    X_batch.clear()
    target_batch.clear()
    weight_batch.clear()

labels_all = []

for idx in tqdm(range(n_seq), desc="Writing subsequences"):
    # Same logic as EceiDatasetOriginal._read_data and __getitem__
    shot_index = inner.shot_idxi[idx]
    filename = inner._filename(shot_index)
    start_i = inner.start_idxi[idx]
    stop_i = inner.stop_idxi[idx]
    with h5py.File(filename, "r") as f:
        if np.all(inner.offsets[..., shot_index] == 0) and "offsets" in f:
            inner.offsets[..., shot_index] = f["offsets"][...]
        X = (
            f["LFS"][..., start_i:stop_i][..., :: inner._step_in_getitem]
            - inner.offsets[..., shot_index][..., np.newaxis]
        ).astype(np.float32)
    if inner.normalize:
        X = (X - inner.normalize_mean[..., np.newaxis]) / inner.normalize_std[..., np.newaxis]
    T = X.shape[-1]
    target = np.zeros((T,), dtype=np.float32)
    weight = np.full((T,), inner.neg_weight, dtype=np.float32)
    if inner.disruptedi[idx]:
        first_disrupt = int(
            (inner.disrupt_idxi[idx] - start_i + 1) / inner._step_in_getitem
        )
        target[first_disrupt:] = 1.0
        weight[first_disrupt:] = inner.pos_weight
    labels_all.append(1 if inner.disruptedi[idx] else 0)
    X_batch.append(np.ascontiguousarray(X))
    target_batch.append(target)
    weight_batch.append(weight.astype(np.float32))
    if len(X_batch) >= MMAP_BATCH_SIZE:
        flush_mmap()

flush_mmap()
np.save(OUTPUT_DIR / "labels.npy", np.array(labels_all, dtype=np.int64))
np.save(OUTPUT_DIR / "train_inds.npy", train_inds)
np.save(OUTPUT_DIR / "test_inds.npy", test_inds)
np.save(OUTPUT_DIR / "val_inds.npy", val_inds)

meta = {
    "nsub": int(inner.nsub),
    "nrecept": int(inner.nrecept),
    "data_step": int(DATA_STEP),
    "pos_weight": float(inner.pos_weight),
    "neg_weight": float(inner.neg_weight),
    "flattop_only": FLATTOP_ONLY,
    "n_sequences": n_seq,
    "n_train": len(train_inds),
    "n_val": len(val_inds),
    "n_test": len(test_inds),
}
with open(OUTPUT_DIR / "meta.json", "w") as f:
    json.dump(meta, f, indent=2)

print(f"Saved to {OUTPUT_DIR}")
print(json.dumps(meta, indent=2))

## Verify: load back and compare one sample

Load the mmap and compare one subsequence with the same index from `EceiDatasetOriginal` to ensure preprocessing matches.

In [None]:
from mmap_ninja import RaggedMmap
import torch

x_m = RaggedMmap(str(OUTPUT_DIR / "X"))
t_m = RaggedMmap(str(OUTPUT_DIR / "target"))
w_m = RaggedMmap(str(OUTPUT_DIR / "weight"))
labels = np.load(OUTPUT_DIR / "labels.npy")

assert len(x_m) == n_seq
idx = 0
X_saved = np.ascontiguousarray(x_m[idx])
target_saved = t_m[idx]
weight_saved = w_m[idx]
X_live, target_live, _, weight_live = inner[idx]
X_live = X_live.numpy()
target_live = target_live.numpy()
weight_live = weight_live.numpy()

print(f"Shape saved: {X_saved.shape}, live: {X_live.shape}")
print(f"X match: {np.allclose(X_saved, X_live)}")
print(f"target match: {np.allclose(target_saved, target_live)}")
print(f"weight match: {np.allclose(weight_saved, weight_live)}")
print("Verification OK.")

## Train using prebuilt mmap

Run training without loading decimated H5 on the fly; use the pre-saved subsequences instead.

**Flattop-only** (default output):
```bash
torchrun --nproc_per_node=4 train_tcn_ddp_original.py --prebuilt-mmap-dir subseqs_original_mmap --flattop-only ...
# or
bash run_tcn_baseline_160_original_prenorm.sh --prebuilt-mmap-dir subseqs_original_mmap
```

**Full segment** (when you ran with `FLATTOP_ONLY = False`):
```bash
torchrun --nproc_per_node=4 train_tcn_ddp_original.py --prebuilt-mmap-dir subseqs_original_mmap_full ...
# or
bash run_tcn_baseline_160_original_prenorm.sh --prebuilt-mmap-dir subseqs_original_mmap_full
```
(Do **not** pass `--flattop-only` when using the full-segment mmap.)