# Training Pipeline Diagnostic

Step-by-step inspection of every component before running training.
Includes an audit against the reference `disruptcnn` implementation.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from pathlib import Path
from collections import OrderedDict

from dataset_ecei_tcn import ECEiTCNDataset, create_loaders

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

---
## 1. Configuration — disruptcnn run.sh vs ours

| Parameter | disruptcnn `run.sh` | Our pipeline | Match? |
|-----------|-------------------|--------------|--------|
| batch_size | 12 | 12 | Yes |
| lr | **0.5** | **2e-3** | **NO — 250x smaller** |
| epochs | 1500 | 50 | Intentional |
| dropout | 0.1 | 0.1 | Yes |
| clip | 0.3 | 0.3 | Yes |
| data_step | 10 | 10 | Yes |
| nsub | 78125 (decimated) | 781250 (raw) / 10 = 78125 | Yes |
| stride | nsub - nrecept + 1 = 48109 | **481260 / 10 = 48126** | **NO — off by 17** |
| levels | 4 | 4 | Yes |
| nhid | 80 | 80 | Yes |
| kernel_size | 15 | 15 | Yes |
| dilation_size | 10 | 10 | Yes |
| nrecept (target) | 30000 | 30000 | Yes |
| warmup | 5 epochs | **3 epochs** | **NO** |
| warmup multiplier | 8 | 8 | Yes |
| plateau patience | 10 (default) | **5** | **NO** |
| label +1 offset | Yes | **No** | **NO (off-by-one)** |
| sampling | undersample negatives | oversample positives | Different approach |

In [None]:
# ── Data paths ─────────────────────────────────────────────────────
ROOT           = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt'
DECIMATED_ROOT = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt_decimated'

# ── disruptcnn run.sh values ──────────────────────────────────────
DATA_STEP       = 10
TWARN           = 300_000       # 300 ms at 1 MHz
BASELINE_LEN    = 40_000        # 40 ms at 1 MHz
NSUB_RAW        = 781_250       # 78125 * 10
BATCH_SIZE      = 12
NUM_WORKERS     = 4

# ── Model (from run.sh) ───────────────────────────────────────────
INPUT_CHANNELS  = 160
N_CLASSES       = 1
LEVELS          = 4
NHID            = 80
KERNEL_SIZE     = 15
DILATION_BASE   = 10
DROPOUT         = 0.1
NRECEPT_TARGET  = 30_000

# ── Training (from run.sh) ────────────────────────────────────────
LR_DISRUPTCNN   = 0.5           # run.sh value
LR_OURS         = 2e-3          # what we had
CLIP            = 0.3

---
## 2. Receptive field & dilation schedule

In [None]:
def calc_rf(kernel_size, dilation_sizes):
    return 1 + 2 * (kernel_size - 1) * int(np.sum(dilation_sizes))

# Build dilation schedule exactly as disruptcnn
base_dilations = [DILATION_BASE ** i for i in range(LEVELS - 1)]  # [1, 10, 100]
rf_without_last = calc_rf(KERNEL_SIZE, base_dilations)
last_dilation = int(np.ceil(
    (NRECEPT_TARGET - rf_without_last) / (2.0 * (KERNEL_SIZE - 1))))
last_dilation = max(last_dilation, 1)
DILATION_SIZES = base_dilations + [last_dilation]
NRECEPT = calc_rf(KERNEL_SIZE, DILATION_SIZES)

print(f'Base dilations (levels 0..{LEVELS-2}): {base_dilations}')
print(f'RF without last level   : {rf_without_last:,}')
print(f'Last-level dilation     : {last_dilation}')
print(f'Final dilation schedule : {DILATION_SIZES}')
print(f'Actual receptive field  : {NRECEPT:,} decimated samples')
print(f'  = {NRECEPT / 1e5 * 1e3:.2f} ms at 100 kHz')
print()

# ── Stride: should be nsub_dec - nrecept + 1 ──
NSUB_DEC = NSUB_RAW // DATA_STEP  # 78125
STRIDE_CORRECT_DEC = NSUB_DEC - NRECEPT + 1
STRIDE_CORRECT_RAW = STRIDE_CORRECT_DEC * DATA_STEP

STRIDE_OLD_RAW = 481_260
STRIDE_OLD_DEC = STRIDE_OLD_RAW // DATA_STEP

print(f'nsub (decimated)        : {NSUB_DEC:,}')
print(f'Correct stride (dec)    : {STRIDE_CORRECT_DEC:,}  (nsub - nrecept + 1 = {NSUB_DEC} - {NRECEPT} + 1)')
print(f'Correct stride (raw)    : {STRIDE_CORRECT_RAW:,}')
print(f'Old stride (dec)        : {STRIDE_OLD_DEC:,}  <- ERROR: off by {STRIDE_OLD_DEC - STRIDE_CORRECT_DEC}')
print(f'Overlap per subseq      : {NSUB_DEC - STRIDE_CORRECT_DEC:,} decimated samples = {(NSUB_DEC - STRIDE_CORRECT_DEC) / 1e5 * 1e3:.1f} ms')
print(f'  (overlap ≈ receptive field - 1 = {NRECEPT - 1:,})')

# Use the correct stride from here on
STRIDE_RAW = STRIDE_CORRECT_RAW

---
## 3. Model architecture inspection

In [None]:
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size
    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super().__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()
    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)
    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, dilation_size=2, kernel_size=2, dropout=0.2):
        super().__init__()
        layers = []
        num_levels = len(num_channels)
        if np.isscalar(dilation_size):
            dilation_size = [dilation_size ** i for i in range(num_levels)]
        for i in range(num_levels):
            in_ch = num_inputs if i == 0 else num_channels[i-1]
            out_ch = num_channels[i]
            layers.append(TemporalBlock(in_ch, out_ch, kernel_size, stride=1,
                                        padding=(kernel_size-1)*dilation_size[i],
                                        dilation=dilation_size[i], dropout=dropout))
        self.network = nn.Sequential(*layers)
    def forward(self, x):
        return self.network(x)

class TCN(nn.Module):
    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout, dilation_size):
        super().__init__()
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size,
                                   dropout=dropout, dilation_size=dilation_size)
        self.linear = nn.Linear(num_channels[-1], output_size)
    def forward(self, x):
        y = self.tcn(x)
        o = self.linear(y.permute(0, 2, 1))
        return torch.sigmoid(o.squeeze(-1))

In [None]:
channel_sizes = [NHID] * LEVELS
model = TCN(INPUT_CHANNELS, N_CLASSES, channel_sizes,
            kernel_size=KERNEL_SIZE, dropout=DROPOUT,
            dilation_size=DILATION_SIZES).to(DEVICE)

# ── Layer-by-layer summary ──
print(f'{"Layer":<45s} {"Type":<25s} {"Output shape":<25s} {"Params":>10s}')
print('─' * 110)
total_params = 0
trainable_params = 0

for name, module in model.named_modules():
    if len(list(module.children())) > 0:
        continue  # skip container modules
    n_params = sum(p.numel() for p in module.parameters())
    n_train  = sum(p.numel() for p in module.parameters() if p.requires_grad)
    total_params += n_params
    trainable_params += n_train
    typ = module.__class__.__name__

    # extra info for Conv1d
    extra = ''
    if isinstance(module, nn.Conv1d):
        extra = f'  k={module.kernel_size[0]} d={module.dilation[0]} p={module.padding[0]}'
    print(f'{name:<45s} {typ:<25s} {extra:<25s} {n_params:>10,}')

print('─' * 110)
print(f'Total parameters     : {total_params:>12,}')
print(f'Trainable parameters : {trainable_params:>12,}')
print(f'Model size (MB)      : {total_params * 4 / 1e6:>12.2f}  (float32)')

In [None]:
# ── Per-level breakdown ──
print(f'{"Level":<8s} {"In ch":<8s} {"Out ch":<8s} {"Dilation":<10s} '
      f'{"Padding":<10s} {"RF contrib":<12s} {"Cumul RF":<12s} {"Params":>10s}')
print('─' * 85)

cumul_rf = 1
for i in range(LEVELS):
    in_ch  = INPUT_CHANNELS if i == 0 else NHID
    out_ch = NHID
    d      = DILATION_SIZES[i]
    pad    = (KERNEL_SIZE - 1) * d
    rf_add = 2 * (KERNEL_SIZE - 1) * d  # each block adds this
    cumul_rf += rf_add

    # param count for this block
    block = list(model.tcn.network.children())[i]
    bp = sum(p.numel() for p in block.parameters())

    print(f'{i:<8d} {in_ch:<8d} {out_ch:<8d} {d:<10,d} '
          f'{pad:<10,d} {rf_add:<12,d} {cumul_rf:<12,d} {bp:>10,}')

print('─' * 85)
print(f'Final RF = {cumul_rf:,} decimated samples ({cumul_rf/1e5*1e3:.2f} ms at 100 kHz)')
assert cumul_rf == NRECEPT, f'Mismatch: {cumul_rf} != {NRECEPT}'

---
## 4. Dataset & subsequence structure

In [None]:
ds = ECEiTCNDataset(
    root            = ROOT,
    decimated_root  = DECIMATED_ROOT,
    Twarn           = TWARN,
    baseline_length = BASELINE_LEN,
    data_step       = DATA_STEP,
    nsub            = NSUB_RAW,
    stride          = STRIDE_RAW,   # corrected stride
    normalize       = True,
)
ds.summary()

In [None]:
n_total = len(ds)
n_pos   = int(ds.seq_has_disrupt.sum())
n_neg   = n_total - n_pos

print(f'Total subsequences  : {n_total}')
print(f'  disruptive (has label-1 region) : {n_pos} ({n_pos/n_total*100:.1f}%)')
print(f'  clear      (all label-0)        : {n_neg} ({n_neg/n_total*100:.1f}%)')
print(f'Subsequence length  : {ds._T_sub:,} decimated samples')
print(f'  = {ds._T_sub / 1e5 * 1e3:.1f} ms at 100 kHz')
print(f'Overlap             : {NSUB_DEC - STRIDE_CORRECT_DEC:,} decimated samples')
print(f'  = receptive field - 1 = {NRECEPT - 1:,}')
print()
print(f'pos_weight = {ds.pos_weight:.4f}')
print(f'neg_weight = {ds.neg_weight:.4f}')
print(f'  ratio pos/neg weight = {ds.pos_weight / ds.neg_weight:.2f}x')

# ── Per-shot breakdown ──
print(f'\nPer-shot subsequence counts:')
from collections import Counter
shot_counts = Counter(ds.seq_shot_idx.tolist())
for s_idx, count in sorted(shot_counts.items()):
    has_d = ds.seq_has_disrupt[ds.seq_shot_idx == s_idx]
    n_d = int(has_d.sum())
    shot = ds.shots[s_idx]
    split = ds.splits[s_idx]
    print(f'  shot {shot} ({split:>5s}): {count:3d} subseqs '
          f'({n_d} disruptive, {count - n_d} clear)')

---
## 5. Input data: shape, scale, statistics

In [None]:
# Sample a few subsequences and inspect
N_INSPECT = min(8, len(ds))
print(f'Inspecting {N_INSPECT} subsequences (indices 0..{N_INSPECT-1}):\n')

print(f'{"idx":<5s} {"X shape":<18s} {"X min":<12s} {"X max":<12s} '
      f'{"X mean":<12s} {"X std":<12s} {"tgt mean":<10s} {"pos%":<8s} '
      f'{"wgt_pos":<10s} {"wgt_neg":<10s}')
print('─' * 120)

X_samples = []
for i in range(N_INSPECT):
    X, target, weight = ds[i]
    X_np = X.numpy()
    tgt_np = target.numpy()
    wgt_np = weight.numpy()
    X_samples.append(X_np)

    pos_frac = tgt_np.mean()
    wgt_pos = wgt_np[tgt_np > 0.5].mean() if pos_frac > 0 else 0
    wgt_neg = wgt_np[tgt_np < 0.5].mean() if pos_frac < 1 else 0

    print(f'{i:<5d} {str(X_np.shape):<18s} {X_np.min():<12.4f} {X_np.max():<12.4f} '
          f'{X_np.mean():<12.6f} {X_np.std():<12.4f} {tgt_np.mean():<10.4f} '
          f'{pos_frac:<8.3f} {wgt_pos:<10.4f} {wgt_neg:<10.4f}')

X_all = np.stack(X_samples)  # (N, 20, 8, T)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# 1. Value distribution
vals = X_all.flatten()
axes[0].hist(vals[::100], bins=200, color='steelblue', alpha=0.7, edgecolor='none')
axes[0].axvline(vals.mean(), color='red', linestyle='--', label=f'mean={vals.mean():.3f}')
axes[0].axvline(0, color='black', linestyle='-', alpha=0.3)
axes[0].set_title(f'Input value distribution (sampled from {N_INSPECT} subseqs)')
axes[0].set_xlabel('Normalized value')
axes[0].legend()
axes[0].set_xlim(-6, 6)

# 2. Per-channel mean
ch_means = X_all.mean(axis=(0, -1))  # (20, 8)
im = axes[1].imshow(ch_means, aspect='auto', cmap='RdBu_r')
axes[1].set_title('Per-channel mean (should be ≈ 0)')
axes[1].set_xlabel('Column (0-7)')
axes[1].set_ylabel('Row (0-19)')
plt.colorbar(im, ax=axes[1])

# 3. Per-channel std
ch_stds = X_all.std(axis=(0, -1))  # (20, 8)
im2 = axes[2].imshow(ch_stds, aspect='auto', cmap='viridis')
axes[2].set_title('Per-channel std (should be ≈ 1)')
axes[2].set_xlabel('Column (0-7)')
axes[2].set_ylabel('Row (0-19)')
plt.colorbar(im2, ax=axes[2])

plt.tight_layout()
plt.show()

print(f'Overall: mean={vals.mean():.4f}, std={vals.std():.4f}, '
      f'min={vals.min():.2f}, max={vals.max():.2f}')

In [None]:
# ── Visualise a single subsequence: signal + target + weight ──
idx_show = 0
X_s, tgt_s, wgt_s = ds[idx_show]
X_np = X_s.numpy()
tgt_np = tgt_s.numpy()
wgt_np = wgt_s.numpy()
T = X_np.shape[-1]
t_ms = np.arange(T) / 1e5 * 1e3  # ms

fig, axes = plt.subplots(3, 1, figsize=(16, 7), sharex=True,
                         gridspec_kw={'height_ratios': [3, 1, 1]})

# Signal (a few channels)
for ch in [(0,0), (5,3), (10,4), (19,7)]:
    axes[0].plot(t_ms, X_np[ch[0], ch[1], :], linewidth=0.3,
                label=f'ch({ch[0]},{ch[1]})', alpha=0.7)
axes[0].axvspan(0, t_ms[NRECEPT-1], alpha=0.1, color='gray', label='Receptive field')
axes[0].set_ylabel('Normalized signal')
axes[0].set_title(f'Subsequence {idx_show}: (20, 8, {T}) → flattened to (160, {T})')
axes[0].legend(fontsize=8, ncol=5)
axes[0].grid(True, alpha=0.2)

# Target
axes[1].fill_between(t_ms, tgt_np, alpha=0.5, color='firebrick')
axes[1].set_ylabel('Target')
axes[1].set_ylim(-0.05, 1.05)
axes[1].grid(True, alpha=0.2)

# Weight
axes[2].plot(t_ms, wgt_np, color='darkorange', linewidth=1)
axes[2].set_ylabel('Weight')
axes[2].set_xlabel('Time (ms)')
axes[2].grid(True, alpha=0.2)

plt.tight_layout()
plt.show()

---
## 6. Batch structure & stratified sampling

In [None]:
loaders = create_loaders(ds, batch_size=BATCH_SIZE, num_workers=0,
                         stratified_train=True)

for split_name, loader in loaders.items():
    n_batches = len(loader.batch_sampler) if hasattr(loader.batch_sampler, '__len__') else len(loader)
    print(f'{split_name:>5s}: {n_batches} batches')

# ── Inspect first few training batches ──
train_loader = loaders.get('train', list(loaders.values())[0])

print(f'\nFirst 5 training batches:')
print(f'{"batch":<7s} {"X shape":<25s} {"tgt shape":<18s} '
      f'{"n_pos":<7s} {"n_neg":<7s} {"pos%":<8s} '
      f'{"X range":<20s} {"tgt mean":<10s}')
print('─' * 105)

batch_pos_counts = []
for b_idx, (X, target, weight) in enumerate(train_loader):
    if b_idx >= 5:
        break
    # count subsequences with any positive labels
    has_pos = (target.sum(dim=-1) > 0)  # per-subseq
    n_pos = has_pos.sum().item()
    n_neg = X.shape[0] - n_pos
    batch_pos_counts.append(n_pos)

    print(f'{b_idx:<7d} {str(tuple(X.shape)):<25s} {str(tuple(target.shape)):<18s} '
          f'{n_pos:<7d} {n_neg:<7d} {n_pos/X.shape[0]:<8.2f} '
          f'[{X.min():.2f}, {X.max():.2f}]{"":>5s} {target.mean():.4f}')

# Scan all batches for balance
print(f'\nScanning all training batches for class balance...')
all_pos = []
for X, target, weight in train_loader:
    has_pos = (target.sum(dim=-1) > 0)
    all_pos.append(has_pos.sum().item())
all_pos = batch_pos_counts + all_pos  # combine with first 5

all_pos = np.array(all_pos)
print(f'  Batches: {len(all_pos)}')
print(f'  Pos subseqs per batch: mean={all_pos.mean():.1f}, '
      f'min={all_pos.min()}, max={all_pos.max()} '
      f'(target={BATCH_SIZE//2})')
print(f'  Perfectly balanced batches: '
      f'{(all_pos == BATCH_SIZE // 2).sum()}/{len(all_pos)} '
      f'({(all_pos == BATCH_SIZE // 2).mean()*100:.0f}%)')

---
## 7. Forward pass sanity check

In [None]:
model.eval()
X_sample, tgt_sample, wgt_sample = ds[0]
X_in = X_sample.unsqueeze(0).to(DEVICE)             # (1, 20, 8, T)
B = X_in.shape[0]
X_flat = X_in.view(B, -1, X_in.shape[-1])           # (1, 160, T)

print(f'Input shape  : {tuple(X_in.shape)}  →  flattened: {tuple(X_flat.shape)}')

with torch.no_grad():
    output = model(X_flat)                            # (1, T)

print(f'Output shape : {tuple(output.shape)}')
print(f'Output range : [{output.min():.4f}, {output.max():.4f}]')
print(f'Output mean  : {output.mean():.4f}  (untrained → should be ≈ 0.5)')
print()

# ── Loss region ──
T_out = output.shape[-1]
T_valid = T_out - (NRECEPT - 1)
print(f'Total timesteps       : {T_out:,}')
print(f'Excluded (receptive f.): {NRECEPT - 1:,} ({(NRECEPT-1)/T_out*100:.1f}%)')
print(f'Valid for loss         : {T_valid:,} ({T_valid/T_out*100:.1f}%)')
print()

# ── Check untrained loss ──
out_v = output[:, NRECEPT-1:].cpu()
tgt_v = tgt_sample[NRECEPT-1:].unsqueeze(0)
wgt_v = wgt_sample[NRECEPT-1:].unsqueeze(0)
loss_val = F.binary_cross_entropy(out_v, tgt_v, weight=wgt_v).item()
loss_unweighted = F.binary_cross_entropy(out_v, tgt_v).item()
print(f'Untrained weighted BCE   : {loss_val:.4f}')
print(f'Untrained unweighted BCE : {loss_unweighted:.4f}')
print(f'Random baseline (ln2)    : {np.log(2):.4f}')

---
## 8. Label off-by-one audit

**disruptcnn** uses `(disrupt_idxi - start_idxi + 1) / data_step` to compute the
disruption boundary — the `+1` shifts the label by one raw sample.

**Our code** uses `dl // data_step` (no `+1`).

The difference is **1 decimated timestep = 10 μs** — negligible for the model,
but documented here for completeness.

In [None]:
# Pick a subsequence with disruption transition
trans_idx = np.where((ds.seq_disrupt_local > 0) &
                     (ds.seq_disrupt_local < ds._data_nsub))[0]
if len(trans_idx) > 0:
    i = trans_idx[0]
    dl = int(ds.seq_disrupt_local[i])
    step = ds._step_in_getitem

    our_boundary    = dl // step
    disruptcnn_boundary = (dl + 1) // step  # with the +1

    print(f'Subsequence {i}:')
    print(f'  disrupt_local (data-file space) : {dl}')
    print(f'  step_in_getitem                 : {step}')
    print(f'  Our label boundary (decimated)  : {our_boundary}')
    print(f'  disruptcnn boundary (decimated)  : {disruptcnn_boundary}')
    print(f'  Difference                      : {disruptcnn_boundary - our_boundary} '
          f'decimated sample(s) = {(disruptcnn_boundary - our_boundary) * 10} μs')
else:
    print('No transition subsequences found.')

---
## 9. Learning rate schedule comparison

In [None]:
import torch.optim as optim

n_train_batches = len(train_loader)
WARMUP_MULT = 8

configs = [
    {'name': 'disruptcnn (lr=0.5, warmup=5ep, patience=10)',
     'lr': 0.5, 'warmup_epochs': 5, 'patience': 10},
    {'name': 'Ours old (lr=2e-3, warmup=3ep, patience=5)',
     'lr': 2e-3, 'warmup_epochs': 3, 'patience': 5},
    {'name': 'Ours corrected (lr=0.5, warmup=5ep, patience=10)',
     'lr': 0.5, 'warmup_epochs': 5, 'patience': 10},
]

fig, ax = plt.subplots(figsize=(14, 5))
colors = ['firebrick', 'steelblue', 'seagreen']

for cfg, color in zip(configs, colors):
    # simulate warmup only (plateau needs loss, skip)
    warmup_iters = cfg['warmup_epochs'] * n_train_batches
    iters = np.arange(0, 20 * n_train_batches + 1)
    lrs = []
    for it in iters:
        if it < warmup_iters:
            scale = (1 - 1/WARMUP_MULT) / max(warmup_iters, 1) * it + 1/WARMUP_MULT
        else:
            scale = 1.0
        lrs.append(cfg['lr'] * scale)

    epochs_ax = iters / n_train_batches
    ax.plot(epochs_ax, lrs, label=cfg['name'], color=color, linewidth=1.5)

ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('LR warmup comparison (plateau decay not simulated)')
ax.set_yscale('log')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 20)
plt.tight_layout()
plt.show()

print(f'disruptcnn: start LR = {0.5/8:.4f}, peak LR = 0.5 at epoch 5')
print(f'Ours old  : start LR = {2e-3/8:.6f}, peak LR = 0.002 at epoch 3')
print(f'Ours fixed: start LR = {0.5/8:.4f}, peak LR = 0.5 at epoch 5')

---
## 10. Summary of fixes needed for `train_tcn.ipynb`

| # | Issue | Fix |
|---|-------|-----|
| 1 | **LR = 2e-3** (should be 0.5) | `LR = 0.5` |
| 2 | **STRIDE hardcoded** (off by 17 dec samples) | Compute dynamically: `STRIDE = (NSUB // DATA_STEP - NRECEPT + 1) * DATA_STEP` |
| 3 | **Warmup = 3 epochs** (should be 5) | `WARMUP_EPOCHS = 5` |
| 4 | **Plateau patience = 5** (should be 10) | Remove explicit `patience=5` (use default 10) |
| 5 | **Label +1 offset** missing | Add `+1` in `__getitem__`: `d = min((dl + 1) // step, T)` |
| 6 | **Sampling** differs | Acceptable — both achieve balanced batches |

In [None]:
print('Corrected values for train_tcn.ipynb:')
print(f'  LR             = 0.5')
print(f'  STRIDE         = {STRIDE_CORRECT_RAW}  (computed: nsub_dec - nrecept + 1 = {STRIDE_CORRECT_DEC}, × {DATA_STEP})')
print(f'  WARMUP_EPOCHS  = 5')
print(f'  Plateau patience = 10 (PyTorch default)')
print(f'  NRECEPT        = {NRECEPT}  (computed from model)')