# TCN Disruption Prediction — Training

Per-timestep binary classification using a Temporal Convolutional Network.
Adapted from [disruptcnn](https://github.com/rmchurch/disruptcnn) (Churchill et al. 2019).

Uses `ECEiTCNDataset` with pre-decimated data for fast I/O.

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import weight_norm

from dataset_ecei_tcn import ECEiTCNDataset, create_loaders

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

## 1. Configuration

In [None]:
# ── Data ──────────────────────────────────────────────────────────────
ROOT           = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt'
DECIMATED_ROOT = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt_decimated'

DATA_STEP       = 10
TWARN           = 300_000      # 300 ms at 1 MHz
BASELINE_LEN    = 40_000       # 40 ms (matches disruptcnn)
NSUB            = 781_250      # ~781 ms (matches disruptcnn)
STRIDE          = 481_260      # overlap by receptive field

BATCH_SIZE      = 12
NUM_WORKERS     = 4

# ── Model (matches disruptcnn run.sh) ─────────────────────────────────
INPUT_CHANNELS  = 160          # 20 × 8 flattened
N_CLASSES       = 1            # binary per-timestep
LEVELS          = 4
NHID            = 80           # hidden channels per level
KERNEL_SIZE     = 15
DILATION_BASE   = 10
DROPOUT         = 0.1

# ── Training ──────────────────────────────────────────────────────────
EPOCHS          = 50
LR              = 2e-3
CLIP            = 0.3
WARMUP_EPOCHS   = 3
WARMUP_FACTOR   = 8            # start LR = LR / WARMUP_FACTOR

CHECKPOINT_DIR  = Path('checkpoints_tcn')
CHECKPOINT_DIR.mkdir(exist_ok=True)

## 2. TCN Model

Temporal Convolutional Network from [Bai et al. 2018](https://arxiv.org/abs/1803.01271),
with modifications for arbitrary dilation factors (from disruptcnn).

In [None]:
class Chomp1d(nn.Module):
    """Remove trailing padding to enforce causality."""
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super().__init__()
        self.conv1 = weight_norm(nn.Conv1d(
            n_inputs, n_outputs, kernel_size,
            stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(
            n_outputs, n_outputs, kernel_size,
            stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(
            self.conv1, self.chomp1, self.relu1, self.dropout1,
            self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, dilation_size=2,
                 kernel_size=2, dropout=0.2):
        super().__init__()
        layers = []
        num_levels = len(num_channels)
        if np.isscalar(dilation_size):
            dilation_size = [dilation_size ** i for i in range(num_levels)]
        for i in range(num_levels):
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            layers.append(TemporalBlock(
                in_ch, out_ch, kernel_size, stride=1,
                padding=(kernel_size - 1) * dilation_size[i],
                dilation=dilation_size[i], dropout=dropout))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class TCN(nn.Module):
    def __init__(self, input_size, output_size, num_channels,
                 kernel_size, dropout, dilation_size):
        super().__init__()
        self.tcn = TemporalConvNet(
            input_size, num_channels,
            kernel_size=kernel_size, dropout=dropout,
            dilation_size=dilation_size)
        self.linear = nn.Linear(num_channels[-1], output_size)

    def forward(self, x):
        """x: (N, C_in, L) → output: (N, L)"""
        y = self.tcn(x)                          # (N, C_hid, L)
        o = self.linear(y.permute(0, 2, 1))      # (N, L, 1)
        return torch.sigmoid(o.squeeze(-1))       # (N, L)

## 3. Build model & compute receptive field

In [None]:
def calc_receptive_field(kernel_size, dilation_sizes):
    """Receptive field length of a TCN."""
    return 1 + 2 * (kernel_size - 1) * int(np.sum(dilation_sizes))


def build_model(input_channels, n_classes, levels, nhid,
                kernel_size, dilation_base, dropout, nrecept_target=30_000):
    """Build TCN and compute actual receptive field (matches disruptcnn logic)."""
    channel_sizes = [nhid] * levels

    # adjust last-level dilation so receptive field ≈ nrecept_target
    base_dilations = [dilation_base ** i for i in range(levels - 1)]
    rf_without_last = calc_receptive_field(kernel_size, base_dilations)
    last_dilation = int(np.ceil(
        (nrecept_target - rf_without_last) / (2.0 * (kernel_size - 1))))
    last_dilation = max(last_dilation, 1)
    dilation_sizes = base_dilations + [last_dilation]

    nrecept = calc_receptive_field(kernel_size, dilation_sizes)

    model = TCN(input_channels, n_classes, channel_sizes,
                kernel_size=kernel_size, dropout=dropout,
                dilation_size=dilation_sizes)

    n_params = sum(p.numel() for p in model.parameters())
    print(f'Dilation sizes : {dilation_sizes}')
    print(f'Receptive field: {nrecept:,} samples '
          f'({nrecept / (1e6 / 10) * 1e3:.1f} ms at 100 kHz)')
    print(f'Parameters     : {n_params:,}')
    return model, nrecept


model, NRECEPT = build_model(
    INPUT_CHANNELS, N_CLASSES, LEVELS, NHID,
    KERNEL_SIZE, DILATION_BASE, DROPOUT,
    nrecept_target=30_000,
)
model = model.to(DEVICE)

## 4. Data

In [None]:
ds = ECEiTCNDataset(
    root            = ROOT,
    decimated_root  = DECIMATED_ROOT,
    Twarn           = TWARN,
    baseline_length = BASELINE_LEN,
    data_step       = DATA_STEP,
    nsub            = NSUB,
    stride          = STRIDE,
    normalize       = True,
)
ds.summary()

loaders = create_loaders(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
for name, loader in loaders.items():
    print(f'  {name:>5s}: {len(loader.dataset)} subseqs, {len(loader)} batches')

## 5. Training loop

In [None]:
LOG_EVERY_N_BATCHES = 5   # print a line every N batches during training


def _grad_norm(model):
    """Total L2 norm of all gradients."""
    total = 0.0
    for p in model.parameters():
        if p.grad is not None:
            total += p.grad.data.norm(2).item() ** 2
    return total ** 0.5


def train_one_epoch(model, loader, optimizer, nrecept, device, epoch,
                    n_epochs, clip=0.3, log_every=LOG_EVERY_N_BATCHES):
    """Train for one epoch with verbose per-batch logging."""
    model.train()
    n_batches = len(loader)

    running_loss = 0.0
    running_correct = 0
    running_total = 0
    running_pos = 0
    batch_losses = []

    pbar = tqdm(enumerate(loader), total=n_batches,
                desc=f'Epoch {epoch}/{n_epochs} [train]',
                leave=True, dynamic_ncols=True)

    for batch_idx, (X, target, weight) in pbar:
        B = X.shape[0]
        X = X.view(B, -1, X.shape[-1]).to(device)
        target = target.to(device)
        weight = weight.to(device)

        optimizer.zero_grad()
        output = model(X)

        # loss only on valid region (after receptive field)
        out_v = output[:, nrecept - 1:]
        tgt_v = target[:, nrecept - 1:]
        wgt_v = weight[:, nrecept - 1:]

        loss = F.binary_cross_entropy(out_v, tgt_v, weight=wgt_v)
        loss.backward()

        grad_norm_before = _grad_norm(model)
        if clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        grad_norm_after = _grad_norm(model)

        optimizer.step()

        # ── accumulate stats ──
        batch_loss = loss.item()
        batch_losses.append(batch_loss)
        running_loss += batch_loss

        with torch.no_grad():
            pred = (out_v >= 0.5).float()
            running_correct += (pred == tgt_v).sum().item()
            running_total += tgt_v.numel()
            running_pos += tgt_v.sum().item()

        # ── update progress bar ──
        avg_loss = running_loss / (batch_idx + 1)
        avg_acc = running_correct / max(running_total, 1)
        pos_frac = running_pos / max(running_total, 1)
        lr_now = optimizer.param_groups[0]['lr']

        pbar.set_postfix({
            'loss': f'{avg_loss:.4e}',
            'acc': f'{avg_acc:.3f}',
            'pos%': f'{pos_frac:.3f}',
            'lr': f'{lr_now:.2e}',
        })

        # ── detailed log every N batches ──
        if (batch_idx + 1) % log_every == 0 or (batch_idx + 1) == n_batches:
            tqdm.write(
                f'  [{epoch}/{n_epochs}] '
                f'batch {batch_idx+1:>4d}/{n_batches}  '
                f'loss={batch_loss:.4e}  '
                f'avg_loss={avg_loss:.4e}  '
                f'acc={avg_acc:.4f}  '
                f'pos%={pos_frac:.3f}  '
                f'|grad|={grad_norm_before:.3f}->{grad_norm_after:.3f}  '
                f'lr={lr_now:.2e}'
            )

    epoch_loss = running_loss / max(n_batches, 1)
    epoch_acc = running_correct / max(running_total, 1)
    return {
        'loss': epoch_loss,
        'accuracy': epoch_acc,
        'pos_frac': running_pos / max(running_total, 1),
        'batch_losses': batch_losses,
    }


@torch.no_grad()
def evaluate(model, loader, nrecept, device, epoch=0, n_epochs=0,
             split_name='val', thresholds=None):
    """Compute loss, accuracy, and F1 over a loader with verbose output."""
    model.eval()
    if thresholds is None:
        thresholds = np.linspace(0.05, 0.95, 19)

    total_loss = 0.0
    n_batches = len(loader)
    total = 0
    correct_50 = 0
    total_pos = 0
    TPs = np.zeros(len(thresholds))
    TNs = np.zeros(len(thresholds))
    FPs = np.zeros(len(thresholds))
    FNs = np.zeros(len(thresholds))

    pbar = tqdm(enumerate(loader), total=n_batches,
                desc=f'Epoch {epoch}/{n_epochs} [{split_name}]',
                leave=True, dynamic_ncols=True)

    for batch_idx, (X, target, weight) in pbar:
        B = X.shape[0]
        X = X.view(B, -1, X.shape[-1]).to(device)
        target = target.to(device)
        weight = weight.to(device)

        output = model(X)
        out_v = output[:, nrecept - 1:]
        tgt_v = target[:, nrecept - 1:]
        wgt_v = weight[:, nrecept - 1:]

        loss = F.binary_cross_entropy(out_v, tgt_v, weight=wgt_v)
        total_loss += loss.item()
        total += tgt_v.numel()
        total_pos += tgt_v.sum().item()

        pred_50 = (out_v >= 0.5).float()
        correct_50 += (pred_50 == tgt_v).sum().item()

        for i, th in enumerate(thresholds):
            pred = (out_v >= th).float()
            TPs[i] += ((pred == 1) & (tgt_v == 1)).sum().item()
            TNs[i] += ((pred == 0) & (tgt_v == 0)).sum().item()
            FPs[i] += ((pred == 1) & (tgt_v == 0)).sum().item()
            FNs[i] += ((pred == 0) & (tgt_v == 1)).sum().item()

        avg_loss = total_loss / (batch_idx + 1)
        avg_acc  = correct_50 / max(total, 1)
        pbar.set_postfix({'loss': f'{avg_loss:.4e}', 'acc@0.5': f'{avg_acc:.3f}'})

    avg_loss = total_loss / max(n_batches, 1)

    # best F1 across thresholds
    precision = TPs / (TPs + FPs + 1e-10)
    recall    = TPs / (TPs + FNs + 1e-10)
    f1 = 2 * precision * recall / (precision + recall + 1e-10)
    best_idx = np.argmax(f1)
    accuracy  = (TPs[best_idx] + TNs[best_idx]) / max(total, 1)

    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'acc_at_50': correct_50 / max(total, 1),
        'f1': f1[best_idx],
        'precision': precision[best_idx],
        'recall': recall[best_idx],
        'threshold': thresholds[best_idx],
        'pos_frac': total_pos / max(total, 1),
        'n_timesteps': total,
    }

    tqdm.write(
        f'  [{split_name}] '
        f'loss={avg_loss:.4e}  '
        f'acc@best_th={accuracy:.4f}  acc@0.5={metrics["acc_at_50"]:.4f}  '
        f'F1={f1[best_idx]:.4f}  P={precision[best_idx]:.4f}  R={recall[best_idx]:.4f}  '
        f'th={thresholds[best_idx]:.2f}  '
        f'pos%={metrics["pos_frac"]:.3f}  '
        f'n_ts={total:,}'
    )
    return metrics

In [None]:
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, nesterov=True)

# linear warmup then ReduceLROnPlateau
train_split = 'train' if 'train' in loaders else list(loaders.keys())[0]
val_split   = 'test'  if 'test'  in loaders else train_split

warmup_iters = WARMUP_EPOCHS * len(loaders[train_split])
warmup_lambda = lambda it: (1 - 1/WARMUP_FACTOR) / max(warmup_iters, 1) * it + 1/WARMUP_FACTOR
scheduler_warmup  = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)
scheduler_plateau = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5)

# ── history dict ──
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_f1': [], 'val_acc': [], 'val_precision': [],
    'val_recall': [], 'val_threshold': [], 'lr': [],
}
best_f1 = 0.0
global_step = 0

# ── pretty header ──
print('=' * 110)
print(f'  Training config:  {EPOCHS} epochs | batch={BATCH_SIZE} | '
      f'lr={LR} | clip={CLIP} | warmup={WARMUP_EPOCHS} ep')
print(f'  Splits:  train={len(loaders[train_split].dataset)} subseqs '
      f'({len(loaders[train_split])} batches) | '
      f'val={len(loaders[val_split].dataset)} subseqs '
      f'({len(loaders[val_split])} batches)')
print(f'  Receptive field:  {NRECEPT:,} samples')
print(f'  Device: {DEVICE}')
print('=' * 110)

for epoch in range(1, EPOCHS + 1):
    t0 = time.perf_counter()
    print(f'\n{"─" * 110}')
    print(f'  EPOCH {epoch}/{EPOCHS}   (global step = {global_step:,},'
          f'  lr = {optimizer.param_groups[0]["lr"]:.2e})')
    print(f'{"─" * 110}')

    # ═══════════════════════════════════ TRAIN ═══════════════════════════
    train_metrics = train_one_epoch(
        model, loaders[train_split], optimizer, NRECEPT, DEVICE,
        epoch=epoch, n_epochs=EPOCHS, clip=CLIP,
    )
    global_step += len(loaders[train_split])

    # warmup scheduler steps (per iteration inside train_one_epoch above
    # is the typical approach but we keep it simple: step once per batch)
    if global_step <= warmup_iters:
        scheduler_warmup.step(global_step)

    # ═══════════════════════════════════ VALIDATE ═══════════════════════
    val_metrics = evaluate(
        model, loaders[val_split], NRECEPT, DEVICE,
        epoch=epoch, n_epochs=EPOCHS, split_name='val',
    )

    if global_step > warmup_iters:
        scheduler_plateau.step(val_metrics['loss'])

    lr_now = optimizer.param_groups[0]['lr']
    elapsed = time.perf_counter() - t0

    # ── record history ──
    history['train_loss'].append(train_metrics['loss'])
    history['train_acc'].append(train_metrics['accuracy'])
    history['val_loss'].append(val_metrics['loss'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_acc'].append(val_metrics['accuracy'])
    history['val_precision'].append(val_metrics['precision'])
    history['val_recall'].append(val_metrics['recall'])
    history['val_threshold'].append(val_metrics['threshold'])
    history['lr'].append(lr_now)

    # ── checkpoint ──
    is_best = val_metrics['f1'] > best_f1
    if is_best:
        best_f1 = val_metrics['f1']
        torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_f1': best_f1,
            'threshold': val_metrics['threshold'],
            'nrecept': NRECEPT,
        }, CHECKPOINT_DIR / 'best.pt')

    torch.save({
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'best_f1': best_f1,
        'nrecept': NRECEPT,
    }, CHECKPOINT_DIR / 'last.pt')

    # ── epoch summary ──
    star = '  ** NEW BEST **' if is_best else ''
    print(f'\n  EPOCH {epoch}/{EPOCHS} SUMMARY  ({elapsed:.1f}s){star}')
    print(f'  ┌─────────────────────────────────────────────────────────────┐')
    print(f'  │  Train loss    : {train_metrics["loss"]:.6e}                │')
    print(f'  │  Train acc@0.5 : {train_metrics["accuracy"]:.4f}  '
          f'  pos%: {train_metrics["pos_frac"]:.3f}           │')
    print(f'  │  ─────────────────────────────────────────────────────────  │')
    print(f'  │  Val   loss    : {val_metrics["loss"]:.6e}                │')
    print(f'  │  Val   acc@th  : {val_metrics["accuracy"]:.4f}  '
          f'  acc@0.5: {val_metrics["acc_at_50"]:.4f}          │')
    print(f'  │  Val   F1      : {val_metrics["f1"]:.4f}  '
          f'  P={val_metrics["precision"]:.4f}  R={val_metrics["recall"]:.4f}  '
          f'th={val_metrics["threshold"]:.2f}  │')
    print(f'  │  LR            : {lr_now:.2e}                           │')
    print(f'  │  Best F1 so far: {best_f1:.4f}                              │')
    print(f'  └─────────────────────────────────────────────────────────────┘')

print(f'\n{"═" * 110}')
print(f'  TRAINING COMPLETE — {EPOCHS} epochs, best val F1 = {best_f1:.4f}')
print(f'{"═" * 110}')

## 6. Training curves

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(20, 8))
epochs_range = np.arange(1, len(history['train_loss']) + 1)

# ── Row 1 ──

# (0,0) Loss
axes[0, 0].plot(epochs_range, history['train_loss'], label='Train', linewidth=1.5)
axes[0, 0].plot(epochs_range, history['val_loss'], label='Val', linewidth=1.5)
axes[0, 0].set_ylabel('BCE Loss')
axes[0, 0].set_title('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# (0,1) Accuracy
axes[0, 1].plot(epochs_range, history['train_acc'], label='Train', linewidth=1.5)
axes[0, 1].plot(epochs_range, history['val_acc'], label='Val (best th)', linewidth=1.5)
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim(0, 1)

# (0,2) F1
axes[0, 2].plot(epochs_range, history['val_f1'], color='firebrick', linewidth=2,
                label='Val F1')
axes[0, 2].set_ylabel('F1 Score')
axes[0, 2].set_title(f'Validation F1  (best = {best_f1:.4f})')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].set_ylim(0, 1)

# ── Row 2 ──

# (1,0) Precision & Recall
axes[1, 0].plot(epochs_range, history['val_precision'], label='Precision',
                linewidth=1.5, color='teal')
axes[1, 0].plot(epochs_range, history['val_recall'], label='Recall',
                linewidth=1.5, color='darkorange')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Score')
axes[1, 0].set_title('Val Precision & Recall')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim(0, 1)

# (1,1) Best threshold per epoch
axes[1, 1].plot(epochs_range, history['val_threshold'], color='purple',
                linewidth=1.5, marker='.', markersize=4)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Threshold')
axes[1, 1].set_title('Best threshold per epoch')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim(0, 1)

# (1,2) Learning rate
axes[1, 2].plot(epochs_range, history['lr'], color='darkorange', linewidth=1.5)
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('LR')
axes[1, 2].set_title('Learning Rate')
axes[1, 2].set_yscale('log')
axes[1, 2].grid(True, alpha=0.3)

plt.suptitle(f'Training Summary — Best val F1 = {best_f1:.4f}',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 7. Visualise predictions on validation samples

In [None]:
# Load best checkpoint
ckpt = torch.load(CHECKPOINT_DIR / 'best.pt', map_location=DEVICE)
model.load_state_dict(ckpt['state_dict'])
best_th = ckpt.get('threshold', 0.5)
print(f'Loaded best checkpoint (epoch {ckpt["epoch"]}, F1={ckpt["best_f1"]:.4f}, th={best_th:.2f})')

model.eval()

# grab a batch from validation
val_loader = loaders[val_split]
X_b, target_b, weight_b = next(iter(val_loader))

B = X_b.shape[0]
X_flat = X_b.view(B, -1, X_b.shape[-1]).to(DEVICE)

with torch.no_grad():
    pred_b = model(X_flat).cpu().numpy()

target_np = target_b.numpy()
T_sub = pred_b.shape[-1]
t_ax_ms = np.arange(T_sub) / (1e6 / DATA_STEP / 1000)  # ms

n_show = min(4, B)
fig, axes = plt.subplots(n_show, 1, figsize=(16, 3 * n_show), sharex=True)
if n_show == 1:
    axes = [axes]

for i, ax in enumerate(axes):
    ax.plot(t_ax_ms, pred_b[i], color='steelblue', linewidth=1, label='Prediction')
    ax.plot(t_ax_ms, target_np[i], color='firebrick', linewidth=1, linestyle='--', label='Target')
    ax.axhline(best_th, color='gray', linestyle=':', alpha=0.5, label=f'Threshold={best_th:.2f}')
    ax.axvspan(0, t_ax_ms[NRECEPT - 1], alpha=0.08, color='gray')
    ax.set_ylabel(f'Sample {i}')
    ax.set_ylim(-0.05, 1.05)
    ax.grid(True, alpha=0.2)
    if i == 0:
        ax.legend(loc='upper left', fontsize=9)
        ax.set_title('TCN predictions vs targets (gray = receptive field, not in loss)')

axes[-1].set_xlabel('Time (ms) within subsequence')
plt.tight_layout()
plt.show()