# Interpretable Debug: TCN Training Pipeline

This notebook walks through the **rolled-back** training pipeline with rich visualizations:

1. **Input** — Data layout, how subsequences are formed, time parameters  
2. **Normalization & labels** — Preprocessing, label definition, **label segments** for disruptive data  
3. **Model** — TCN structure, **per-layer receptive field**  
4. **Prediction** — How output is produced and cropped  
5. **Weighting** — How loss weights are applied (and why we use only `batch_weights`)

Set `USE_REAL_DATA = True` and the data paths in the next cell to use real dataset; otherwise synthetic data is used.

In [None]:
# Setup: paths and constants (match train_tcn_ddp.py)
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))

USE_REAL_DATA = False  # Set True and set paths below to use real dataset
ROOT = "/path/to/dsrpt"
DECIMATED_ROOT = "/path/to/dsrpt_decimated"
CLEAR_ROOT = "/path/to/clear_decimated"
CLEAR_DECIMATED_ROOT = "/path/to/clear_decimated"
NORM_STATS = None  # e.g. "norm_stats_pca1.npz"
PCA_COMPONENTS = 0  # 0, 1, 4, 8, 16

# Training params
DATA_STEP = 10
NSUB_RAW = 781_250
TWARN = 300_000
BASELINE_LEN = 40_000
EXCLUDE_LAST_MS = 0.0
INPUT_CHANNELS = 160
LEVELS = 4
NHID = 80
KERNEL_SIZE = 15
DILATION_BASE = 10
DROPOUT = 0.2
NRECEPT_TARGET = 30_000
T_sub = NSUB_RAW // DATA_STEP

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches

%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 4)
plt.rcParams['font.size'] = 11

---
## 1. Input — Data layout and how subsequences are formed

- **Raw data:** `root/` has `meta.csv` (shot, split, t_disruption) and `{shot}.h5` with key `LFS` → shape `(20, 8, T)` at 1 MHz (or `(C, T)` for PCA).
- **Clear shots:** optional `clear_root/`; whole shot = label 0.
- **Decimated:** when `decimated_root` is used, LFS is already offset-removed and decimated (100 kHz); no decimation in `__getitem__`.
- **Tiling:** each shot is split into windows of length `nsub`; windows advance by `stride`. For each window we store `disrupt_local` (start of Twarn in window, or -1 if clear) and `positive_end_local` (end of label-1 region).

In [None]:
# Time parameters (all in raw 1 MHz samples unless noted)
fig, ax = plt.subplots(1, 1, figsize=(10, 2.5))
ax.set_xlim(0, 10)
ax.set_ylim(0, 4)
ax.axis('off')

params = [
    ('Twarn', TWARN, f'{TWARN/1e6*1000:.1f} ms before disruption'),
    ('baseline_length', BASELINE_LEN, 'DC offset window'),
    ('data_step', DATA_STEP, 'decimation → 100 kHz'),
    ('nsub', NSUB_RAW, f'~{NSUB_RAW/1e6*1000:.1f} ms window'),
    ('T_sub', T_sub, 'nsub // data_step (output length)'),
]
y = 3.5
for name, val, desc in params:
    ax.text(0.2, y, f'{name}', fontfamily='monospace', fontsize=12, fontweight='bold')
    ax.text(2.2, y, f'= {val:,}', fontfamily='monospace')
    ax.text(5, y, desc, color='gray')
    y -= 0.65
ax.set_title('Time parameters', fontsize=14)
plt.tight_layout()
plt.show()
print(f"T_sub = {T_sub:,}  (samples per subsequence)")

In [None]:
# Schematic: one shot tiled into overlapping windows (conceptual)
from train_tcn_ddp import build_model
_, nrecept, _ = build_model(INPUT_CHANNELS, 1, LEVELS, NHID, KERNEL_SIZE, DILATION_BASE, DROPOUT, nrecept_target=NRECEPT_TARGET)
stride_raw = (NSUB_RAW // DATA_STEP - nrecept + 1) * DATA_STEP
stride_data = stride_raw // DATA_STEP  # in decimated space
nsub_data = NSUB_RAW // DATA_STEP

fig, ax = plt.subplots(1, 1, figsize=(12, 2))
L = 5 * nsub_data  # show 5 windows
ax.set_xlim(0, L)
ax.set_ylim(-0.3, 1.5)
ax.set_xlabel('Time (decimated samples)')
ax.set_yticks([])
for i in range(5):
    start = i * stride_data
    end = start + nsub_data
    ax.add_patch(Rectangle((start, 0.2), nsub_data, 0.6, facecolor='steelblue', alpha=0.7, edgecolor='navy', linewidth=1.5))
    ax.text(start + nsub_data/2, 0.5, f'win {i}', ha='center', va='center', color='white', fontweight='bold')
ax.axhline(0.5, color='k', linewidth=0.5, linestyle='--')
ax.set_title(f'Subsequence tiling: window length = {nsub_data:,}, stride = {stride_data:,} (raw stride {stride_raw:,})')
plt.tight_layout()
plt.show()

---
## 2. Normalization and label definition

**Preprocessing in `__getitem__`:**
1. **DC offset:** `baseline = mean(X[..., :baseline_length])`; `X = X - baseline`
2. **Decimation:** `X = X[..., ::data_step]` (1 MHz → 100 kHz)
3. **Z-score:** `X = (X - norm_mean) / norm_std` (per-channel, from training split)

**Labels (output space):**  
- `d = (disrupt_local + 1) // step`, `e = (positive_end_local + step - 1) // step`  
- `target[d:e] = 1` (Twarn window), elsewhere 0  
- `weight[0:d] = neg_weight`, `weight[d:e] = pos_weight`, `weight[e:T] = 0` (excluded)

In [None]:
# Normalization: synthetic signal before and after z-score (one channel)
np.random.seed(42)
t = np.linspace(0, 1, 500)
sig = 0.5 * np.sin(2 * np.pi * 5 * t) + 0.3 * np.random.randn(500) + 2.0  # offset ~2
sig_z = (sig - sig.mean()) / (sig.std() + 1e-8)
fig, axes = plt.subplots(2, 1, figsize=(10, 4), sharex=True)
axes[0].plot(sig, color='steelblue')
axes[0].set_ylabel('Raw (with offset)')
axes[0].set_title('Before: DC offset + z-score')
axes[1].plot(sig_z, color='green')
axes[1].set_ylabel('Normalized')
axes[1].set_xlabel('Sample')
axes[1].set_title('After: (x - mean) / std')
plt.tight_layout()
plt.show()

In [None]:
# Label segments for a disruptive subsequence (rich visualization)
T_ex = 78_125
d_ex, e_ex = 35_000, 38_000  # example: clear [0, d), Twarn [d, e), excluded [e, T]
t_target = np.zeros(T_ex)
t_target[d_ex:e_ex] = 1.0
t_weight = np.zeros(T_ex)
t_weight[:d_ex] = 0.56
t_weight[d_ex:e_ex] = 4.5
# Plot a window around the transition for clarity
win = slice(33_000, 42_000)
r = np.arange(T_ex)[win]
fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)
ax0, ax1 = axes
ax0.fill_between(r, 0, t_target[win], color='green', alpha=0.3, label='label 0 (clear)')
ax0.fill_between(r, 0, t_target[win], where=(t_target[win] > 0.5), color='red', alpha=0.5, label='label 1 (Twarn)')
ax0.axvspan(d_ex, e_ex, alpha=0.25, color='red')
ax0.set_ylabel('Target')
ax0.set_ylim(-0.05, 1.2)
ax0.legend(loc='upper right')
ax0.set_title('Label segments (zoom around transition)')
ax0.axvline(d_ex, color='red', linestyle='--', linewidth=1.5, alpha=0.9)
ax0.axvline(e_ex, color='red', linestyle='--', linewidth=1.5, alpha=0.9)
ax1.fill_between(r, 0, t_weight[win], color='steelblue', alpha=0.5)
ax1.axvspan(d_ex, e_ex, alpha=0.2, color='red')
ax1.set_xlabel('Output step index')
ax1.set_ylabel('Weight')
ax1.set_title('Per-timestep BCE weight')
ax1.axvline(d_ex, color='red', linestyle='--', linewidth=1.5, alpha=0.9)
ax1.axvline(e_ex, color='red', linestyle='--', linewidth=1.5, alpha=0.9)
plt.tight_layout()
plt.show()
print(f"  [0, {d_ex})   → label 0 (clear),  count = {d_ex}")
print(f"  [{d_ex}, {e_ex}) → label 1 (Twarn), count = {e_ex - d_ex}")
print(f"  [{e_ex}, {T_ex}) → excluded,       count = {T_ex - e_ex}")

In [None]:
# Optional: load real dataset and show one disruptive sample's target & weight
if USE_REAL_DATA and Path(ROOT).exists():
    from dataset_ecei_tcn import ECEiTCNDataset
    from train_tcn_ddp import build_model
    _, nrecept_ds, _ = build_model(INPUT_CHANNELS if PCA_COMPONENTS == 0 else PCA_COMPONENTS, 1, LEVELS, NHID,
                                  KERNEL_SIZE, DILATION_BASE, DROPOUT, nrecept_target=NRECEPT_TARGET)
    _stride_raw = (NSUB_RAW // DATA_STEP - nrecept_ds + 1) * DATA_STEP
    ds = ECEiTCNDataset(root=ROOT, decimated_root=DECIMATED_ROOT or None, clear_root=CLEAR_ROOT or None,
                        clear_decimated_root=CLEAR_DECIMATED_ROOT or None, Twarn=TWARN, baseline_length=BASELINE_LEN,
                        data_step=DATA_STEP, nsub=NSUB_RAW, stride=_stride_raw, normalize=True,
                        norm_stats_path=NORM_STATS, exclude_last_ms=EXCLUDE_LAST_MS, ignore_twarn=False,
                        n_input_channels=PCA_COMPONENTS if PCA_COMPONENTS > 0 else None)
    idx = np.where(ds.seq_has_disrupt)[0][0]
    X, target, weight = ds[idx]
    target = target.numpy()
    weight = weight.numpy()
    fig, ax = plt.subplots(1, 1, figsize=(14, 3))
    ax.plot(target, label='target', color='green', alpha=0.8)
    ax.plot(weight, label='weight', color='steelblue', alpha=0.7)
    ax.set_xlabel('Output step')
    ax.legend()
    ax.set_title(f'Real sample index {idx} (disruptive): target and weight')
    plt.tight_layout()
    plt.show()
else:
    print('Using synthetic example above. Set USE_REAL_DATA=True and paths to plot a real sample.')

---
## 3. Model and per-layer receptive field

- **TCN:** stack of `TemporalBlock`s (dilated causal convs) → linear → sigmoid → `(B, T_sub)`.
- **Causal:** output at time `t` only sees input `[0..t]`. The first `(nrecept - 1)` outputs do not have full context, so training uses only `output[nrecept-1:]` (and target/weight cropped the same way).
- **Receptive field:** `RF = 1 + 2 * (kernel_size - 1) * sum(dilations)`. Each block adds `2*(k-1)*dilation_i`.

In [None]:
from train_tcn_ddp import build_model, calc_receptive_field

model, nrecept, dilation_sizes = build_model(
    INPUT_CHANNELS, 1, LEVELS, NHID, KERNEL_SIZE, DILATION_BASE, DROPOUT, nrecept_target=NRECEPT_TARGET,
)
cum_rf = []
cum = 0
for i, d in enumerate(dilation_sizes):
    add = 2 * (KERNEL_SIZE - 1) * d
    cum += add
    cum_rf.append(1 + cum)

fig, ax = plt.subplots(1, 1, figsize=(10, 4))
x = np.arange(len(dilation_sizes))
bars = ax.bar(x, cum_rf, color='steelblue', edgecolor='navy', linewidth=1.2)
ax.axhline(nrecept, color='red', linestyle='--', label=f'Total RF = {nrecept:,}')
ax.set_xticks(x)
ax.set_xticklabels([f'Level {i}\ndilation={d}' for i, d in enumerate(dilation_sizes)])
ax.set_ylabel('Cumulative receptive field (samples)')
ax.set_title('Per-level receptive field (causal)')
ax.legend()
for i, (v, d) in enumerate(zip(cum_rf, dilation_sizes)):
    ax.text(i, v + 500, f'{v:,}', ha='center', fontsize=10)
plt.tight_layout()
plt.show()
print(f'Target RF: {NRECEPT_TARGET:,}  →  Achieved: {nrecept:,}')
print(f'Usable length for loss: T_sub - (nrecept - 1) = {T_sub - (nrecept - 1):,}')

---
## 4. How the prediction is made

- Forward: `(B, C, T_sub)` → TCN → linear → sigmoid → `(B, T_sub)`.
- Crop: `out_v = output[:, nrecept-1:]`, `tgt_v = target[:, nrecept-1:]`.
- Prediction: `pred = (out_v >= 0.5).float()`.

In [None]:
# Dummy forward (small T to avoid OOM)
T_forward = min(T_sub, 4000)
X_dummy = torch.randn(2, INPUT_CHANNELS, T_forward) * 0.1
with torch.no_grad():
    out = model(X_dummy)
out_crop = out[:, nrecept - 1:]
print(f'Input: {X_dummy.shape}  →  Output: {out.shape}  →  Crop: {out_crop.shape}')
print('Prediction: pred = (out_crop >= 0.5).float()')

---
## 5. Loss weighting (current rolled-back version)

- The dataset returns `(X, target, weight)`; the third tensor has per-timestep pos/neg weights. **In the current code this is not used in the loss** (loader yields `_weight`, ignored).
- Training uses **only** `batch_weights(tgt_v)` for BCE weights:
  - `n_pos = tgt_v.sum()`, `n_neg = n_total - n_pos`
  - Positive steps: weight = `0.5 * n_total / n_pos`
  - Negative steps: weight = `0.5 * n_total / n_neg`  
  So total weight on positives = total on negatives = 0.5×n_total (50/50 balance per batch).
- **Do not** multiply dataset weight by `batch_weights` — double weighting causes collapse to predicting 1 everywhere.

In [None]:
from train_tcn_ddp import batch_weights

# Example: 10 timesteps, 2 positive
tgt = torch.tensor([0., 0, 0, 0, 0, 1., 1, 0, 0, 0.], dtype=torch.float32)
w = batch_weights(tgt)
n_total, n_pos = tgt.numel(), int(tgt.sum().item())
n_neg = n_total - n_pos
pw = 0.5 * n_total / max(n_pos, 1)
nw = 0.5 * n_total / max(n_neg, 1)

fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=True)
ax0, ax1 = axes
x = np.arange(len(tgt))
colors = ['green' if t == 0 else 'red' for t in tgt]
ax0.bar(x, tgt.numpy(), color=colors, edgecolor='black')
ax0.set_ylabel('Target')
ax0.set_title('Target (0 = clear, 1 = disruptive)')
ax0.set_ylim(-0.05, 1.2)

ax1.bar(x, w.numpy(), color=colors, edgecolor='black')
ax1.axhline(pw, color='red', linestyle='--', alpha=0.8, label=f'pos weight = {pw:.3f}')
ax1.axhline(nw, color='green', linestyle='--', alpha=0.8, label=f'neg weight = {nw:.3f}')
ax1.set_xlabel('Timestep')
ax1.set_ylabel('BCE weight')
ax1.set_title('batch_weights(tgt): 50/50 total weight on pos vs neg')
ax1.legend()
plt.tight_layout()
plt.show()
print(f'n_pos={n_pos}, n_neg={n_neg}  →  pos weight={pw:.3f}, neg weight={nw:.3f}')
print(f'Sum on pos: {w[tgt==1].sum().item():.3f},  Sum on neg: {w[tgt==0].sum().item():.3f}')

---
## Training step (conceptual)

1. Get batch `(X, target, _weight)`; **\_weight is ignored**.
2. `out = model(X)` → `(B, T_sub)`.
3. `out_v = out[:, nrecept-1:]`, `tgt_v = target[:, nrecept-1:]`.
4. `wgt_v = batch_weights(tgt_v)`.
5. `loss = BCE(out_v, tgt_v, weight=wgt_v)`.
6. `loss.backward()`, `clip_grad_norm_`, `optimizer.step()`.