# TCN Disruption Prediction — Training

Per-timestep binary classification using a Temporal Convolutional Network.
Adapted from [disruptcnn](https://github.com/rmchurch/disruptcnn) (Churchill et al. 2019).

Uses `ECEiTCNDataset` with pre-decimated data for fast I/O.

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import weight_norm

from dataset_ecei_tcn import ECEiTCNDataset, create_loaders

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

## 1. Configuration

In [None]:
# ── Data ──────────────────────────────────────────────────────────────
ROOT           = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt'
DECIMATED_ROOT = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt_decimated'

DATA_STEP       = 10
TWARN           = 300_000      # 300 ms at 1 MHz
BASELINE_LEN    = 40_000       # 40 ms (matches disruptcnn)
NSUB            = 781_250      # ~781 ms (matches disruptcnn)
STRIDE          = 481_260      # overlap by receptive field

BATCH_SIZE      = 12
NUM_WORKERS     = 4

# ── Model (matches disruptcnn run.sh) ─────────────────────────────────
INPUT_CHANNELS  = 160          # 20 × 8 flattened
N_CLASSES       = 1            # binary per-timestep
LEVELS          = 4
NHID            = 80           # hidden channels per level
KERNEL_SIZE     = 15
DILATION_BASE   = 10
DROPOUT         = 0.1

# ── Training ──────────────────────────────────────────────────────────
EPOCHS          = 50
LR              = 2e-3
CLIP            = 0.3
WARMUP_EPOCHS   = 3
WARMUP_FACTOR   = 8            # start LR = LR / WARMUP_FACTOR

CHECKPOINT_DIR  = Path('checkpoints_tcn')
CHECKPOINT_DIR.mkdir(exist_ok=True)

## 2. TCN Model

Temporal Convolutional Network from [Bai et al. 2018](https://arxiv.org/abs/1803.01271),
with modifications for arbitrary dilation factors (from disruptcnn).

In [None]:
class Chomp1d(nn.Module):
    """Remove trailing padding to enforce causality."""
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super().__init__()
        self.conv1 = weight_norm(nn.Conv1d(
            n_inputs, n_outputs, kernel_size,
            stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(
            n_outputs, n_outputs, kernel_size,
            stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(
            self.conv1, self.chomp1, self.relu1, self.dropout1,
            self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, dilation_size=2,
                 kernel_size=2, dropout=0.2):
        super().__init__()
        layers = []
        num_levels = len(num_channels)
        if np.isscalar(dilation_size):
            dilation_size = [dilation_size ** i for i in range(num_levels)]
        for i in range(num_levels):
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            layers.append(TemporalBlock(
                in_ch, out_ch, kernel_size, stride=1,
                padding=(kernel_size - 1) * dilation_size[i],
                dilation=dilation_size[i], dropout=dropout))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class TCN(nn.Module):
    def __init__(self, input_size, output_size, num_channels,
                 kernel_size, dropout, dilation_size):
        super().__init__()
        self.tcn = TemporalConvNet(
            input_size, num_channels,
            kernel_size=kernel_size, dropout=dropout,
            dilation_size=dilation_size)
        self.linear = nn.Linear(num_channels[-1], output_size)

    def forward(self, x):
        """x: (N, C_in, L) → output: (N, L)"""
        y = self.tcn(x)                          # (N, C_hid, L)
        o = self.linear(y.permute(0, 2, 1))      # (N, L, 1)
        return torch.sigmoid(o.squeeze(-1))       # (N, L)

## 3. Build model & compute receptive field

In [None]:
def calc_receptive_field(kernel_size, dilation_sizes):
    """Receptive field length of a TCN."""
    return 1 + 2 * (kernel_size - 1) * int(np.sum(dilation_sizes))


def build_model(input_channels, n_classes, levels, nhid,
                kernel_size, dilation_base, dropout, nrecept_target=30_000):
    """Build TCN and compute actual receptive field (matches disruptcnn logic)."""
    channel_sizes = [nhid] * levels

    # adjust last-level dilation so receptive field ≈ nrecept_target
    base_dilations = [dilation_base ** i for i in range(levels - 1)]
    rf_without_last = calc_receptive_field(kernel_size, base_dilations)
    last_dilation = int(np.ceil(
        (nrecept_target - rf_without_last) / (2.0 * (kernel_size - 1))))
    last_dilation = max(last_dilation, 1)
    dilation_sizes = base_dilations + [last_dilation]

    nrecept = calc_receptive_field(kernel_size, dilation_sizes)

    model = TCN(input_channels, n_classes, channel_sizes,
                kernel_size=kernel_size, dropout=dropout,
                dilation_size=dilation_sizes)

    n_params = sum(p.numel() for p in model.parameters())
    print(f'Dilation sizes : {dilation_sizes}')
    print(f'Receptive field: {nrecept:,} samples '
          f'({nrecept / (1e6 / 10) * 1e3:.1f} ms at 100 kHz)')
    print(f'Parameters     : {n_params:,}')
    return model, nrecept


model, NRECEPT = build_model(
    INPUT_CHANNELS, N_CLASSES, LEVELS, NHID,
    KERNEL_SIZE, DILATION_BASE, DROPOUT,
    nrecept_target=30_000,
)
model = model.to(DEVICE)

## 4. Data

In [None]:
ds = ECEiTCNDataset(
    root            = ROOT,
    decimated_root  = DECIMATED_ROOT,
    Twarn           = TWARN,
    baseline_length = BASELINE_LEN,
    data_step       = DATA_STEP,
    nsub            = NSUB,
    stride          = STRIDE,
    normalize       = True,
)
ds.summary()

loaders = create_loaders(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
for name, loader in loaders.items():
    print(f'  {name:>5s}: {len(loader.dataset)} subseqs, {len(loader)} batches')

## 5. Training loop

In [None]:
def train_one_epoch(model, loader, optimizer, nrecept, device, clip=0.3):
    model.train()
    total_loss = 0.0
    n_batches = 0

    for X, target, weight in loader:
        # X: (B, 20, 8, T) → (B, 160, T)
        B = X.shape[0]
        X = X.view(B, -1, X.shape[-1]).to(device)
        target = target.to(device)
        weight = weight.to(device)

        optimizer.zero_grad()
        output = model(X)                         # (B, T)

        # loss only on valid region (after receptive field)
        out_v = output[:, nrecept - 1:]
        tgt_v = target[:, nrecept - 1:]
        wgt_v = weight[:, nrecept - 1:]

        loss = F.binary_cross_entropy(out_v, tgt_v, weight=wgt_v)
        loss.backward()

        if clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        total_loss += loss.item()
        n_batches += 1

    return total_loss / max(n_batches, 1)


@torch.no_grad()
def evaluate(model, loader, nrecept, device, thresholds=None):
    """Compute loss, accuracy, and F1 over a loader."""
    model.eval()
    if thresholds is None:
        thresholds = np.linspace(0.05, 0.95, 19)

    total_loss = 0.0
    n_batches = 0
    total = 0
    TPs = np.zeros(len(thresholds))
    TNs = np.zeros(len(thresholds))
    FPs = np.zeros(len(thresholds))
    FNs = np.zeros(len(thresholds))

    for X, target, weight in loader:
        B = X.shape[0]
        X = X.view(B, -1, X.shape[-1]).to(device)
        target = target.to(device)
        weight = weight.to(device)

        output = model(X)
        out_v = output[:, nrecept - 1:]
        tgt_v = target[:, nrecept - 1:]
        wgt_v = weight[:, nrecept - 1:]

        loss = F.binary_cross_entropy(out_v, tgt_v, weight=wgt_v)
        total_loss += loss.item()
        n_batches += 1
        total += tgt_v.numel()

        for i, th in enumerate(thresholds):
            pred = (out_v >= th).float()
            TPs[i] += ((pred == 1) & (tgt_v == 1)).sum().item()
            TNs[i] += ((pred == 0) & (tgt_v == 0)).sum().item()
            FPs[i] += ((pred == 1) & (tgt_v == 0)).sum().item()
            FNs[i] += ((pred == 0) & (tgt_v == 1)).sum().item()

    avg_loss = total_loss / max(n_batches, 1)

    # best F1 across thresholds
    precision = TPs / (TPs + FPs + 1e-10)
    recall    = TPs / (TPs + FNs + 1e-10)
    f1 = 2 * precision * recall / (precision + recall + 1e-10)
    best_idx = np.argmax(f1)
    accuracy  = (TPs[best_idx] + TNs[best_idx]) / max(total, 1)

    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'f1': f1[best_idx],
        'precision': precision[best_idx],
        'recall': recall[best_idx],
        'threshold': thresholds[best_idx],
    }

In [None]:
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, nesterov=True)

# linear warmup then ReduceLROnPlateau
warmup_iters = WARMUP_EPOCHS * len(loaders.get('train', list(loaders.values())[0]))
warmup_lambda = lambda it: (1 - 1/WARMUP_FACTOR) / max(warmup_iters, 1) * it + 1/WARMUP_FACTOR
scheduler_warmup  = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)
scheduler_plateau = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5)

train_split = 'train' if 'train' in loaders else list(loaders.keys())[0]
val_split   = 'test'  if 'test'  in loaders else train_split

history = {'train_loss': [], 'val_loss': [], 'val_f1': [], 'val_acc': [], 'lr': []}
best_f1 = 0.0
iteration = 0

for epoch in range(1, EPOCHS + 1):
    t0 = time.perf_counter()

    # ── train ──
    model.train()
    epoch_loss = 0.0
    n_batches = 0
    for X, target, weight in loaders[train_split]:
        B = X.shape[0]
        X = X.view(B, -1, X.shape[-1]).to(DEVICE)
        target = target.to(DEVICE)
        weight = weight.to(DEVICE)

        optimizer.zero_grad()
        output = model(X)
        out_v = output[:, NRECEPT - 1:]
        tgt_v = target[:, NRECEPT - 1:]
        wgt_v = weight[:, NRECEPT - 1:]
        loss = F.binary_cross_entropy(out_v, tgt_v, weight=wgt_v)
        loss.backward()
        if CLIP > 0:
            nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1
        iteration += 1

        # warmup scheduler steps per iteration
        if iteration <= warmup_iters:
            scheduler_warmup.step(iteration)

    train_loss = epoch_loss / max(n_batches, 1)

    # ── validate ──
    val_metrics = evaluate(model, loaders[val_split], NRECEPT, DEVICE)

    # plateau scheduler steps per epoch on val loss
    if iteration > warmup_iters:
        scheduler_plateau.step(val_metrics['loss'])

    lr_now = optimizer.param_groups[0]['lr']
    elapsed = time.perf_counter() - t0

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_metrics['loss'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_acc'].append(val_metrics['accuracy'])
    history['lr'].append(lr_now)

    # ── checkpoint ──
    is_best = val_metrics['f1'] > best_f1
    if is_best:
        best_f1 = val_metrics['f1']
        torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_f1': best_f1,
            'threshold': val_metrics['threshold'],
            'nrecept': NRECEPT,
        }, CHECKPOINT_DIR / 'best.pt')

    torch.save({
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'best_f1': best_f1,
        'nrecept': NRECEPT,
    }, CHECKPOINT_DIR / 'last.pt')

    star = ' *' if is_best else ''
    print(f'Epoch {epoch:3d}/{EPOCHS}  '
          f'train_loss={train_loss:.4e}  '
          f'val_loss={val_metrics["loss"]:.4e}  '
          f'acc={val_metrics["accuracy"]:.4f}  '
          f'F1={val_metrics["f1"]:.4f} (th={val_metrics["threshold"]:.2f})  '
          f'lr={lr_now:.2e}  '
          f'{elapsed:.1f}s{star}')

## 6. Training curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

epochs_range = np.arange(1, len(history['train_loss']) + 1)

# Loss
axes[0].plot(epochs_range, history['train_loss'], label='Train')
axes[0].plot(epochs_range, history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('BCE Loss')
axes[0].set_title('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# F1 & Accuracy
axes[1].plot(epochs_range, history['val_f1'], label='F1', color='firebrick')
axes[1].plot(epochs_range, history['val_acc'], label='Accuracy', color='steelblue')
axes[1].set_xlabel('Epoch')
axes[1].set_title('Validation F1 & Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim(0, 1)

# Learning rate
axes[2].plot(epochs_range, history['lr'], color='darkorange')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('LR')
axes[2].set_title('Learning Rate')
axes[2].set_yscale('log')
axes[2].grid(True, alpha=0.3)

plt.suptitle(f'Best val F1 = {best_f1:.4f}', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

## 7. Visualise predictions on validation samples

In [None]:
# Load best checkpoint
ckpt = torch.load(CHECKPOINT_DIR / 'best.pt', map_location=DEVICE)
model.load_state_dict(ckpt['state_dict'])
best_th = ckpt.get('threshold', 0.5)
print(f'Loaded best checkpoint (epoch {ckpt["epoch"]}, F1={ckpt["best_f1"]:.4f}, th={best_th:.2f})')

model.eval()

# grab a batch from validation
val_loader = loaders[val_split]
X_b, target_b, weight_b = next(iter(val_loader))

B = X_b.shape[0]
X_flat = X_b.view(B, -1, X_b.shape[-1]).to(DEVICE)

with torch.no_grad():
    pred_b = model(X_flat).cpu().numpy()

target_np = target_b.numpy()
T_sub = pred_b.shape[-1]
t_ax_ms = np.arange(T_sub) / (1e6 / DATA_STEP / 1000)  # ms

n_show = min(4, B)
fig, axes = plt.subplots(n_show, 1, figsize=(16, 3 * n_show), sharex=True)
if n_show == 1:
    axes = [axes]

for i, ax in enumerate(axes):
    ax.plot(t_ax_ms, pred_b[i], color='steelblue', linewidth=1, label='Prediction')
    ax.plot(t_ax_ms, target_np[i], color='firebrick', linewidth=1, linestyle='--', label='Target')
    ax.axhline(best_th, color='gray', linestyle=':', alpha=0.5, label=f'Threshold={best_th:.2f}')
    ax.axvspan(0, t_ax_ms[NRECEPT - 1], alpha=0.08, color='gray')
    ax.set_ylabel(f'Sample {i}')
    ax.set_ylim(-0.05, 1.05)
    ax.grid(True, alpha=0.2)
    if i == 0:
        ax.legend(loc='upper left', fontsize=9)
        ax.set_title('TCN predictions vs targets (gray = receptive field, not in loss)')

axes[-1].set_xlabel('Time (ms) within subsequence')
plt.tight_layout()
plt.show()