# Preprocessing & DataLoader Visualisation

Walk through every preprocessing step of `ECEiTCNDataset` and inspect the
resulting tensors, labels, and class balance.

In [None]:
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path

from dataset_ecei_tcn import (
    ECEiTCNDataset, create_loaders,
    read_raw_shot, remove_offset, normalize_array, decimate,
)

ROOT = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt'
FS   = 1_000_000   # 1 MHz

## 1. Metadata overview

In [None]:
meta = pd.read_csv(Path(ROOT) / 'meta.csv')
print(f'Columns: {list(meta.columns)}')
print(f'Total shots: {len(meta)}')
print(meta.groupby('split').size())
meta.head(10)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# shot duration distribution
dur_ms = meta['t_disruption']  # already in ms
for sp in meta['split'].unique():
    mask = meta['split'] == sp
    axes[0].hist(dur_ms[mask], bins=30, alpha=0.6, label=sp)
axes[0].set_xlabel('t_disruption (ms)')
axes[0].set_ylabel('Count')
axes[0].set_title('Shot duration distribution')
axes[0].legend()

# shots per split
counts = meta.groupby('split').size()
axes[1].bar(counts.index, counts.values, color=['steelblue', 'firebrick'][:len(counts)])
axes[1].set_ylabel('# Shots')
axes[1].set_title('Shots per split')
for i, v in enumerate(counts.values):
    axes[1].text(i, v + 1, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

## 1b. Data composition diagnosis

**Critical check**: Does the dataset contain non-disruptive (clean) shots?

Churchill et al. (2019) trained on **2,747 shots (42% disruptive, 58% non-disruptive)**.
If our dataset only contains disruptive shots, the model never sees truly healthy
plasma — all "clear" labels come from the early parts of shots that eventually disrupt.
This makes the classification task fundamentally harder and limits the achievable F1.

In [None]:
# ═══════════════════════════════════════════════════════════════════════
#  DATA COMPOSITION DIAGNOSIS
# ═══════════════════════════════════════════════════════════════════════

print('=' * 70)
print('  DATA COMPOSITION DIAGNOSIS')
print('=' * 70)

# ── 1. Check for a disruption-status column ──────────────────────────
has_label_col = False
for col in ['is_disruptive', 'disruptive', 'label', 'disrupted', 'type']:
    if col in meta.columns:
        has_label_col = True
        print(f'\n  Found label column: "{col}"')
        print(f'  Value counts:')
        print(meta[col].value_counts().to_string().replace('\n', '\n    '))
        break

if not has_label_col:
    print('\n  No explicit disruption-status column found.')
    print(f'  Columns available: {list(meta.columns)}')

# ── 2. Analyse t_disruption values ──────────────────────────────────
print(f'\n  t_disruption statistics (ms):')
t_dis = meta['t_disruption']
print(f'    count   = {t_dis.count()} / {len(meta)} (non-NaN)')
print(f'    NaN     = {t_dis.isna().sum()}')
print(f'    min     = {t_dis.min():.1f} ms')
print(f'    max     = {t_dis.max():.1f} ms')
print(f'    median  = {t_dis.median():.1f} ms')

n_nan  = int(t_dis.isna().sum())
n_inf  = int(np.isinf(t_dis.values.astype(float)).sum()) if n_nan == 0 else 0
n_zero = int((t_dis == 0).sum())
n_neg  = int((t_dis < 0).sum())
n_valid = len(meta) - n_nan

# Heuristic: shots with t_disruption=NaN, 0, negative, or very large (>20s)
# are likely non-disruptive
n_suspect_nondisrupt = n_nan + n_zero + n_neg
very_long_threshold = 20_000  # 20 seconds — unusually long for a disruptive shot
n_very_long = int((t_dis > very_long_threshold).sum())

print(f'\n  ── Classification heuristic ──')
print(f'    NaN t_disruption (likely non-disruptive) : {n_nan}')
print(f'    Zero t_disruption                        : {n_zero}')
print(f'    Negative t_disruption                    : {n_neg}')
print(f'    t_disruption > {very_long_threshold/1000:.0f}s (very long)          : {n_very_long}')
print(f'    Remaining (clearly disruptive)            : {n_valid - n_zero - n_neg - n_very_long}')

# ── 3. Check actual h5 file lengths vs t_disruption ─────────────────
print(f'\n  ── Spot-checking shot durations vs t_disruption ──')
n_check = min(10, len(meta))
sample_shots = meta.sample(n_check, random_state=42) if len(meta) > n_check else meta

for _, row in sample_shots.iterrows():
    shot = int(row['shot'])
    t_dis_ms = row['t_disruption']
    h5_path = Path(ROOT) / f'{shot}.h5'
    if h5_path.exists():
        with h5py.File(h5_path, 'r') as f:
            T_total = f['LFS'].shape[-1]
        dur_ms = T_total / FS * 1000
        # If the shot is much longer than t_disruption, data extends past disruption
        # If t_disruption ≈ shot length, it disrupted near the end (typical)
        ratio = t_dis_ms / dur_ms * 100 if dur_ms > 0 else 0
        flag = '' if 50 < ratio < 105 else '  ⚠️'
        print(f'    shot {shot:>8d}: length={dur_ms:>8.1f} ms, '
              f't_dis={t_dis_ms:>8.1f} ms, '
              f'ratio={ratio:>5.1f}%{flag}')
    else:
        print(f'    shot {shot:>8d}: h5 file not found')

# ── 4. Summary verdict ──────────────────────────────────────────────
print()
print('  ' + '─' * 66)
all_disruptive = (n_nan == 0 and n_zero == 0 and n_neg == 0)

if all_disruptive and not has_label_col:
    pct_dis = 100.0
    print(f'  ⚠️  ALL {len(meta)} shots appear DISRUPTIVE (t_disruption is valid for all)')
    print(f'  ⚠️  No non-disruptive shots detected.')
    print(f'  ⚠️  Churchill et al. used 42% disruptive / 58% non-disruptive.')
    print(f'  ⚠️  This limits the "clear" class to early regions of disruptive shots,')
    print(f'  ⚠️  making the classification task fundamentally harder.')
    print(f'  ⚠️  RECOMMENDATION: Add non-disruptive shots to improve performance.')
elif n_nan > 0 or has_label_col:
    n_nondis = n_nan if not has_label_col else int((meta.get('is_disruptive', meta.get('disruptive', meta.get('label', pd.Series()))) == 0).sum())
    n_dis_shots = len(meta) - n_nondis if n_nondis > 0 else n_valid
    pct_dis = n_dis_shots / len(meta) * 100
    pct_nondis = 100 - pct_dis
    print(f'  ✓  Dataset composition:')
    print(f'       Disruptive    : {n_dis_shots:>5d} ({pct_dis:.1f}%)')
    print(f'       Non-disruptive: {n_nondis:>5d} ({pct_nondis:.1f}%)')
    print(f'       (Churchill et al.: 42% / 58%)')
    if pct_nondis < 30:
        print(f'  ⚠️  Non-disruptive fraction is low ({pct_nondis:.0f}%). '
              f'Consider adding more clean shots.')
else:
    print(f'  ✓  {len(meta)} shots with valid t_disruption.')
    print(f'      Could not determine disruptive/non-disruptive split automatically.')

print('  ' + '─' * 66)
print('=' * 70)

In [None]:
# ── Visualise the composition ────────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# (a) t_disruption distribution — look for bimodal / NaN gaps
ax = axes[0]
valid_tdis = t_dis.dropna()
ax.hist(valid_tdis, bins=40, color='firebrick', alpha=0.6, edgecolor='k', linewidth=0.5)
if n_nan > 0:
    ax.axvline(0, color='gray', ls='--', lw=2)
    ax.text(0.05, 0.95, f'{n_nan} NaN\n(non-disruptive?)',
            transform=ax.transAxes, va='top', fontsize=10,
            bbox=dict(boxstyle='round', fc='yellow', alpha=0.7))
ax.set_xlabel('t_disruption (ms)')
ax.set_ylabel('# Shots')
ax.set_title('Distribution of disruption times')
ax.grid(True, alpha=0.2)

# (b) Per-split composition
ax = axes[1]
splits = meta['split'].unique()
x_pos = np.arange(len(splits))
width = 0.35

for i, sp in enumerate(splits):
    sp_meta = meta[meta['split'] == sp]
    n_total_sp = len(sp_meta)
    n_nan_sp = int(sp_meta['t_disruption'].isna().sum())
    n_dis_sp = n_total_sp - n_nan_sp

    ax.bar(i - width/2, n_dis_sp, width, color='firebrick', alpha=0.7,
           label='Disruptive' if i == 0 else '')
    ax.bar(i + width/2, n_nan_sp, width, color='steelblue', alpha=0.7,
           label='Non-disruptive' if i == 0 else '')
    ax.text(i - width/2, n_dis_sp + 0.5, str(n_dis_sp), ha='center', fontsize=9)
    ax.text(i + width/2, n_nan_sp + 0.5, str(n_nan_sp), ha='center', fontsize=9)

ax.set_xticks(x_pos)
ax.set_xticklabels(splits)
ax.set_ylabel('# Shots')
ax.set_title('Disruptive vs Non-disruptive per split')
ax.legend()
ax.grid(True, alpha=0.2, axis='y')

# (c) Pie chart of overall composition
ax = axes[2]
n_dis_total = len(meta) - n_nan
labels_pie = [f'Disruptive\n({n_dis_total})', f'Non-disruptive\n({n_nan})']
sizes = [n_dis_total, max(n_nan, 0.001)]  # avoid zero-size
colors_pie = ['firebrick', 'steelblue']
explode = (0, 0.05)

if n_nan > 0:
    ax.pie(sizes, explode=explode, labels=labels_pie, colors=colors_pie,
           autopct='%1.1f%%', startangle=90, textprops={'fontsize': 11})
else:
    ax.pie([1], labels=[f'ALL Disruptive\n({len(meta)} shots)'],
           colors=['firebrick'], autopct='',
           startangle=90, textprops={'fontsize': 12, 'fontweight': 'bold'})
    ax.text(0, -0.15, '⚠️ No non-disruptive shots', ha='center', fontsize=11,
            color='darkorange', fontweight='bold')

ax.set_title('Overall dataset composition')

plt.suptitle('Data Composition Diagnosis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# ── Look for sibling directories that might contain non-disruptive data ──
parent = Path(ROOT).parent
print(f'Scanning {parent} for related data directories...\n')

found_dirs = []
if parent.exists():
    for p in sorted(parent.iterdir()):
        if p.is_dir():
            has_h5 = len(list(p.glob('*.h5'))[:1]) > 0
            has_meta = (p / 'meta.csv').exists()
            marker = ''
            if has_meta:
                marker += ' [meta.csv ✓]'
                try:
                    m = pd.read_csv(p / 'meta.csv')
                    n_shots = len(m)
                    n_nan_t = int(m['t_disruption'].isna().sum()) if 't_disruption' in m.columns else '?'
                    marker += f'  {n_shots} shots, {n_nan_t} with NaN t_disruption'
                except Exception:
                    pass
            if has_h5:
                marker += ' [h5 ✓]'
            print(f'  {p.name:40s}{marker}')
            found_dirs.append(p)

if not found_dirs:
    print('  (no sibling directories found)')

print(f'\n  Current data root: {ROOT}')
print(f'  If a "nondsrpt" or "clean" directory exists above, it may contain')
print(f'  non-disruptive shots that should be merged into training.')

## 2. Raw signal inspection

Load one full shot and show the raw (20, 8) ECEi grid.

In [None]:
SHOT = meta['shot'].iloc[0]
T_DIS_MS = meta['t_disruption'].iloc[0]
T_DIS = int(T_DIS_MS * 1000)  # in samples

raw = read_raw_shot(ROOT, SHOT)
print(f'Shot {SHOT}: shape = {raw.shape}, t_disruption = {T_DIS_MS:.1f} ms')

In [None]:
# Plot a few channels across the full shot
time_ms = np.arange(raw.shape[-1]) / (FS / 1000)  # ms
channels = [(0, 0), (10, 4), (19, 7)]  # (row, col) in the 20x8 grid

fig, axes = plt.subplots(len(channels), 1, figsize=(16, 3 * len(channels)), sharex=True)
for ax, (r, c) in zip(axes, channels):
    ax.plot(time_ms, raw[r, c, :], linewidth=0.3, color='k')
    ax.axvline(T_DIS_MS, color='red', linestyle='--', linewidth=1.5, label='t_disruption')
    ax.set_ylabel(f'Ch ({r},{c})')
    ax.legend(loc='upper right', fontsize=9)
    ax.grid(True, alpha=0.2)
axes[-1].set_xlabel('Time (ms)')
axes[0].set_title(f'Raw signal — shot {SHOT}')
plt.tight_layout()
plt.show()

In [None]:
# Spatial snapshot at a single time instant (middle of the shot)
t_snap = raw.shape[-1] // 2
fig, ax = plt.subplots(figsize=(6, 8))
im = ax.imshow(raw[:, :, t_snap], aspect='auto', cmap='RdBu_r')
ax.set_xlabel('Radial channel')
ax.set_ylabel('Vertical channel')
ax.set_title(f'Spatial snapshot at t = {t_snap/FS*1e3:.1f} ms')
plt.colorbar(im, ax=ax, label='Raw amplitude')
plt.tight_layout()
plt.show()

## 3. Step 1 — DC offset removal

Mean of the first 40 ms (40 000 samples) is subtracted per channel (matches disruptcnn).

In [None]:
BASELINE_LEN = 40_000
corrected, offset = remove_offset(raw, baseline_length=BASELINE_LEN)

fig, axes = plt.subplots(2, 1, figsize=(16, 6), sharex=True)

r, c = 10, 4
axes[0].plot(time_ms, raw[r, c, :], linewidth=0.3, color='k')
axes[0].axhline(offset[r, c], color='orange', linestyle='--', label=f'Offset = {offset[r,c]:.1f}')
axes[0].axvspan(0, BASELINE_LEN / FS * 1e3, alpha=0.15, color='orange', label='Baseline window')
axes[0].set_title(f'Before offset removal — Ch ({r},{c})')
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.2)

axes[1].plot(time_ms, corrected[r, c, :], linewidth=0.3, color='steelblue')
axes[1].axhline(0, color='gray', linestyle='--', alpha=0.5)
axes[1].set_title(f'After offset removal — Ch ({r},{c})')
axes[1].set_xlabel('Time (ms)')
axes[1].grid(True, alpha=0.2)

plt.tight_layout()
plt.show()

In [None]:
# Offset heatmap across all channels
fig, ax = plt.subplots(figsize=(6, 8))
im = ax.imshow(offset, aspect='auto', cmap='coolwarm')
ax.set_xlabel('Radial channel')
ax.set_ylabel('Vertical channel')
ax.set_title(f'DC offset per channel — shot {SHOT}')
plt.colorbar(im, ax=ax, label='Offset value')
plt.tight_layout()
plt.show()

## 4. Step 2 — Temporal decimation (10×)

Every 10th sample is kept → effective 100 kHz.

In [None]:
DATA_STEP = 10
decimated = decimate(corrected, DATA_STEP)
time_dec_ms = np.arange(decimated.shape[-1]) / (FS / DATA_STEP / 1000)

print(f'Before decimation: {corrected.shape}  →  After: {decimated.shape}')

# Zoom into a 10-ms window to see the effect
t0, t1 = 500.0, 510.0  # ms
mask_full = (time_ms >= t0) & (time_ms < t1)
mask_dec  = (time_dec_ms >= t0) & (time_dec_ms < t1)

r, c = 10, 4
fig, ax = plt.subplots(figsize=(16, 3))
ax.plot(time_ms[mask_full], corrected[r, c, mask_full],
        linewidth=0.5, color='k', alpha=0.4, label='1 MHz (original)')
ax.plot(time_dec_ms[mask_dec], decimated[r, c, mask_dec],
        linewidth=1.2, color='steelblue', marker='.', markersize=3, label='100 kHz (decimated)')
ax.set_xlabel('Time (ms)')
ax.set_title(f'Decimation comparison — Ch ({r},{c}), {t0}–{t1} ms')
ax.legend()
ax.grid(True, alpha=0.2)
plt.tight_layout()
plt.show()

### 4b. Save decimated data to disk

Run once to create offset-removed + 10×-decimated h5 files.
Subsequent dataset loads from this directory skip offset removal and decimation entirely.

In [None]:
import shutil
from tqdm import tqdm

DECIMATED_ROOT = Path('/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt_decimated')
BASELINE_LEN = 40_000
DATA_STEP    = 10

if DECIMATED_ROOT.exists() and (DECIMATED_ROOT / 'meta.csv').exists():
    n_existing = len(list(DECIMATED_ROOT.glob('*.h5')))
    print(f'Decimated data already exists at {DECIMATED_ROOT}  ({n_existing} h5 files)')
else:
    DECIMATED_ROOT.mkdir(parents=True, exist_ok=True)
    # Copy meta.csv unchanged (t_disruption stays in ms)
    shutil.copy(Path(ROOT) / 'meta.csv', DECIMATED_ROOT / 'meta.csv')

    for _, row in tqdm(meta.iterrows(), total=len(meta), desc='Saving decimated shots'):
        shot = int(row['shot'])
        raw = read_raw_shot(ROOT, shot)
        corrected, _ = remove_offset(raw, baseline_length=BASELINE_LEN)
        dec = decimate(corrected, DATA_STEP)

        out_path = DECIMATED_ROOT / f'{shot}.h5'
        with h5py.File(out_path, 'w') as f:
            f.create_dataset('LFS', data=dec.astype(np.float32),
                             compression='gzip', compression_opts=4)
            f.attrs['data_step']       = DATA_STEP
            f.attrs['baseline_length'] = BASELINE_LEN
            f.attrs['source_fs_hz']    = FS
            f.attrs['effective_fs_hz'] = FS // DATA_STEP

    print(f'Saved {len(meta)} decimated shots to {DECIMATED_ROOT}')

## 5. Step 3 — Z-score normalisation

Per-channel mean/std computed from training shots; then applied to all splits.

In [None]:
# Build the dataset — uses decimated data when available for speed
DECIMATED_ROOT = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt_decimated'

ds = ECEiTCNDataset(
    root            = ROOT,
    Twarn           = 300_000,
    baseline_length = 40_000,
    data_step       = 10,
    nsub            = 781_250,    # ~781 ms (matches disruptcnn)
    stride          = 481_090,    # (nsub/step - nrecept + 1) * step
    normalize       = True,
    decimated_root  = DECIMATED_ROOT,
)
ds.summary()

# Compute normalisation from train split (cache to disk)
NORM_STATS_PATH = Path('norm_stats.npz')

if NORM_STATS_PATH.exists():
    ds.load_norm_stats(str(NORM_STATS_PATH))
    norm_mean, norm_std = ds.norm_mean, ds.norm_std
    print(f'Loaded cached norm stats from {NORM_STATS_PATH}')
else:
    norm_mean, norm_std = ds.compute_norm_stats(split='train', max_shots=100)
    ds.save_norm_stats(str(NORM_STATS_PATH))
    print(f'Computed and saved norm stats to {NORM_STATS_PATH}')

print(f'  mean range: [{norm_mean.min():.4f}, {norm_mean.max():.4f}]')
print(f'  std  range: [{norm_std.min():.4f}, {norm_std.max():.4f}]')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 8))

im0 = axes[0].imshow(norm_mean, aspect='auto', cmap='coolwarm')
axes[0].set_title('Per-channel mean (train)')
axes[0].set_xlabel('Radial')
axes[0].set_ylabel('Vertical')
plt.colorbar(im0, ax=axes[0])

im1 = axes[1].imshow(norm_std, aspect='auto', cmap='viridis')
axes[1].set_title('Per-channel std (train)')
axes[1].set_xlabel('Radial')
axes[1].set_ylabel('Vertical')
plt.colorbar(im1, ax=axes[1])

plt.tight_layout()
plt.show()

In [None]:
# Apply normalisation to the example shot
normalised = normalize_array(decimated, norm_mean, norm_std)

r, c = 10, 4
fig, axes = plt.subplots(2, 2, figsize=(16, 7))

# Time series comparison
axes[0, 0].plot(time_dec_ms, decimated[r, c, :], linewidth=0.3, color='k')
axes[0, 0].set_title(f'Before normalisation — Ch ({r},{c})')
axes[0, 0].set_ylabel('Amplitude')
axes[0, 0].grid(True, alpha=0.2)

axes[0, 1].plot(time_dec_ms, normalised[r, c, :], linewidth=0.3, color='steelblue')
axes[0, 1].set_title(f'After normalisation — Ch ({r},{c})')
axes[0, 1].set_ylabel('z-score')
axes[0, 1].grid(True, alpha=0.2)

# Histograms
axes[1, 0].hist(decimated[r, c, :], bins=100, color='gray', alpha=0.7)
axes[1, 0].set_xlabel('Amplitude')
axes[1, 0].set_title('Value distribution (before)')

axes[1, 1].hist(normalised[r, c, :], bins=100, color='steelblue', alpha=0.7)
axes[1, 1].set_xlabel('z-score')
axes[1, 1].set_title('Value distribution (after)')

plt.tight_layout()
plt.show()

In [None]:
# Distribution across ALL channels before vs. after normalisation
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Flatten all channels, subsample time for speed
step = 100
vals_before = decimated[:, :, ::step].flatten()
vals_after  = normalised[:, :, ::step].flatten()

axes[0].hist(vals_before, bins=200, color='gray', alpha=0.7)
axes[0].set_title('All-channel amplitude (before norm)')
axes[0].set_xlabel('Amplitude')
axes[0].set_xlim(np.percentile(vals_before, [0.5, 99.5]))

axes[1].hist(vals_after, bins=200, color='steelblue', alpha=0.7)
axes[1].set_title('All-channel z-score (after norm)')
axes[1].set_xlabel('z-score')
axes[1].set_xlim(np.percentile(vals_after, [0.5, 99.5]))

plt.tight_layout()
plt.show()

## 6. Per-timestep labels & weights

The label is 0 (clear) until `Twarn` before disruption, then 1 (disruptive).

In [None]:
# Show the label on the full shot time axis
Twarn_ms = 300  # ms

fig, axes = plt.subplots(2, 1, figsize=(16, 5), sharex=True)

r, c = 10, 4
axes[0].plot(time_dec_ms, normalised[r, c, :], linewidth=0.3, color='k')
axes[0].set_ylabel('z-score')
axes[0].set_title(f'Normalised signal — Ch ({r},{c}), shot {SHOT}')
axes[0].axvline(T_DIS_MS, color='red', ls='--', lw=1.5, label='t_disruption')
axes[0].axvline(T_DIS_MS - Twarn_ms, color='orange', ls='--', lw=1.5, label=f't_dis − Twarn ({Twarn_ms} ms)')
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.2)

# Per-timestep label
label_full = np.zeros(len(time_dec_ms), dtype=np.float32)
d_idx_dec = int((T_DIS - 300_000) / DATA_STEP)
d_idx_dec = max(0, min(d_idx_dec, len(label_full)))
label_full[d_idx_dec:] = 1.0

axes[1].fill_between(time_dec_ms, 0, label_full, color='firebrick', alpha=0.4, label='Disruptive (1)')
axes[1].fill_between(time_dec_ms, 0, 1 - label_full, color='steelblue', alpha=0.2, label='Clear (0)')
axes[1].set_ylabel('Label')
axes[1].set_xlabel('Time (ms)')
axes[1].set_title('Per-timestep binary label (Twarn = 300 ms)')
axes[1].set_ylim(-0.05, 1.15)
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.2)

plt.tight_layout()
plt.show()

## 7. Subsequence tiling

Each shot is split into fixed-length windows (500 ms). Windows that straddle the
disruption boundary have a label transition inside.

In [None]:
# Find all subsequences for this shot
shot_meta_idx = np.where(ds.shots == SHOT)[0][0]
seq_mask = ds.seq_shot_idx == shot_meta_idx
seq_starts = ds.seq_start[seq_mask]
seq_stops  = ds.seq_stop[seq_mask]
seq_disrupt = ds.seq_disrupt_local[seq_mask]

print(f'Shot {SHOT}: {seq_mask.sum()} subsequences')

fig, ax = plt.subplots(figsize=(16, 3))
for i, (a, b, d) in enumerate(zip(seq_starts, seq_stops, seq_disrupt)):
    color = 'firebrick' if d >= 0 else 'steelblue'
    ax.barh(0, (b - a) / FS * 1e3, left=a / FS * 1e3, height=0.6,
            color=color, alpha=0.5, edgecolor='k', linewidth=0.5)
ax.axvline(T_DIS_MS, color='red', ls='--', lw=2, label='t_disruption')
ax.axvline(T_DIS_MS - Twarn_ms, color='orange', ls='--', lw=1.5, label='label boundary')
ax.set_xlabel('Time (ms)')
ax.set_title(f'Subsequence windows for shot {SHOT} (red = contains disruption label)')
ax.set_yticks([])
ax.legend(loc='upper left')
plt.tight_layout()
plt.show()

## 8. DataLoader end-to-end test

Pull a batch from the train loader and verify shapes + label distribution.

In [None]:
loaders = create_loaders(ds, batch_size=4, num_workers=0)
print('Splits available:', list(loaders.keys()))

split_name = 'train' if 'train' in loaders else list(loaders.keys())[0]
batch = next(iter(loaders[split_name]))
X_b, target_b, weight_b = batch

print(f'\nBatch from "{split_name}":')
print(f'  X      : {X_b.shape}  dtype={X_b.dtype}')
print(f'  target : {target_b.shape}  dtype={target_b.dtype}')
print(f'  weight : {weight_b.shape}  dtype={weight_b.dtype}')
print(f'  label frac per sample: {[f"{t.mean():.2f}" for t in target_b]}')

In [None]:
# Visualise one sample from the batch
idx = 0
x_np = X_b[idx].numpy()          # (20, 8, T_sub)
t_np = target_b[idx].numpy()     # (T_sub,)
w_np = weight_b[idx].numpy()     # (T_sub,)
T_sub = x_np.shape[-1]
t_ax = np.arange(T_sub) / (FS / ds.data_step / 1000)  # ms

fig, axes = plt.subplots(4, 1, figsize=(16, 10), sharex=True,
                          gridspec_kw={'height_ratios': [3, 3, 1, 1]})

# Signal (one channel)
axes[0].plot(t_ax, x_np[10, 4, :], linewidth=0.4, color='k')
axes[0].set_ylabel('z-score')
axes[0].set_title('Preprocessed signal — Ch (10, 4)')
axes[0].grid(True, alpha=0.2)

# All 160 channels as heatmap
flat = x_np.reshape(160, -1)
axes[1].imshow(flat, aspect='auto', cmap='RdBu_r',
               vmin=np.percentile(flat, 1), vmax=np.percentile(flat, 99),
               extent=[t_ax[0], t_ax[-1], 159, 0])
axes[1].set_ylabel('Channel')
axes[1].set_title('All 160 channels (20×8 flattened)')

# Per-timestep label
axes[2].fill_between(t_ax, 0, t_np, color='firebrick', alpha=0.5)
axes[2].set_ylabel('Label')
axes[2].set_ylim(-0.05, 1.15)

# Weight
axes[3].plot(t_ax, w_np, color='darkorange', linewidth=1)
axes[3].set_ylabel('Weight')
axes[3].set_xlabel('Time (ms) within subsequence')

plt.tight_layout()
plt.show()

## 9. Class balance statistics

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

for i, sp in enumerate(np.unique(ds.splits)):
    idx = ds.get_split_indices(sp)
    n_dis  = int(ds.seq_has_disrupt[idx].sum())
    n_clr  = len(idx) - n_dis

    axes[0].bar(i * 3, n_clr, color='steelblue', width=0.8, label='Clear' if i == 0 else '')
    axes[0].bar(i * 3 + 1, n_dis, color='firebrick', width=0.8, label='Disruptive' if i == 0 else '')
    axes[0].text(i * 3, n_clr + 1, str(n_clr), ha='center', fontsize=9)
    axes[0].text(i * 3 + 1, n_dis + 1, str(n_dis), ha='center', fontsize=9)

axes[0].set_xticks([i * 3 + 0.5 for i in range(len(np.unique(ds.splits)))])
axes[0].set_xticklabels(np.unique(ds.splits))
axes[0].set_ylabel('# Subsequences')
axes[0].set_title('Subsequences per split (clear vs disruptive)')
axes[0].legend()

# Fraction of disruptive time steps per subsequence
T_sub = ds.nsub // ds.data_step
fracs = []
for dl in ds.seq_disrupt_local:
    if dl < 0:
        fracs.append(0.0)
    else:
        d = min(dl // ds.data_step, T_sub)
        fracs.append((T_sub - d) / T_sub)
fracs = np.array(fracs)

axes[1].hist(fracs, bins=50, color='gray', alpha=0.7, edgecolor='k')
axes[1].set_xlabel('Fraction of disruptive time steps')
axes[1].set_ylabel('# Subsequences')
axes[1].set_title('Distribution of disruptive fraction per subsequence')

plt.tight_layout()
plt.show()

print(f'\nClass weights: pos_weight = {ds.pos_weight:.3f}, neg_weight = {ds.neg_weight:.3f}')