# Explore: Original DisruptCNN-Style Dataset

This notebook walks from **shot list** → **segment bounds** → **subsequence tiling** → (optional) **full dataset** for the pipeline used by `train_tcn_ddp_original.py` and `EceiDatasetOriginal`.

**Start here:** The section **⏱ Times, flattop, and CLEAR vs DISRUPTION** defines the times used (all in **ms**), what **flattop** means, and which portion of each segment is **CLEAR** (label 0) vs **DISRUPTION** (label 1). Section **2b** explains how **meta.csv** is used for decimated samples (alternative to the shot list).

**Goals:**
- See how each shot becomes a segment (flattop-only vs full).
- See how segments are tiled into fixed-length windows (overlap, tail fallback).
- Inspect sequence counts, label balance, and file-length filtering.
- Optionally load the built dataset and visualize example windows + labels.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

from disruptcnn.dataset_original import (
    segment_info_for_comparison,
    subsequences_original_tiling,
    EceiDatasetOriginal,
)

# ── Config (edit for your environment) ─────────────────────────────
SHOT_LIST = Path('disruptcnn/shots/d3d_disrupt_ecei.final.txt')
FLATTOP_ONLY = True
DATA_ROOT = Path('/home/idies/workspace/Storage/yhuang2/persistent/ecei')  # or None to skip H5
DECIMATED_ROOT = DATA_ROOT / 'dsrpt_decimated' if DATA_ROOT else None
NSUB_RAW = 781_250   # window in raw samples (~781 ms at 1 MHz)
NRECEPT_RAW = 300_000  # receptive field in raw (~300 ms)
DATA_STEP = 10       # decimation: 1 MHz → 100 kHz
NSUB = NSUB_RAW // DATA_STEP   # window in decimated space
NRECEPT = NRECEPT_RAW // DATA_STEP

print(f'Shot list: {SHOT_LIST}')
print(f'Flattop only: {FLATTOP_ONLY}')
print(f'nsub (decimated): {NSUB}, nrecept: {NRECEPT}')

---
## ⏱ Times, flattop, and CLEAR vs DISRUPTION (read this first)

All times in the shot list and in meta.csv are in **milliseconds (ms)**.

### What is “flattop”?
**Flattop** is the period of the discharge when the plasma current is **steady** (flat). In the shot list:
- **t_flat_start** [ms] = start of flattop
- **t_flat_last** [ms] = duration of flattop  
- So flattop runs from **t_flat_start** to **t_flat_start + t_flat_last**.  
Shots with `nan` for t_flat_* have no flattop. When **flattop_only = True**, we use **only this window** as the segment and drop shots without flattop.

### Which time range is used (the segment)?
- **Segment start** [ms]: flattop_only → **t_flat_start**; else 0 (relative to tstart).
- **Segment end** [ms]: **tend** = max(**tdisrupt**, min(tlast, t_flat_stop)) with t_flat_stop = t_flat_start + t_flat_last.  
So the segment is from segment start to tend (capped by the data).

### CLEAR vs DISRUPTION — which portion is which?
- **Twarn = 300 ms** (warning window before disruption).
- **CLEAR (label 0)** = from **segment start** up to **tdisrupt − 300 ms**.  
  → The “pre‑warning” part of the segment. No disruption yet.
- **DISRUPTION (label 1)** = from **tdisrupt − 300 ms** to **segment end**.  
  → The last 300 ms before disruption plus the disruption; we treat this as “disruptive.”

So in each window (subsequence): timesteps before **(tdisrupt − Twarn)** are **clear**; timesteps at and after **(tdisrupt − Twarn)** are **disruption**.

In [None]:
# Visual timeline for one shot: segment, flattop, CLEAR vs DISRUPTION
TWARN_MS = 300.0
segments = segment_info_for_comparison(str(SHOT_LIST), FLATTOP_ONLY)
if segments:
    s = segments[0]
    tstart, tlast = s['tstart'], s['tlast']
    t_flat_start, t_flat_last = s['t_flat_start'], s['t_flat_last']
    tdisrupt = s['tdisrupt']
    t_flat_stop = t_flat_start + t_flat_last
    tend = max(tdisrupt, min(tlast, t_flat_stop))
    seg_start_ms = t_flat_start if FLATTOP_ONLY else 0
    t_clear_end = tdisrupt - TWARN_MS  # CLEAR = [seg_start_ms, t_clear_end), DISRUPT = [t_clear_end, tend]

    fig, ax = plt.subplots(1, 1, figsize=(12, 2.5))
    xlo, xhi = seg_start_ms - 200, tend + 100
    ax.set_xlim(xlo, xhi)
    ax.set_ylim(0, 3)
    # Segment (full used range)
    ax.axvspan(seg_start_ms, tend, ymin=0.5, ymax=0.9, color='lightgray', alpha=0.8, label='Segment (used)')
    # Flattop
    ax.axvspan(t_flat_start, t_flat_stop, ymin=0.55, ymax=0.85, color='blue', alpha=0.4, label='Flattop')
    # CLEAR (label 0)
    ax.axvspan(seg_start_ms, t_clear_end, ymin=0.1, ymax=0.45, color='green', alpha=0.7, label='CLEAR (label 0)')
    # DISRUPTION (label 1)
    ax.axvspan(t_clear_end, tend, ymin=0.1, ymax=0.45, color='red', alpha=0.7, label='DISRUPTION (label 1)')
    ax.axvline(tdisrupt, color='black', ls='--', linewidth=2, label='tdisrupt')
    ax.axvline(t_clear_end, color='orange', ls=':', linewidth=1.5, label='tdisrupt − 300 ms')
    ax.set_yticks([0.275, 0.7])
    ax.set_yticklabels(['CLEAR vs DISRUPT\n(labels)', 'Segment / Flattop'])
    ax.set_xlabel('Time (ms)')
    ax.set_title(f"Shot {s['shot']}: which portion is CLEAR vs DISRUPTION")
    ax.legend(loc='upper right', fontsize=8)
    plt.tight_layout()
    plt.show()
else:
    print('No segments (e.g. flattop_only=True and no flattop in list).')

---
### Sampled visualization: 10 random shots — clear vs disrupt on the signal

Below we **sample 10 random shots** and plot each segment with **CLEAR** (label 0) and **DISRUPTION** (label 1) portions marked on the actual ECEi signal (when H5 is available). Use this to confirm that subsequence/label generation matches the segment bounds.

In [None]:
# Sample 10 random shots; plot segment signal with CLEAR vs DISRUPTION portions
import h5py

TWARN_MS = 300.0
rng = np.random.default_rng(42)
segments = segment_info_for_comparison(str(SHOT_LIST), FLATTOP_ONLY)
n_sample = min(10, len(segments))
if n_sample == 0:
    print('No segments to sample.')
else:
    idx = rng.choice(len(segments), size=n_sample, replace=False)
    sampled = [segments[i] for i in idx]

    # Resolve root for H5: decimated (flat) or raw (disrupt subdir)
    def path_for_shot(shot):
        if DECIMATED_ROOT and (DECIMATED_ROOT / f'{shot}.h5').exists():
            return DECIMATED_ROOT / f'{shot}.h5', 'decimated'
        if DATA_ROOT and (Path(DATA_ROOT) / 'disrupt' / f'{shot}.h5').exists():
            return Path(DATA_ROOT) / 'disrupt' / f'{shot}.h5', 'raw'
        if DATA_ROOT and (Path(DATA_ROOT) / 'dsrpt' / f'{shot}.h5').exists():
            return Path(DATA_ROOT) / 'dsrpt' / f'{shot}.h5', 'raw'
        return None, None

    fig, axes = plt.subplots(2, 5, figsize=(16, 6))
    axes = axes.flatten()
    for k, s in enumerate(sampled):
        ax = axes[k]
        shot = s['shot']
        start_idx = s['start_idx']
        stop_idx = s['stop_idx']
        disrupt_idx = s['disrupt_idx']
        dt = s['dt']
        h5_path, mode = path_for_shot(shot)
        if h5_path is not None and disrupt_idx >= 0:
            with h5py.File(h5_path, 'r') as f:
                LFS = f['LFS'][...]
            if mode == 'decimated':
                step = DATA_STEP
                s0, s1 = start_idx // step, stop_idx // step
                s1 = min(s1, LFS.shape[-1])
                sig = LFS[:, :, s0:s1] if LFS.ndim == 3 else LFS[:, s0:s1]
                disrupt_local = (disrupt_idx - start_idx) // step
            else:
                s0, s1 = start_idx, min(stop_idx + 1, LFS.shape[-1])
                sig = LFS[:, :, s0:s1] if LFS.ndim == 3 else LFS[:, s0:s1]
                disrupt_local = disrupt_idx - start_idx
            if sig.ndim == 3:
                sig = sig.reshape(-1, sig.shape[-1]).mean(axis=0)
            else:
                sig = sig.mean(axis=0)
            T = sig.shape[-1]
            disrupt_local = np.clip(disrupt_local, 0, T)
            # Time axis in ms: one decimated sample = DATA_STEP * dt ms; raw = dt ms
            ms_per_samp = (DATA_STEP if mode == 'decimated' else 1) * float(dt)
            t_axis = np.arange(T) * ms_per_samp
            ax.plot(t_axis, sig, color='black', linewidth=0.5, alpha=0.9)
            t_cut = t_axis[disrupt_local] if disrupt_local < len(t_axis) else t_axis[-1]
            ax.axvspan(0, t_cut, alpha=0.25, color='green', label='CLEAR')
            ax.axvspan(t_cut, t_axis[-1], alpha=0.25, color='red', label='DISRUPT')
            ax.axvline(t_cut, color='orange', ls='--', linewidth=1)
            ax.set_xlabel('Time (ms)')
        else:
            # No H5 or clear shot: timeline only
            seg_start_ms = s['t_flat_start'] if FLATTOP_ONLY else 0
            t_flat_stop = s['t_flat_start'] + s['t_flat_last']
            tend = max(s['tdisrupt'], min(s['tlast'], t_flat_stop))
            t_clear_end = s['tdisrupt'] - TWARN_MS
            ax.axvspan(seg_start_ms, t_clear_end, alpha=0.5, color='green', label='CLEAR')
            ax.axvspan(t_clear_end, tend, alpha=0.5, color='red', label='DISRUPT')
            ax.axvline(s['tdisrupt'], color='black', ls='--')
            ax.axvline(t_clear_end, color='orange', ls=':')
            ax.set_xlabel('Time (ms)')
        ax.set_title(f"Shot {shot}")
        ax.set_ylabel('LFS (mean)' if h5_path else 'Segment')
        if k == 0:
            ax.legend(loc='upper right', fontsize=7)
    for k in range(n_sample, len(axes)):
        axes[k].set_visible(False)
    plt.suptitle('10 random shots: CLEAR (green) vs DISRUPTION (red) portions', fontsize=12)
    plt.tight_layout()
    plt.show()

---
## 1. Shot list (raw)

Columns: **Shot**, # segments, **tstart**, **tlast**, **dt**, SNR min, **t_flat_start**, **t_flat_last**, **tdisrupt**.

- Time columns are in **ms** (tstart, tlast, t_flat_*, tdisrupt).
- **dt** = sampling period in ms (e.g. 0.001 → 1 kHz in the list; raw data is 1 MHz so dt_raw = 0.001 for 1 MHz).
- **Flattop**: t_flat_start … t_flat_start + t_flat_last. Some rows have `nan` for t_flat_* (no flattop).

In [None]:
# Load shot list as table
raw = np.loadtxt(SHOT_LIST, skiprows=1)
if raw.ndim == 1:
    raw = raw[np.newaxis, :]
cols = ['Shot', 'n_seg', 'tstart', 'tlast', 'dt', 'SNR_min', 't_flat_start', 't_flat_last', 'tdisrupt']
df = pd.DataFrame(raw, columns=cols)
df['Shot'] = df['Shot'].astype(int)
print(f'Total rows: {len(df)}')
print(df.dtypes)
df.head(10)

In [None]:
# Basic stats
print('Segment length (tlast - tstart) ms:', (df['tlast'] - df['tstart']).describe())
print('\\ntdisrupt (ms):', df['tdisrupt'].describe())
print('\\nFlattop: t_flat_start NaN count:', df['t_flat_start'].isna().sum())
print('Flattop: t_flat_last NaN count:', df['t_flat_last'].isna().sum())
if FLATTOP_ONLY:
    kept = df[df['t_flat_start'].notna()]
    print(f'\\nWith flattop_only=True: {len(kept)} shots kept (drop {len(df) - len(kept)} with NaN flattop)')

---
## 2. Segment logic (per shot)

For each shot we define a **segment** [start_idx, stop_idx] in **sample space** (tstart = 0 at index 0, dt samples per ms).

- **Flattop-only:** segment start = t_flat_start (relative to tstart), segment end = **tend** = max(tdisrupt, min(tlast, t_flat_stop)) with t_flat_stop = t_flat_start + t_flat_last.
- **Twarn = 300 ms:** label as disruptive from sample **disrupt_idx** = ceil((tdisrupt - 300 - tstart) / dt) to end of segment.

`segment_info_for_comparison()` returns these bounds (in raw sample indices).

In [None]:
segments = segment_info_for_comparison(str(SHOT_LIST), FLATTOP_ONLY)
print(f'Segments (after flattop filter): {len(segments)}')
# Show first 3 in a table
pd.DataFrame(segments[:5])

---
## 3. Subsequence tiling (shot → windows)

Each segment is split into **overlapping windows** of length **nsub** (decimated). Overlap = nsub - nrecept + 1.

- If the segment is **shorter than nsub** (common with decimated H5), the dataset adds **one window** from the tail: [file_len - nsub, file_len].
- Logic: `subsequences_original_tiling(segment, nsub=NSUB, nrecept=NRECEPT, data_step=1)` (in decimated space we use step=1).

Below: for a few example shots we plot the segment on a timeline and the resulting windows (with “has_disrupt” color).

In [None]:
# How much of each segment is "clear" (before Twarn) vs "disrupt" (after)?
# disrupt_idx is the sample where label becomes 1; start_idx is segment start.
seg_df = pd.DataFrame(segments)
seg_df['clear_samples'] = np.maximum(0, seg_df['disrupt_idx'] - seg_df['start_idx'])
seg_df['disrupt_samples'] = np.maximum(0, seg_df['stop_idx'] - seg_df['disrupt_idx'])
seg_df['clear_ms'] = seg_df['clear_samples'] * seg_df['dt']
seg_df['disrupt_ms'] = seg_df['disrupt_samples'] * seg_df['dt']
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.scatter(seg_df['clear_ms'], seg_df['disrupt_ms'], alpha=0.3, s=10)
ax.set_xlabel('Clear (pre-Twarn) length (ms)')
ax.set_ylabel('Disrupt (Twarn to end) length (ms)')
ax.set_title('Per-shot: clear vs disrupt extent in segment')
ax.axhline(seg_df['disrupt_ms'].median(), color='red', ls='--', alpha=0.7, label='median disrupt')
ax.axvline(seg_df['clear_ms'].median(), color='blue', ls='--', alpha=0.7, label='median clear')
ax.legend()
plt.tight_layout()
plt.show()

---
## 4. Build full dataset (optional)

If **DECIMATED_ROOT** exists and contains H5 files, we build **EceiDatasetOriginal** and run **train_val_test_split**. This shows:

- How many shots are **dropped** (no H5 file).
- Total **sequences** and per-split counts.
- **Tail fallback**: shots with segment shorter than nsub get one window each.

---
## 2b. How **meta.csv** is used for decimated samples

The **original** pipeline above uses the **shot list** (`.txt`) and **EceiDatasetOriginal**.  
A **second** pipeline uses **meta.csv** with **ECEiTCNDataset** when loading from a **decimated** directory (e.g. `dsrpt_decimated/`).

### meta.csv columns (all times in **ms**)
| Column | Meaning |
|--------|--------|
| **shot** | Shot ID (integer). |
| **split** | `train` / `val` / `test` (which split the shot belongs to). |
| **t_disruption** | Time of disruption [ms]. Same role as **tdisrupt** in the shot list. |
| **t_last** or **t_segment_end** (optional) | End of the segment [ms]. If present, the segment ends here instead of at file end. |

### How the times are used (same CLEAR vs DISRUPTION rule)
- **Twarn = 300 ms** (same as above).
- **CLEAR (label 0)** = from segment start up to **t_disruption − 300 ms**.
- **DISRUPTION (label 1)** = from **t_disruption − 300 ms** to segment end.
- Segment start is typically the first sample after baseline (e.g. offset removal). Segment end = **t_last** or **t_segment_end** from meta.csv if present, else end of the H5 file.

So for decimated samples that use **meta.csv**: the only difference from the shot list is the **source** of the times (meta.csv instead of the `.txt` file). The meaning of **flattop** is not in meta.csv — that pipeline often uses the full file or a segment-end column; **flattop** is a concept from the shot list (t_flat_start, t_flat_last). When both exist, the shot list is used for **EceiDatasetOriginal**; meta.csv is used for **ECEiTCNDataset** with decimated roots.

In [None]:
# If decimated dir has meta.csv, show how it's used (same CLEAR / DISRUPTION rule)
META_PATH = DECIMATED_ROOT / 'meta.csv' if DECIMATED_ROOT else None
if META_PATH and META_PATH.exists():
    meta = pd.read_csv(META_PATH)
    print('meta.csv columns:', list(meta.columns))
    print('Total shots:', len(meta))
    if 't_disruption' in meta.columns:
        t_dis = meta['t_disruption'].dropna()
        t_clear_end = t_dis - 300  # CLEAR up to this [ms], DISRUPT after
        print('\nTimes [ms]: t_disruption (min, max):', t_dis.min(), t_dis.max())
        print('CLEAR portion ends at t_disruption - 300 ms; DISRUPTION from there to segment end.')
    print(meta.head(10))
else:
    print('No meta.csv at DECIMATED_ROOT; skipping. (Original pipeline uses shot list only.)')

In [None]:
dataset = None
if DECIMATED_ROOT and DECIMATED_ROOT.exists():
    norm_path = Path('norm_stats.npz')
    if not norm_path.exists():
        norm_path = None  # dataset will look in root/decimated_root
    dataset = EceiDatasetOriginal(
        root=str(DATA_ROOT) if DATA_ROOT else str(SHOT_LIST.parent.parent),
        disrupt_file=str(SHOT_LIST),
        clear_file=None,
        flattop_only=FLATTOP_ONLY,
        normalize=True,
        data_step=DATA_STEP,
        nsub=NSUB_RAW,
        nrecept=NRECEPT_RAW,
        decimated_root=str(DECIMATED_ROOT),
        norm_stats_path=str(norm_path) if norm_path and norm_path.exists() else None,
    )
    dataset.train_val_test_split()
    print(f'Shots in dataset: {len(dataset.shot)}')
    print(f'Total sequences: {len(dataset.shot_idxi)}')
    print(f'Train: {len(dataset.train_inds)}, Val: {len(dataset.val_inds)}, Test: {len(dataset.test_inds)}')
    print(f'Sequences with disrupt: {dataset.disruptedi.sum()}, without: {(~dataset.disruptedi).sum()}')
else:
    print('DECIMATED_ROOT not set or missing; skipping dataset build.')

---
## 5. Sequence-level exploration

Distributions: **sequences per shot**, **has_disrupt** fraction, **sequence length** (stop_idxi - start_idxi).

In [None]:
if dataset is not None:
    seq_len = dataset.stop_idxi - dataset.start_idxi
    per_shot = pd.Series(dataset.shot_idxi).value_counts().sort_index()
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes[0,0].hist(per_shot.values, bins=min(50, per_shot.max()), edgecolor='k', alpha=0.7)
    axes[0,0].set_xlabel('Sequences per shot')
    axes[0,0].set_ylabel('Number of shots')
    axes[0,0].set_title('Sequences per shot')
    axes[0,1].hist(seq_len, bins=50, edgecolor='k', alpha=0.7)
    axes[0,1].axvline(NSUB, color='red', ls='--', label=f'nsub={NSUB}')
    axes[0,1].set_xlabel('Sequence length (decimated samples)')
    axes[0,1].set_ylabel('Count')
    axes[0,1].set_title('Sequence length (tail windows < nsub)')
    axes[0,1].legend()
    # Train/val/test split composition
    split_names = ['train', 'val', 'test']
    split_inds = [dataset.train_inds, dataset.val_inds, dataset.test_inds]
    counts = [len(inds) for inds in split_inds]
    n_disrupt = [dataset.disruptedi[inds].sum() for inds in split_inds]
    x = np.arange(len(split_names))
    axes[1,0].bar(x, counts, color='steelblue')
    axes[1,0].set_xticks(x)
    axes[1,0].set_xticklabels(split_names)
    for i, (c, d) in enumerate(zip(counts, n_disrupt)):
        axes[1,0].text(i, c, f'{int(d)} disrupt', ha='center', fontsize=9)
    axes[1,0].set_ylabel('Number of sequences')
    axes[1,0].set_title('Split sizes')
    # Label balance in train
    train_dis = dataset.disruptedi[dataset.train_inds].sum()
    axes[1,1].pie([train_dis, len(dataset.train_inds) - train_dis], labels=['has_disrupt', 'no_disrupt'], autopct='%1.1f%%')
    axes[1,1].set_title('Train: disruptive vs clear windows')
    plt.tight_layout()
    plt.show()
else:
    print('Build dataset first (section 4).')

---
## 6. Example window: signal + labels

Load one sequence from the dataset and plot a few LFS channels and the per-timestep label (0 = clear, 1 = disruptive after Twarn).

In [None]:
if dataset is not None:
    idx = dataset.train_inds[0]  # first training sequence
    X, target, _, weight = dataset[idx]
    X = X.numpy() if hasattr(X, 'numpy') else np.asarray(X)
    target = target.numpy() if hasattr(target, 'numpy') else np.asarray(target)
    # LFS is (20, 8, T) or (160, T); flatten channel dim for 1D indexing
    if X.ndim == 3:
        X_flat = X.reshape(-1, X.shape[-1])  # (160, T)
    else:
        X_flat = X
    T = X_flat.shape[-1]
    fig, axes = plt.subplots(3, 1, figsize=(14, 7), sharex=True)
    ch = [0, X_flat.shape[0]//2, X_flat.shape[0]-1]
    for ax, c in zip(axes[:2], ch[:2]):
        ax.plot(np.arange(T), X_flat[c], color='black', alpha=0.8)
        ax.set_ylabel(f'Ch {c}')
    axes[2].fill_between(np.arange(T), 0, target, alpha=0.5, color='red', label='target (1=disrupt)')
    axes[2].set_ylabel('Label')
    axes[2].set_xlabel('Time (decimated samples)')
    axes[2].legend()
    axes[0].set_title(f'Shot index {dataset.shot_idxi[idx]}, seq len={T} (disrupt in window: {dataset.disruptedi[idx]})')
    plt.tight_layout()
    plt.show()
else:
    print('Build dataset first (section 4).')

---
## 7. Summary

- **Times**: All in **ms**. **Flattop** = t_flat_start to t_flat_start + t_flat_last (steady current). **CLEAR (label 0)** = segment start → **tdisrupt − 300 ms**. **DISRUPTION (label 1)** = **tdisrupt − 300 ms** → segment end.
- **Shot list** → segment bounds (flattop or full) and **disrupt_idx** (Twarn = 300 ms before tdisrupt). **meta.csv** (for decimated ECEiTCNDataset): shot, split, t_disruption [ms]; optional t_last / t_segment_end; same CLEAR vs DISRUPTION rule.
- **Tiling**: overlapping windows of length **nsub**; if the decimated file is shorter than nsub, one **tail window** is added per shot.
- **Label balance**: disrupt-only means all shots disrupt; "clear" windows are the early part of each shot before Twarn. Check train split balance (section 5) and class weights if performance is limited.
- **Recommendations**: (1) Inspect segment_length vs nsub — many tail-only windows can dominate. (2) Try including clear shots (non-disruptive) if available. (3) Tune Twarn or exclude_last_ms if needed.