# DisruptCNN with decimated ECEi data

This notebook adapts the decimated dataset into the original DisruptCNN file/metadata format, then runs the original dataloader and training loop with minimal changes.

Decimated data root used here:
`/home/idies/workspace/Storage/yhuang2/persistent/ecei/dsrpt_decimated`


In [None]:
from pathlib import Path
import os
import sys
import shutil
import time
from types import SimpleNamespace

import h5py
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import torch.nn.functional as F

DECIMATED_ROOT = Path("/home/idies/workspace/Storage/yhuang2/persistent/ecei/dsrpt_decimated")
COMPAT_ROOT = Path(os.environ.get("DISRUPTCNN_COMPAT_ROOT", DECIMATED_ROOT.parent / "disruptcnn_compat"))
DATA_ROOT = COMPAT_ROOT / "data"

DISRUPT_DIR = DECIMATED_ROOT / "disrupt"
CLEAR_DIR = DECIMATED_ROOT / "clear"

USE_SYMLINKS = True  # set False to copy files into COMPAT_ROOT
WRITE_OFFSETS = True  # writes 'offsets' dataset into H5 files

OFFSET_WINDOW_MS = (-50.0, -10.0)  # same window as create_offsets.py
NORMALIZE = True
NORM_MAX_SHOTS = None  # set int to limit for quick stats

BATCH_SIZE = 12
DATA_STEP = 1  # decimated data already reduces rate; keep 1 to preserve timing
NRECEPT = 30000
NSUB = 78125


## Import original DisruptCNN modules

We keep the original code intact and only adapt the data to its expected format.


In [None]:
repo_root = Path.cwd()
if repo_root.name != "disruptcnn":
    repo_root = Path("/path/to/disruptcnn")  # update if needed

sys.path.insert(0, str(repo_root.parent))

try:
    from disruptcnn.loader import EceiDataset, data_generator
    import disruptcnn.main as disrupt_main
except Exception:
    from loader import EceiDataset, data_generator
    import main as disrupt_main


## Load decimated metadata and map columns

This maps `meta.csv` columns into the original DisruptCNN shot list format.


In [None]:
def pick_col(df, candidates, required=True):
    for col in candidates:
        if col in df.columns:
            return col
    if required:
        raise KeyError(f"None of {candidates} found in meta.csv")
    return None

META_PATH = DECIMATED_ROOT / "meta.csv"
if not META_PATH.exists():
    raise FileNotFoundError(f"meta.csv not found at {META_PATH}")

meta = pd.read_csv(META_PATH)

shot_col = pick_col(meta, ["shot", "shot_id", "shotnum", "shot_number"])
tstart_col = pick_col(meta, ["tstart", "t_start", "t_start_ms"])

dt_col = pick_col(meta, ["dt", "dt_ms", "sample_dt", "sample_dt_ms"])

tlast_col = pick_col(
    meta,
    ["tlast", "t_last", "tstop", "t_stop", "t_end", "t_segment_end", "t_segment_end_ms"],
    required=False,
)

tdisrupt_col = pick_col(
    meta,
    ["tdisrupt", "t_disrupt", "t_disrupt_ms", "tdisrupt_ms"],
    required=False,
)

tflat_start_col = pick_col(meta, ["t_flat_start", "tflat_start", "t_flat_start_ms"], required=False)
tflat_stop_col = pick_col(meta, ["t_flat_stop", "tflat_stop", "t_flat_last", "tflat_last", "t_flat_end"], required=False)
tflat_dur_col = pick_col(meta, ["t_flat_duration", "tflat_duration"], required=False)

shots = meta[shot_col].to_numpy().astype(int)
tstart_ms = meta[tstart_col].to_numpy().astype(float)
dt_ms = meta[dt_col].to_numpy().astype(float)

def shot_to_h5(shot):
    for base in (DISRUPT_DIR, CLEAR_DIR, DECIMATED_ROOT):
        path = base / f"{int(shot)}.h5"
        if path.exists():
            return path
    raise FileNotFoundError(f"H5 file for shot {shot} not found in {DECIMATED_ROOT}")

if tlast_col is None:
    tlast_ms = np.zeros_like(tstart_ms)
    for i, shot in enumerate(shots):
        path = shot_to_h5(shot)
        with h5py.File(path, "r") as f:
            n_samples = f["LFS"].shape[-1]
        tlast_ms[i] = tstart_ms[i] + n_samples * dt_ms[i]
else:
    tlast_ms = meta[tlast_col].to_numpy().astype(float)

if tdisrupt_col is None:
    disrupted_col = pick_col(meta, ["disrupted", "is_disrupt", "is_disruptive"], required=True)
    tdisrupt_ms = np.where(meta[disrupted_col].to_numpy().astype(bool), tlast_ms, -1000.0)
else:
    tdisrupt_ms = meta[tdisrupt_col].to_numpy().astype(float)

if tflat_start_col and tflat_dur_col:
    tflat_start_ms = meta[tflat_start_col].to_numpy().astype(float)
    tflat_duration_ms = meta[tflat_dur_col].to_numpy().astype(float)
elif tflat_start_col and tflat_stop_col:
    tflat_start_ms = meta[tflat_start_col].to_numpy().astype(float)
    tflat_stop_ms = meta[tflat_stop_col].to_numpy().astype(float)
    tflat_duration_ms = tflat_stop_ms - tflat_start_ms
else:
    tflat_start_ms = np.zeros_like(tstart_ms)
    tflat_duration_ms = tlast_ms - tstart_ms

is_disrupt = tdisrupt_ms > 0


## Create DisruptCNN-compatible data layout and shot lists

This builds `data/disrupt`, `data/clear`, and the shot list files expected by the original loader.


In [None]:
def ensure_dir(path):
    path.mkdir(parents=True, exist_ok=True)


def link_or_copy(src, dst):
    if dst.exists():
        return
    if USE_SYMLINKS:
        os.symlink(src, dst)
    else:
        shutil.copy2(src, dst)


def write_shot_file(path, shot_indices):
    # The loader treats the 8th column as duration (t_flat_start + duration).
    header = "# Shot\t# segments\ttstart\ttlast\tdt\tSNR min\tt_flat_start\tt_flat_last\ttdisrupt\n"
    with open(path, "w", encoding="utf-8") as f:
        f.write(header)
        for i in shot_indices:
            f.write(
                f"{shots[i]}\t1\t{tstart_ms[i]:.3f}\t{tlast_ms[i]:.3f}\t{dt_ms[i]:.6f}\t0.00\t"
                f"{tflat_start_ms[i]:.3f}\t{tflat_duration_ms[i]:.3f}\t{tdisrupt_ms[i]:.3f}\n"
            )


ensure_dir(DATA_ROOT / "disrupt")
ensure_dir(DATA_ROOT / "clear")

for i, shot in enumerate(shots):
    src = shot_to_h5(shot)
    dst_dir = DATA_ROOT / ("disrupt" if is_disrupt[i] else "clear")
    dst = dst_dir / f"{int(shot)}.h5"
    link_or_copy(src, dst)

CLEAR_FILE = COMPAT_ROOT / "d3d_clear_ecei.final.txt"
DISRUPT_FILE = COMPAT_ROOT / "d3d_disrupt_ecei.final.txt"

write_shot_file(CLEAR_FILE, np.where(~is_disrupt)[0])
write_shot_file(DISRUPT_FILE, np.where(is_disrupt)[0])

print(f"Wrote {CLEAR_FILE} and {DISRUPT_FILE}")
print(f"Data root: {DATA_ROOT}")


## Offsets (per-shot baseline)

Creates `offsets` datasets if missing. This matches `create_offsets.py` but scales the window to the decimated sampling rate.


In [None]:
def compute_offset_indices(tstart, dt, window_ms, n_samples):
    start_ms, end_ms = window_ms
    start_idx = int(np.round((start_ms - tstart) / dt))
    end_idx = int(np.round((end_ms - tstart) / dt))
    start_idx = max(start_idx, 0)
    end_idx = max(end_idx, start_idx + 1)
    end_idx = min(end_idx, n_samples)
    start_idx = min(start_idx, max(end_idx - 1, 0))
    return start_idx, end_idx

if WRITE_OFFSETS:
    for i, shot in enumerate(shots):
        path = DATA_ROOT / ("disrupt" if is_disrupt[i] else "clear") / f"{int(shot)}.h5"
        with h5py.File(path, "r+") as f:
            if "offsets" in f:
                continue
            n_samples = f["LFS"].shape[-1]
            start_idx, end_idx = compute_offset_indices(tstart_ms[i], dt_ms[i], OFFSET_WINDOW_MS, n_samples)
            data = f["LFS"][..., start_idx:end_idx]
            offsets = data.mean(axis=-1).astype("float32")
            f.create_dataset("offsets", data=offsets)

    print("Offsets ensured.")
else:
    print("WRITE_OFFSETS is False; skipping offsets creation.")


## Normalization (per-channel mean/std)

Writes `normalization.npz` to the data root in the exact format expected by the original loader.


In [None]:
def update_running_stats(mean, M2, count, x):
    n = x.shape[-1]
    x_mean = x.mean(axis=-1)
    x_var = x.var(axis=-1)
    if count == 0:
        return x_mean, x_var * n, n
    delta = x_mean - mean
    total = count + n
    mean = mean + delta * n / total
    M2 = M2 + x_var * n + delta * delta * count * n / total
    return mean, M2, total

if NORMALIZE:
    paths = sorted((DATA_ROOT / "disrupt").glob("*.h5")) + sorted((DATA_ROOT / "clear").glob("*.h5"))
    if NORM_MAX_SHOTS is not None:
        paths = paths[: int(NORM_MAX_SHOTS)]

    mean = None
    M2 = None
    count = 0
    chunk_size = 20000

    for path in paths:
        with h5py.File(path, "r") as f:
            offsets = f["offsets"][...]
            n_samples = f["LFS"].shape[-1]
            for start in range(0, n_samples, chunk_size):
                end = min(start + chunk_size, n_samples)
                x = f["LFS"][..., start:end] - offsets[..., np.newaxis]
                if mean is None:
                    mean, M2, count = update_running_stats(0, 0, 0, x)
                else:
                    mean, M2, count = update_running_stats(mean, M2, count, x)

    std = np.sqrt(M2 / max(count, 1))

    norm_path = DATA_ROOT / "normalization.npz"
    np.savez(norm_path, mean_all=mean, std_all=std, mean_flat=mean, std_flat=std)
    print(f"Saved {norm_path}")
else:
    print("NORMALIZE is False; skipping normalization stats.")


## Instantiate the original dataloader

We now use the untouched `EceiDataset` and `data_generator` from DisruptCNN.


In [None]:
dataset = EceiDataset(
    root=str(DATA_ROOT) + "/",
    clear_file=str(CLEAR_FILE),
    disrupt_file=str(DISRUPT_FILE),
    flattop_only=False,
    Twarn=300,
    label_balance="const",
    normalize=NORMALIZE,
    data_step=DATA_STEP,
    nsub=NSUB,
    nrecept=NRECEPT,
)

dataset.train_val_test_split()
train_loader, val_loader, test_loader = data_generator(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=4,
    undersample=1.0,
)

x, y, idx, w = dataset[0]
print("Sample shapes:", x.shape, y.shape, w.shape)


## Training (original loop, single GPU)

This mirrors the original `main.py` logic but runs in a single process.


In [None]:
args = SimpleNamespace(
    input_channels=160,
    n_classes=1,
    dropout=0.1,
    clip=0.3,
    kernel_size=15,
    dilation_size=10,
    levels=4,
    nhid=80,
    nrecept=NRECEPT,
    nsub=NSUB,
    lr=0.5,
    cuda=torch.cuda.is_available(),
    thresholds=np.linspace(0.05, 0.95, 19),
    distributed=False,
    backend="gloo",
    rank=0,
    plot=False,
)
args.tstart = time.time()
args.iterations_valid = len(train_loader)
args.iterations_warmup = 5 * len(train_loader)
args.multiplier_warmup = 8

model = disrupt_main.create_model(args)
if args.cuda:
    model = model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True)

lambda1 = (
    lambda iteration: (1.0 - 1.0 / args.multiplier_warmup) / args.iterations_warmup * iteration
    + 1.0 / args.multiplier_warmup
)
scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
scheduler_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5)

EPOCHS = 1  # set higher for full training
steps = 0
total_loss = 0.0
best_acc = 0.0

for epoch in range(EPOCHS):
    if hasattr(train_loader.sampler, "set_epoch"):
        train_loader.sampler.set_epoch(epoch)
    for batch_idx, (data, target, global_index, weight) in enumerate(train_loader):
        iteration = epoch * len(train_loader) + batch_idx
        args.iteration = iteration

        if iteration < args.iterations_warmup:
            scheduler_warmup.step(iteration)
        elif iteration > 0 and iteration % args.iterations_valid == 0:
            scheduler_plateau.step(total_loss)

        loss = disrupt_main.train_seq(data, target, weight, model, optimizer, args)
        total_loss += float(loss)
        steps += data.shape[0] * data.shape[-1]

        if batch_idx % 50 == 0:
            lr_epoch = [group["lr"] for group in optimizer.param_groups][0]
            print(
                f"Epoch {epoch} batch {batch_idx}/{len(train_loader)} "
                f"loss={total_loss / max(batch_idx + 1, 1):.6e} steps={steps} lr={lr_epoch:.2e}"
            )

        if iteration > 0 and iteration % args.iterations_valid == 0:
            valid_loss, valid_acc, valid_f1, TP, TN, FP, FN, threshold = disrupt_main.evaluate(
                val_loader, model, args
            )
            best_acc = max(best_acc, valid_acc)
            total_loss = 0.0

print(f"Best validation accuracy: {best_acc:.4f}")
