# TCN Disruption Prediction — Evaluation

Loads the best checkpoint and performs a thorough evaluation on the **test** split:

1. **Global metrics** — loss, accuracy, F1, precision, recall across thresholds
2. **ROC & Precision-Recall curves**
3. **Confusion matrix** at the optimal threshold
4. **Per-shot analysis** — which shots are well-predicted vs. missed
5. **Prediction time-series** — overlay predictions vs. targets for individual subsequences
6. **Prediction distribution** — histogram of model outputs for positive vs. negative timesteps

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.utils.data import DataLoader, Subset

from dataset_ecei_tcn import ECEiTCNDataset, create_loaders

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

## 1. Configuration

Must match the training config exactly.

In [None]:
# ── Data (SciServer paths) ────────────────────────────────────────────
ROOT           = '/home/idies/workspace/Storage/yhuang2/persistent/ecei/dsrpt'
DECIMATED_ROOT = '/home/idies/workspace/Storage/yhuang2/persistent/ecei/dsrpt_decimated'
CLEAR_ROOT     = '/home/idies/workspace/Storage/yhuang2/persistent/ecei/clear'
CLEAR_DECIMATED_ROOT = '/home/idies/workspace/Storage/yhuang2/persistent/ecei/clear_decimated'

DATA_STEP       = 10
TWARN           = 300_000
BASELINE_LEN    = 40_000
NSUB            = 781_250

BATCH_SIZE      = 12
NUM_WORKERS     = 4

# ── Model ─────────────────────────────────────────────────────────────
INPUT_CHANNELS  = 160
N_CLASSES       = 1
LEVELS          = 4
NHID            = 80
KERNEL_SIZE     = 15
DILATION_BASE   = 10
DROPOUT         = 0.1

# Point to the correct checkpoint directory:
#   'checkpoints_tcn'      — single-GPU training (train_tcn.ipynb)
#   'checkpoints_tcn_ddp'  — multi-GPU DDP training (train_tcn_ddp.py)
CHECKPOINT_DIR  = Path('checkpoints_tcn_ddp')
CHECKPOINT_PATH = CHECKPOINT_DIR / 'best.pt'   # change to 'last.pt' if desired

EVAL_SPLIT      = 'test'   # which split to evaluate

## 2. TCN Model definition

Identical to `train_tcn.ipynb`.

In [None]:
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size
    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super().__init__()
        self.conv1 = weight_norm(nn.Conv1d(
            n_inputs, n_outputs, kernel_size,
            stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(
            n_outputs, n_outputs, kernel_size,
            stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(
            self.conv1, self.chomp1, self.relu1, self.dropout1,
            self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, dilation_size=2,
                 kernel_size=2, dropout=0.2):
        super().__init__()
        layers = []
        num_levels = len(num_channels)
        if np.isscalar(dilation_size):
            dilation_size = [dilation_size ** i for i in range(num_levels)]
        for i in range(num_levels):
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            layers.append(TemporalBlock(
                in_ch, out_ch, kernel_size, stride=1,
                padding=(kernel_size - 1) * dilation_size[i],
                dilation=dilation_size[i], dropout=dropout))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class TCN(nn.Module):
    def __init__(self, input_size, output_size, num_channels,
                 kernel_size, dropout, dilation_size):
        super().__init__()
        self.tcn = TemporalConvNet(
            input_size, num_channels,
            kernel_size=kernel_size, dropout=dropout,
            dilation_size=dilation_size)
        self.linear = nn.Linear(num_channels[-1], output_size)

    def forward(self, x):
        y = self.tcn(x)
        o = self.linear(y.permute(0, 2, 1))
        return torch.sigmoid(o.squeeze(-1))

In [None]:
def calc_receptive_field(kernel_size, dilation_sizes):
    return 1 + 2 * (kernel_size - 1) * int(np.sum(dilation_sizes))


def build_model(input_channels, n_classes, levels, nhid,
                kernel_size, dilation_base, dropout, nrecept_target=30_000):
    channel_sizes = [nhid] * levels
    base_dilations = [dilation_base ** i for i in range(levels - 1)]
    rf_without_last = calc_receptive_field(kernel_size, base_dilations)
    last_dilation = int(np.ceil(
        (nrecept_target - rf_without_last) / (2.0 * (kernel_size - 1))))
    last_dilation = max(last_dilation, 1)
    dilation_sizes = base_dilations + [last_dilation]
    nrecept = calc_receptive_field(kernel_size, dilation_sizes)

    model = TCN(input_channels, n_classes, channel_sizes,
                kernel_size=kernel_size, dropout=dropout,
                dilation_size=dilation_sizes)
    n_params = sum(p.numel() for p in model.parameters())
    print(f'Dilation sizes : {dilation_sizes}')
    print(f'Receptive field: {nrecept:,} samples')
    print(f'Parameters     : {n_params:,}')
    return model, nrecept


model, NRECEPT = build_model(
    INPUT_CHANNELS, N_CLASSES, LEVELS, NHID,
    KERNEL_SIZE, DILATION_BASE, DROPOUT,
    nrecept_target=30_000,
)
model = model.to(DEVICE)

STRIDE = (NSUB // DATA_STEP - NRECEPT + 1) * DATA_STEP
print(f'Stride (raw)   : {STRIDE:,}')

## 3. Load checkpoint

In [None]:
assert CHECKPOINT_PATH.exists(), f'Checkpoint not found: {CHECKPOINT_PATH}'

ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt['state_dict'])
model.eval()

ckpt_epoch = ckpt.get('epoch', '?')
ckpt_f1    = ckpt.get('best_f1', float('nan'))
ckpt_th    = ckpt.get('threshold', 0.5)
ckpt_nrecept = ckpt.get('nrecept', NRECEPT)

print(f'Loaded checkpoint: {CHECKPOINT_PATH}')
print(f'  epoch      = {ckpt_epoch}')
print(f'  best F1    = {ckpt_f1:.4f}')
print(f'  threshold  = {ckpt_th:.2f}')
print(f'  nrecept    = {ckpt_nrecept:,}')

## 4. Load dataset & build evaluation loader

In [None]:
import inspect
_sig = inspect.signature(ECEiTCNDataset.__init__)
_kw = dict(
    root=ROOT,
    decimated_root=DECIMATED_ROOT,
    Twarn=TWARN,
    baseline_length=BASELINE_LEN,
    data_step=DATA_STEP,
    nsub=NSUB,
    stride=STRIDE,
    normalize=True,
)
if 'clear_root' in _sig.parameters:
    _kw['clear_root'] = CLEAR_ROOT
    _kw['clear_decimated_root'] = CLEAR_DECIMATED_ROOT
ds = ECEiTCNDataset(**_kw)
ds.summary()

eval_idx = ds.get_split_indices(EVAL_SPLIT)
print(f'\nEvaluating on split={EVAL_SPLIT!r}: {len(eval_idx)} subsequences')
n_dis = int(ds.seq_has_disrupt[eval_idx].sum())
print(f'  {n_dis} disruptive, {len(eval_idx) - n_dis} clear')

eval_subset = Subset(ds, eval_idx)
eval_loader = DataLoader(
    eval_subset,
    batch_size  = BATCH_SIZE,
    shuffle     = False,
    num_workers = NUM_WORKERS,
    pin_memory  = True,
    drop_last   = False,
)

## 5. Collect all predictions

In [None]:
all_preds   = []   # model sigmoid outputs  (per-timestep, valid region only)
all_targets = []   # binary labels           (per-timestep, valid region only)
all_preds_full   = []  # full subsequence predictions
all_targets_full = []  # full subsequence targets

# Per-subsequence metadata for later per-shot aggregation
subseq_shot_ids  = []  # which shot each subseq belongs to
subseq_starts    = []  # start index in data-file space

model.eval()
with torch.no_grad():
    for X, target, _weight in tqdm(eval_loader, desc='Inference'):
        B = X.shape[0]
        X_flat = X.view(B, -1, X.shape[-1]).to(DEVICE)
        output = model(X_flat).cpu()

        # valid region (after receptive field warm-up)
        out_v = output[:, NRECEPT - 1:]
        tgt_v = target[:, NRECEPT - 1:]

        all_preds.append(out_v.numpy())
        all_targets.append(tgt_v.numpy())
        all_preds_full.append(output.numpy())
        all_targets_full.append(target.numpy())

# Flatten valid-region predictions into 1-D arrays
all_preds   = np.concatenate([p.reshape(-1) for p in all_preds])
all_targets = np.concatenate([t.reshape(-1) for t in all_targets])

# Keep full-subsequence predictions as list of (B, T) arrays
all_preds_full   = np.concatenate(all_preds_full,   axis=0)  # (N_subseqs, T_sub)
all_targets_full = np.concatenate(all_targets_full, axis=0)  # (N_subseqs, T_sub)

# Map each eval subsequence back to its shot
for i, global_idx in enumerate(eval_idx):
    subseq_shot_ids.append(ds.shots[ds.seq_shot_idx[global_idx]])
    subseq_starts.append(ds.seq_start[global_idx])
subseq_shot_ids = np.array(subseq_shot_ids)
subseq_starts   = np.array(subseq_starts)

n_pos = int((all_targets == 1).sum())
n_neg = int((all_targets == 0).sum())
print(f'\nTotal timesteps (valid region): {len(all_targets):,}')
print(f'  Positive (disruptive): {n_pos:,} ({n_pos/len(all_targets)*100:.1f}%)')
print(f'  Negative (clear)     : {n_neg:,} ({n_neg/len(all_targets)*100:.1f}%)')

## 5b. Subsequence-level prediction (majority voting)

Utility to get **one prediction per subsequence** by aggregating per-timestep outputs (majority vote or mean-threshold). Works the same on SciServer and main branches — no path or branch logic.

In [None]:
def subsequence_prediction_majority_vote(
    preds_full,
    threshold=0.5,
    valid_start=None,
    method='mean',
):
    """
    Get one prediction per subsequence by majority voting over timesteps.

    Parameters
    ----------
    preds_full : np.ndarray
        Shape (N_subseqs, T_sub), model sigmoid outputs per timestep.
    threshold : float
        Used to binarize (method='mean': mean >= threshold -> 1) or
        per-timestep vote (method='majority': fraction of timesteps >= threshold).
    valid_start : int or None
        Index where valid region starts (e.g. NRECEPT - 1). If None, use full sequence.
    method : str
        'mean' — subseq label = 1 if mean(preds) >= threshold else 0.
        'majority' — subseq label = mode of (preds >= threshold) over time (strict majority).

    Returns
    -------
    pred_subseq : np.ndarray
        Shape (N_subseqs,) int 0/1, one prediction per subsequence.
    """
    if valid_start is not None:
        preds = preds_full[:, valid_start:]  # (N, T_valid)
    else:
        preds = preds_full

    if method == 'mean':
        mean_per_sub = preds.mean(axis=-1)
        pred_subseq = (mean_per_sub >= threshold).astype(np.int64)
    elif method == 'majority':
        binary = (preds >= threshold).astype(np.int64)
        pred_subseq = (binary.mean(axis=-1) > 0.5).astype(np.int64)
    else:
        raise ValueError(f"method must be 'mean' or 'majority', got {method!r}")
    return pred_subseq


# Apply to current eval predictions (valid region only)
_threshold = best_th if 'best_th' in dir() else ckpt_th
valid_start = NRECEPT - 1
subseq_pred_majority = subsequence_prediction_majority_vote(
    all_preds_full,
    threshold=_threshold,
    valid_start=valid_start,
    method='mean',
)

# Ground truth at subsequence level: 1 if this subsequence contains disruption, else 0
subseq_gt = ds.seq_has_disrupt[eval_idx].astype(np.int64)

# Subsequence-level metrics
tp = ((subseq_pred_majority == 1) & (subseq_gt == 1)).sum()
tn = ((subseq_pred_majority == 0) & (subseq_gt == 0)).sum()
fp = ((subseq_pred_majority == 1) & (subseq_gt == 0)).sum()
fn = ((subseq_pred_majority == 0) & (subseq_gt == 1)).sum()
acc_sub = (tp + tn) / max(tp + tn + fp + fn, 1)
prec_sub = tp / max(tp + fp, 1)
rec_sub = tp / max(tp + fn, 1)
f1_sub = 2 * prec_sub * rec_sub / max(prec_sub + rec_sub, 1e-10)

print('Subsequence-level (majority vote, method=mean):')
print(f'  Threshold = {_threshold:.3f}')
print(f'  Accuracy  = {acc_sub:.4f}')
print(f'  Precision = {prec_sub:.4f}  Recall = {rec_sub:.4f}  F1 = {f1_sub:.4f}')
print(f'  TP={tp}, TN={tn}, FP={fp}, FN={fn}')

## 5b. Subsequence-level prediction (majority voting)

Utility to get **one prediction per subsequence** by aggregating per-timestep outputs (majority vote or mean-threshold). Works the same on SciServer and main branches — no path or branch logic.

In [None]:
def subsequence_prediction_majority_vote(
    preds_full,
    threshold=0.5,
    valid_start=None,
    method='mean',
):
    """
    Get one prediction per subsequence by majority voting over timesteps.

    Parameters
    ----------
    preds_full : np.ndarray
        Shape (N_subseqs, T_sub), model sigmoid outputs per timestep.
    threshold : float
        Used to binarize (method='mean': mean >= threshold -> 1) or
        per-timestep vote (method='majority': fraction of timesteps >= threshold).
    valid_start : int or None
        Index where valid region starts (e.g. NRECEPT - 1). If None, use full sequence.
    method : str
        'mean' — subseq label = 1 if mean(preds) >= threshold else 0.
        'majority' — subseq label = mode of (preds >= threshold) over time (strict majority).

    Returns
    -------
    pred_subseq : np.ndarray
        Shape (N_subseqs,) int 0/1, one prediction per subsequence.
    """
    if valid_start is not None:
        preds = preds_full[:, valid_start:]  # (N, T_valid)
    else:
        preds = preds_full

    if method == 'mean':
        mean_per_sub = preds.mean(axis=-1)
        pred_subseq = (mean_per_sub >= threshold).astype(np.int64)
    elif method == 'majority':
        binary = (preds >= threshold).astype(np.int64)
        # majority = 1 if more than half of timesteps are 1
        pred_subseq = (binary.mean(axis=-1) > 0.5).astype(np.int64)
    else:
        raise ValueError(f"method must be 'mean' or 'majority', got {method!r}")
    return pred_subseq


# Apply to current eval predictions (valid region only)
# Use ckpt_th here (from section 3); after section 6 you can use best_th and re-run for optimal F1
valid_start = NRECEPT - 1
_threshold = best_th if 'best_th' in dir() else ckpt_th
subseq_pred_majority = subsequence_prediction_majority_vote(
    all_preds_full,
    threshold=_threshold,
    valid_start=valid_start,
    method='mean',
)

# Ground truth at subsequence level: 1 if this subsequence contains disruption, else 0
subseq_gt = ds.seq_has_disrupt[eval_idx].astype(np.int64)

# Subsequence-level metrics
tp = ((subseq_pred_majority == 1) & (subseq_gt == 1)).sum()
tn = ((subseq_pred_majority == 0) & (subseq_gt == 0)).sum()
fp = ((subseq_pred_majority == 1) & (subseq_gt == 0)).sum()
fn = ((subseq_pred_majority == 0) & (subseq_gt == 1)).sum()
acc_sub = (tp + tn) / max(tp + tn + fp + fn, 1)
prec_sub = tp / max(tp + fp, 1)
rec_sub = tp / max(tp + fn, 1)
f1_sub = 2 * prec_sub * rec_sub / max(prec_sub + rec_sub, 1e-10)

print('Subsequence-level (majority vote, method=mean):')
print(f'  Threshold = {_threshold:.3f}')
print(f'  Accuracy  = {acc_sub:.4f}')
print(f'  Precision = {prec_sub:.4f}  Recall = {rec_sub:.4f}  F1 = {f1_sub:.4f}')
print(f'  TP={tp}, TN={tn}, FP={fp}, FN={fn}')

## 6. Global metrics across thresholds

In [None]:
thresholds = np.linspace(0.01, 0.99, 199)

precision_arr = np.zeros(len(thresholds))
recall_arr    = np.zeros(len(thresholds))
f1_arr        = np.zeros(len(thresholds))
accuracy_arr  = np.zeros(len(thresholds))
fpr_arr       = np.zeros(len(thresholds))   # for ROC
tpr_arr       = np.zeros(len(thresholds))   # for ROC

for i, th in enumerate(thresholds):
    pred = (all_preds >= th).astype(float)
    TP = ((pred == 1) & (all_targets == 1)).sum()
    TN = ((pred == 0) & (all_targets == 0)).sum()
    FP = ((pred == 1) & (all_targets == 0)).sum()
    FN = ((pred == 0) & (all_targets == 1)).sum()

    precision_arr[i] = TP / max(TP + FP, 1)
    recall_arr[i]    = TP / max(TP + FN, 1)
    f1_arr[i]        = 2 * precision_arr[i] * recall_arr[i] / max(precision_arr[i] + recall_arr[i], 1e-10)
    accuracy_arr[i]  = (TP + TN) / max(TP + TN + FP + FN, 1)
    fpr_arr[i]       = FP / max(FP + TN, 1)
    tpr_arr[i]       = TP / max(TP + FN, 1)

best_idx = np.argmax(f1_arr)
best_th  = thresholds[best_idx]

print('=' * 70)
print('  GLOBAL EVALUATION METRICS')
print('=' * 70)
print(f'  Checkpoint threshold : {ckpt_th:.2f}')
print(f'  Best threshold (F1)  : {best_th:.3f}')
print(f'  ────────────────────────────────────────')
print(f'  F1        = {f1_arr[best_idx]:.4f}')
print(f'  Precision = {precision_arr[best_idx]:.4f}')
print(f'  Recall    = {recall_arr[best_idx]:.4f}')
print(f'  Accuracy  = {accuracy_arr[best_idx]:.4f}')
print(f'  ────────────────────────────────────────')
# Also show metrics at the checkpoint's saved threshold
ckpt_idx = np.argmin(np.abs(thresholds - ckpt_th))
print(f'  Metrics at checkpoint th={ckpt_th:.2f}:')
print(f'    F1={f1_arr[ckpt_idx]:.4f}  P={precision_arr[ckpt_idx]:.4f}  '
      f'R={recall_arr[ckpt_idx]:.4f}  Acc={accuracy_arr[ckpt_idx]:.4f}')
print(f'  ────────────────────────────────────────')
# Accuracy at th=0.5
idx_50 = np.argmin(np.abs(thresholds - 0.5))
print(f'  Metrics at th=0.50:')
print(f'    F1={f1_arr[idx_50]:.4f}  P={precision_arr[idx_50]:.4f}  '
      f'R={recall_arr[idx_50]:.4f}  Acc={accuracy_arr[idx_50]:.4f}')
print('=' * 70)

## 7. ROC & Precision-Recall curves

In [None]:
# Sort for proper curve plotting
roc_order = np.argsort(fpr_arr)

# Approximate AUC (ROC) using trapezoidal rule
auc_roc = np.abs(np.trapz(tpr_arr[roc_order], fpr_arr[roc_order]))
# Approximate AUC (PR)
pr_order = np.argsort(recall_arr)
auc_pr = np.abs(np.trapz(precision_arr[pr_order], recall_arr[pr_order]))

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# ── ROC Curve ──
ax = axes[0]
ax.plot(fpr_arr[roc_order], tpr_arr[roc_order], color='steelblue', linewidth=2,
        label=f'ROC (AUC={auc_roc:.4f})')
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='Random')
# mark best F1 threshold
ax.plot(fpr_arr[best_idx], tpr_arr[best_idx], 'ro', markersize=8,
        label=f'Best F1 th={best_th:.2f}')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve')
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)

# ── Precision-Recall Curve ──
ax = axes[1]
ax.plot(recall_arr[pr_order], precision_arr[pr_order], color='firebrick',
        linewidth=2, label=f'PR (AUC={auc_pr:.4f})')
ax.plot(recall_arr[best_idx], precision_arr[best_idx], 'bo', markersize=8,
        label=f'Best F1 th={best_th:.2f}')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curve')
ax.legend(loc='lower left')
ax.grid(True, alpha=0.3)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)

# ── F1 / Precision / Recall vs Threshold ──
ax = axes[2]
ax.plot(thresholds, f1_arr, color='black', linewidth=2, label='F1')
ax.plot(thresholds, precision_arr, color='teal', linewidth=1.5, linestyle='--', label='Precision')
ax.plot(thresholds, recall_arr, color='darkorange', linewidth=1.5, linestyle='--', label='Recall')
ax.axvline(best_th, color='red', linestyle=':', alpha=0.6, label=f'Best th={best_th:.2f}')
ax.set_xlabel('Threshold')
ax.set_ylabel('Score')
ax.set_title('Metrics vs Threshold')
ax.legend(loc='best', fontsize=9)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1.02)

plt.suptitle(f'Evaluation on {EVAL_SPLIT} split  |  Best F1={f1_arr[best_idx]:.4f}  |  AUC-ROC={auc_roc:.4f}',
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

## 8. Confusion matrix

In [None]:
pred_best = (all_preds >= best_th).astype(int)
target_int = all_targets.astype(int)

TP = int(((pred_best == 1) & (target_int == 1)).sum())
TN = int(((pred_best == 0) & (target_int == 0)).sum())
FP = int(((pred_best == 1) & (target_int == 0)).sum())
FN = int(((pred_best == 0) & (target_int == 1)).sum())

cm = np.array([[TN, FP], [FN, TP]])
cm_pct = cm / cm.sum() * 100

fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(cm_pct, cmap='Blues', vmin=0)

labels = ['Clear (0)', 'Disruptive (1)']
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
ax.set_xlabel('Predicted', fontsize=12)
ax.set_ylabel('True', fontsize=12)

for i in range(2):
    for j in range(2):
        color = 'white' if cm_pct[i, j] > 40 else 'black'
        ax.text(j, i, f'{cm[i, j]:,}\n({cm_pct[i, j]:.1f}%)',
                ha='center', va='center', fontsize=13, color=color,
                fontweight='bold')

ax.set_title(f'Confusion Matrix @ threshold={best_th:.3f}\n'
             f'Acc={accuracy_arr[best_idx]:.4f}  F1={f1_arr[best_idx]:.4f}',
             fontsize=12)
fig.colorbar(im, ax=ax, label='% of total timesteps')
plt.tight_layout()
plt.show()

print(f'TP={TP:>10,}   FP={FP:>10,}')
print(f'FN={FN:>10,}   TN={TN:>10,}')

## 9. Prediction distribution (positive vs negative)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# ── Histogram ──
ax = axes[0]
bins = np.linspace(0, 1, 101)
ax.hist(all_preds[all_targets == 0], bins=bins, alpha=0.6, density=True,
        label='Clear (y=0)', color='steelblue')
ax.hist(all_preds[all_targets == 1], bins=bins, alpha=0.6, density=True,
        label='Disruptive (y=1)', color='firebrick')
ax.axvline(best_th, color='black', linestyle='--', linewidth=1.5,
           label=f'Best th={best_th:.2f}')
ax.set_xlabel('Model output (sigmoid)', fontsize=11)
ax.set_ylabel('Density', fontsize=11)
ax.set_title('Prediction Distribution by True Label')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.2)

# ── Box plot of mean prediction per subsequence ──
ax = axes[1]
subseq_mean_pred = all_preds_full[:, NRECEPT - 1:].mean(axis=-1)
subseq_labels    = (all_targets_full[:, NRECEPT - 1:].sum(axis=-1) > 0).astype(int)

data_box = [subseq_mean_pred[subseq_labels == 0],
            subseq_mean_pred[subseq_labels == 1]]
bp = ax.boxplot(data_box, labels=['Clear', 'Disruptive'], widths=0.5,
                patch_artist=True)
bp['boxes'][0].set_facecolor('steelblue')
bp['boxes'][0].set_alpha(0.4)
bp['boxes'][1].set_facecolor('firebrick')
bp['boxes'][1].set_alpha(0.4)
ax.set_ylabel('Mean prediction (valid region)', fontsize=11)
ax.set_title('Per-Subsequence Mean Prediction')
ax.grid(True, alpha=0.2)

plt.suptitle(f'Prediction distributions — {EVAL_SPLIT} split', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

## 10. Per-shot analysis

For each shot in the eval split, compute:
- Average F1 across its subsequences
- Whether the disruption was *detected* (any positive prediction in the warning window)

In [None]:
unique_shots = np.unique(subseq_shot_ids)

shot_results = []

for shot in unique_shots:
    mask = subseq_shot_ids == shot
    idx_local = np.where(mask)[0]

    preds_shot  = all_preds_full[idx_local][:, NRECEPT - 1:]   # (n_sub, T_valid)
    targets_shot = all_targets_full[idx_local][:, NRECEPT - 1:]

    pred_bin = (preds_shot >= best_th).astype(float)

    tp = ((pred_bin == 1) & (targets_shot == 1)).sum()
    fp = ((pred_bin == 1) & (targets_shot == 0)).sum()
    fn = ((pred_bin == 0) & (targets_shot == 1)).sum()
    tn = ((pred_bin == 0) & (targets_shot == 0)).sum()

    prec = tp / max(tp + fp, 1)
    rec  = tp / max(tp + fn, 1)
    f1   = 2 * prec * rec / max(prec + rec, 1e-10)
    acc  = (tp + tn) / max(tp + tn + fp + fn, 1)

    has_disrupt = (targets_shot.sum() > 0)
    detected    = has_disrupt and (pred_bin[targets_shot == 1].sum() > 0)

    # Mean prediction in the disruption region vs clear region
    if has_disrupt and targets_shot.sum() > 0:
        mean_pred_pos = preds_shot[targets_shot == 1].mean()
    else:
        mean_pred_pos = float('nan')
    mean_pred_neg = preds_shot[targets_shot == 0].mean() if (targets_shot == 0).sum() > 0 else float('nan')

    shot_results.append({
        'shot': shot,
        'n_subseqs': int(mask.sum()),
        'has_disrupt': has_disrupt,
        'detected': detected,
        'f1': f1,
        'precision': prec,
        'recall': rec,
        'accuracy': acc,
        'mean_pred_pos': mean_pred_pos,
        'mean_pred_neg': mean_pred_neg,
    })

# ── Print table ──
print(f"{'Shot':>8s}  {'#Sub':>5s}  {'Disrupt':>7s}  {'Detect':>6s}  "
      f"{'F1':>6s}  {'Prec':>6s}  {'Rec':>6s}  {'Acc':>6s}  "
      f"{'Pred(+)':>8s}  {'Pred(-)':>8s}")
print('─' * 88)
for r in shot_results:
    det_str = '  YES' if r['detected'] else '   NO' if r['has_disrupt'] else '  n/a'
    dis_str = '  YES' if r['has_disrupt'] else '   NO'
    print(f"{r['shot']:>8d}  {r['n_subseqs']:>5d}  {dis_str}  {det_str}  "
          f"{r['f1']:>6.3f}  {r['precision']:>6.3f}  {r['recall']:>6.3f}  {r['accuracy']:>6.3f}  "
          f"{r['mean_pred_pos']:>8.4f}  {r['mean_pred_neg']:>8.4f}")

# Overall shot-level detection rate
n_disruptive_shots = sum(1 for r in shot_results if r['has_disrupt'])
n_detected = sum(1 for r in shot_results if r['detected'])
print(f'\nDisruption detection rate: {n_detected}/{n_disruptive_shots} '
      f'({n_detected/max(n_disruptive_shots,1)*100:.0f}%)')

In [None]:
# ── Bar chart of per-shot F1 ──
shot_labels = [str(r['shot']) for r in shot_results]
shot_f1s    = [r['f1'] for r in shot_results]
shot_colors = ['firebrick' if r['has_disrupt'] else 'steelblue' for r in shot_results]

fig, ax = plt.subplots(figsize=(max(8, len(shot_results) * 0.6), 5))
bars = ax.bar(range(len(shot_results)), shot_f1s, color=shot_colors, alpha=0.7, edgecolor='black', linewidth=0.5)
ax.set_xticks(range(len(shot_results)))
ax.set_xticklabels(shot_labels, rotation=45, ha='right', fontsize=9)
ax.set_ylabel('F1 Score')
ax.set_title(f'Per-Shot F1 — {EVAL_SPLIT} split  (red = disruptive, blue = clear)')
ax.set_ylim(0, 1.05)
ax.axhline(f1_arr[best_idx], color='black', linestyle='--', alpha=0.5, label=f'Global F1={f1_arr[best_idx]:.3f}')
ax.legend()
ax.grid(True, alpha=0.2, axis='y')
plt.tight_layout()
plt.show()

## 11. Prediction time-series visualisation

Plot the model's per-timestep output overlaid with the true target for several subsequences.

In [None]:
T_sub = all_preds_full.shape[-1]
t_ax_ms = np.arange(T_sub) / (1e6 / DATA_STEP / 1000)  # ms

# Pick a mix of disruptive and clear subsequences
dis_idxs   = np.where(subseq_labels == 1)[0]
clear_idxs = np.where(subseq_labels == 0)[0]

n_show_each = 4
show_dis   = dis_idxs[:n_show_each]   if len(dis_idxs)   >= n_show_each else dis_idxs
show_clear = clear_idxs[:n_show_each] if len(clear_idxs) >= n_show_each else clear_idxs
show_all   = np.concatenate([show_dis, show_clear])

n_plots = len(show_all)
fig, axes = plt.subplots(n_plots, 1, figsize=(16, 2.8 * n_plots), sharex=True)
if n_plots == 1:
    axes = [axes]

for i, ax in enumerate(axes):
    idx = show_all[i]
    pred_ts = all_preds_full[idx]
    tgt_ts  = all_targets_full[idx]
    shot_id = subseq_shot_ids[idx]
    label_str = 'disruptive' if subseq_labels[idx] else 'clear'

    ax.plot(t_ax_ms, pred_ts, color='steelblue', linewidth=1, label='Prediction')
    ax.plot(t_ax_ms, tgt_ts, color='firebrick', linewidth=1, linestyle='--', label='Target')
    ax.axhline(best_th, color='gray', linestyle=':', alpha=0.5, label=f'Threshold={best_th:.2f}')
    ax.axvspan(0, t_ax_ms[NRECEPT - 1], alpha=0.06, color='gray')
    ax.set_ylabel(f'Shot {shot_id}\n({label_str})', fontsize=9)
    ax.set_ylim(-0.05, 1.05)
    ax.grid(True, alpha=0.2)
    if i == 0:
        ax.legend(loc='upper left', fontsize=9)
        ax.set_title('Per-timestep predictions vs targets (gray = receptive field warm-up)',
                     fontsize=12)

axes[-1].set_xlabel('Time (ms) within subsequence', fontsize=11)
plt.tight_layout()
plt.show()

## 12. Early-warning analysis

For disruptive subsequences: how early before the actual disruption does the model
cross the threshold?

In [None]:
# For each disruptive subsequence, find the first timestep where
# the prediction exceeds the threshold within the valid region
warning_times_ms = []   # how early the alarm fires (ms before end of disruptive window)
missed = 0

for idx in dis_idxs:
    pred_ts = all_preds_full[idx]
    tgt_ts  = all_targets_full[idx]

    # Only look at valid region
    pred_valid = pred_ts[NRECEPT - 1:]
    tgt_valid  = tgt_ts[NRECEPT - 1:]

    # Find where disruption label starts in valid region
    dis_start = np.where(tgt_valid == 1)[0]
    if len(dis_start) == 0:
        continue
    dis_start_idx = dis_start[0]

    # Find first threshold crossing
    alarm_indices = np.where(pred_valid >= best_th)[0]
    if len(alarm_indices) == 0:
        missed += 1
        continue

    first_alarm = alarm_indices[0]
    # Time advantage: how much earlier than disruption onset
    dt_samples = dis_start_idx - first_alarm
    dt_ms = dt_samples / (1e6 / DATA_STEP) * 1000
    warning_times_ms.append(dt_ms)

warning_times_ms = np.array(warning_times_ms)

print(f'Disruptive subsequences: {len(dis_idxs)}')
print(f'  Detected (alarm fired)  : {len(warning_times_ms)}')
print(f'  Missed (no alarm)       : {missed}')

if len(warning_times_ms) > 0:
    early = warning_times_ms[warning_times_ms > 0]
    late  = warning_times_ms[warning_times_ms <= 0]

    print(f'\n  Early warnings (alarm BEFORE disruption onset): {len(early)}')
    if len(early) > 0:
        print(f'    Mean lead time : {early.mean():.1f} ms')
        print(f'    Median         : {np.median(early):.1f} ms')
        print(f'    Min            : {early.min():.1f} ms')
        print(f'    Max            : {early.max():.1f} ms')

    print(f'  Late warnings (alarm AFTER disruption onset)  : {len(late)}')
    if len(late) > 0:
        print(f'    Mean delay     : {-late.mean():.1f} ms')

    # Histogram
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.hist(warning_times_ms, bins=30, color='teal', alpha=0.7, edgecolor='black', linewidth=0.5)
    ax.axvline(0, color='red', linestyle='--', linewidth=1.5, label='Disruption onset')
    ax.set_xlabel('Warning time (ms, positive = early)', fontsize=11)
    ax.set_ylabel('Count', fontsize=11)
    ax.set_title('Distribution of early-warning lead times')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.2)
    plt.tight_layout()
    plt.show()
else:
    print('  No detections — skipping warning-time plot.')

## 13. Summary

In [None]:
print('\n' + '=' * 70)
print('  EVALUATION SUMMARY')
print('=' * 70)
print(f'  Checkpoint     : {CHECKPOINT_PATH}')
print(f'  Epoch          : {ckpt_epoch}')
print(f'  Eval split     : {EVAL_SPLIT}')
print(f'  Subsequences   : {len(eval_idx)}')
print(f'  Timesteps      : {len(all_targets):,} (valid region)')
print(f'  ─────────────────────────────────────')
print(f'  Best threshold : {best_th:.3f}')
print(f'  F1             : {f1_arr[best_idx]:.4f}')
print(f'  Precision      : {precision_arr[best_idx]:.4f}')
print(f'  Recall         : {recall_arr[best_idx]:.4f}')
print(f'  Accuracy       : {accuracy_arr[best_idx]:.4f}')
print(f'  AUC-ROC        : {auc_roc:.4f}')
print(f'  AUC-PR         : {auc_pr:.4f}')
print(f'  ─────────────────────────────────────')
if len(warning_times_ms) > 0:
    print(f'  Disruption detection rate : {n_detected}/{n_disruptive_shots}')
    if len(early) > 0:
        print(f'  Mean early-warning lead  : {early.mean():.1f} ms')
print('=' * 70)