In [None]:
#!/usr/bin/env python3
# ============================================================
# Phase 6 FINAL — Train ONE model on ALL subjects (single ckpt)

# ============================================================

from __future__ import annotations

import os, math, json, time, random, argparse
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

# ---------------- CONFIG ----------------

class CFG:
    ROOT_DIR    = Path(r"/home/tsultan1/BioRob(Final)/Data")
    DATASET_DIR = ROOT_DIR / "_dataset_icml_v1"

    SSL_PREFIX      = "exports_v1_ssl"         # Phase-5 SSL shards
    BAL_PREFIX      = "exports_v1_balanced"    # Phase-5 balanced shards

    TASK_CODES      = [0, 1, 2, 3, 4, 5]       # must include REST=0

    DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
    NUM_WORKERS = 2

    # ---- FINAL training split inside "all data" ----
    VAL_FRAC    = 0.08      # small internal val for early-stop + threshold tune
    SEED        = 42

    # ---- SSL (Stage 1) ----
    USE_SSL     = True
    SSL_EPOCHS  = 35
    SSL_BATCH   = 64
    SSL_LR      = 1e-3
    SSL_MAX_WINDOWS = 6000      # subsample to speed up if needed
    SSL_MASK_PROB   = 0.15
    SSL_TEMP        = 0.1
    SSL_MODAL_DROPOUT = 0.25

    LAMBDA_MASK     = 1.0
    LAMBDA_CONTRAST = 0.6
    LAMBDA_XMOD     = 0.25

    # ---- Supervised (Stage 2) ----
    SUP_EPOCHS   = 55
    SUP_BATCH    = 64
    SUP_LR       = 2e-4
    WEIGHT_DECAY = 1e-4
    BACKBONE_LR_SCALE = 0.15

    # Warm-start stability
    FREEZE_BACKBONE_EPOCHS = 4  # train heads first, then unfreeze

    # Loss weights
    ALPHA_ACTION = 0.35
    BETA_TASK    = 0.65

    # Model
    D_MODEL      = 128
    DROPOUT      = 0.20
    N_HEADS_FUSE = 4
    N_LAYERS_FUSE = 2
    POOL_STRIDE   = 2

    # Regularization to avoid over-relying on ET
    SUP_ET_DROPOUT = 0.10

    # Training quality
    GRAD_CLIP  = 1.0
    USE_AMP    = True
    USE_EMA    = True
    EMA_DECAY  = 0.995

    # Threshold tuning
    THRESH_GRID = [round(x, 2) for x in np.linspace(0.10, 0.90, 17)]

    # Phase 5.5 features (optional)
    USE_EEG_PSD_FEATURES = True
    USE_EMG_FEATURES     = True


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(CFG.SEED)
torch.backends.cudnn.benchmark = True


# ---------------- SMALL UTILITIES ----------------

def discover_folds(prefix: str) -> List[int]:
    folds = []
    for p in CFG.DATASET_DIR.glob(f"{prefix}_fold*"):
        try:
            fid = int(p.name.split("fold")[-1])
            folds.append(fid)
        except Exception:
            pass
    return sorted(set(folds))


def stratified_split_indices(y: np.ndarray, val_frac: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Simple stratified split by y (e.g., y_task). Avoids sklearn dependency.
    """
    rng = np.random.RandomState(seed)
    y = np.asarray(y).astype(int)

    train_idx, val_idx = [], []
    for cls in np.unique(y):
        idx = np.where(y == cls)[0]
        rng.shuffle(idx)
        n_val = max(1, int(round(len(idx) * val_frac)))
        val_idx.append(idx[:n_val])
        train_idx.append(idx[n_val:])
    train_idx = np.concatenate(train_idx) if train_idx else np.array([], dtype=int)
    val_idx   = np.concatenate(val_idx) if val_idx else np.array([], dtype=int)
    rng.shuffle(train_idx)
    rng.shuffle(val_idx)
    return train_idx, val_idx


# ---------------- DATASET LOADER (ALL SPLITS OF ONE FOLD) ----------------

class AllSplitsShardDataset(Dataset):
    """
    Loads a SINGLE fold's shards from split dirs and concatenates:
      {fold_dir}/{train,val,test}/*_shard_*.npz

    Supports:
      - ssl_mode=True  -> keep all windows
      - ssl_mode=False -> keep windows with y_task in TASK_CODES
      - optional Phase 5.5 features per split:
          features_v1_eeg_psd_full_fold{fid}_{split}.npz  (X_psd, X_emg)
    """
    def __init__(
        self,
        fold_dir: Path,
        fold_id: int,
        splits: List[str],
        ssl_mode: bool,
        max_windows: Optional[int] = None,
        attach_features: bool = True,
    ):
        super().__init__()
        self.fold_dir = Path(fold_dir)
        self.fold_id = int(fold_id)
        self.splits = list(splits)
        self.ssl_mode = bool(ssl_mode)

        X_eeg_list, X_emg_list, X_et_list = [], [], []
        y_action_list, y_task_list = [], []

        X_psd_list, X_emgfeat_list = [], []
        self.has_features = False

        for split in self.splits:
            split_dir = self.fold_dir / split
            if not split_dir.exists():
                raise FileNotFoundError(f"Missing split dir: {split_dir}")

            shard_paths = sorted(split_dir.glob("*_shard_*.npz"))
            if not shard_paths:
                raise FileNotFoundError(f"No shards in: {split_dir}")

            # (Optional) load features for this split (Phase 5.5)
            X_psd_split = X_emg_split = None
            if attach_features:
                feat_path = CFG.DATASET_DIR / f"features_v1_eeg_psd_full_fold{self.fold_id}_{split}.npz"
                if feat_path.exists():
                    zf = np.load(feat_path, allow_pickle=True)
                    X_psd_split = zf["X_psd"].astype(np.float32)
                    X_emg_split = zf["X_emg"].astype(np.float32)

            # Load shards for this split, preserving shard order
            X_eeg_s, X_emg_s, X_et_s = [], [], []
            y_action_s, y_task_s = [], []

            for sp in shard_paths:
                z = np.load(sp, allow_pickle=True)
                Xeeg = z["X_EEG"]
                Xemg = z["X_EMG"]
                Xet  = z["X_ET"]
                ya   = z["y_action"].astype(np.int64)
                yt   = z["y_task"].astype(np.int64)

                N = yt.shape[0]
                if N == 0:
                    continue

                if self.ssl_mode:
                    keep = np.ones(N, dtype=bool)
                else:
                    keep = np.isin(yt, np.array(CFG.TASK_CODES, dtype=np.int64))

                if not keep.any():
                    continue

                X_eeg_s.append(Xeeg[keep])
                X_emg_s.append(Xemg[keep])
                X_et_s.append(Xet[keep])
                y_action_s.append(ya[keep])
                y_task_s.append(yt[keep])

            if not X_eeg_s:
                continue

            Xeeg = np.concatenate(X_eeg_s, axis=0).astype(np.float32)
            Xemg = np.concatenate(X_emg_s, axis=0).astype(np.float32)
            Xet  = np.concatenate(X_et_s,  axis=0).astype(np.float32)
            ya   = np.concatenate(y_action_s, axis=0).astype(np.int64)
            yt   = np.concatenate(y_task_s,   axis=0).astype(np.int64)

            X_eeg_list.append(Xeeg)
            X_emg_list.append(Xemg)
            X_et_list.append(Xet)
            y_action_list.append(ya)
            y_task_list.append(yt)

            # If features exist for this split, they MUST align with split windows count
            if X_psd_split is not None and X_emg_split is not None:
                if X_psd_split.shape[0] != Xeeg.shape[0] or X_emg_split.shape[0] != Xeeg.shape[0]:
                    print(f"[WARN] Feature mismatch fold{self.fold_id} {split}: "
                          f"features N={X_psd_split.shape[0]} but windows N={Xeeg.shape[0]} -> skipping features.")
                else:
                    X_psd_list.append(X_psd_split)
                    X_emgfeat_list.append(X_emg_split)
                    self.has_features = True

        if not X_eeg_list:
            raise RuntimeError(f"No data found in fold={fold_id} splits={splits} (ssl_mode={ssl_mode}).")

        self.X_eeg = np.concatenate(X_eeg_list, axis=0)
        self.X_emg = np.concatenate(X_emg_list, axis=0)
        self.X_et  = np.concatenate(X_et_list,  axis=0)
        self.y_action = np.concatenate(y_action_list, axis=0)
        self.y_task   = np.concatenate(y_task_list,   axis=0)

        if max_windows is not None and self.X_eeg.shape[0] > max_windows:
            rng = np.random.RandomState(CFG.SEED)
            idx = rng.choice(self.X_eeg.shape[0], size=max_windows, replace=False)
            self.X_eeg = self.X_eeg[idx]
            self.X_emg = self.X_emg[idx]
            self.X_et  = self.X_et[idx]
            self.y_action = self.y_action[idx]
            self.y_task   = self.y_task[idx]
            if self.has_features and X_psd_list and X_emgfeat_list:
                # If we subsample, subsample features too
                Xpsd_all = np.concatenate(X_psd_list, axis=0)
                Xemg_all = np.concatenate(X_emgfeat_list, axis=0)
                self.X_psd = Xpsd_all[idx]
                self.X_emgfeat = Xemg_all[idx]
            else:
                self.has_features = False

        else:
            if self.has_features and X_psd_list and X_emgfeat_list:
                self.X_psd = np.concatenate(X_psd_list, axis=0).astype(np.float32)
                self.X_emgfeat = np.concatenate(X_emgfeat_list, axis=0).astype(np.float32)
                if self.X_psd.shape[0] != self.X_eeg.shape[0]:
                    print("[WARN] After concat, features N mismatch -> disabling features.")
                    self.has_features = False

        self.eeg_ch = self.X_eeg.shape[-1]
        self.emg_ch = self.X_emg.shape[-1]
        self.et_ch  = self.X_et.shape[-1]

        # Fixed mapping to keep inference stable
        self.task2idx = {t: i for i, t in enumerate(CFG.TASK_CODES)}
        self.num_task_classes = len(self.task2idx)

        self.eeg_psd_dim = int(self.X_psd.shape[1]) if self.has_features else 0
        self.emg_feat_dim = int(self.X_emgfeat.shape[1]) if self.has_features else 0

        print(f"[AllSplitsShardDataset] fold={fold_id} splits={splits} ssl_mode={ssl_mode}")
        print(f"  N={len(self)} shapes: EEG={self.X_eeg.shape} EMG={self.X_emg.shape} ET={self.X_et.shape}")
        if self.has_features:
            print(f"  features: PSD={self.X_psd.shape} EMGfeat={self.X_emgfeat.shape}")

    def __len__(self) -> int:
        return int(self.X_eeg.shape[0])

    def __getitem__(self, idx: int):
        x_eeg = torch.from_numpy(self.X_eeg[idx]).float()
        x_emg = torch.from_numpy(self.X_emg[idx]).float()
        x_et  = torch.from_numpy(self.X_et[idx]).float()

        ya = int(self.y_action[idx])
        yt_raw = int(self.y_task[idx])
        action_label = 1 if ya == 1 else 0
        task_label = self.task2idx.get(yt_raw, 0)

        sample = {
            "eeg": x_eeg,
            "emg": x_emg,
            "et":  x_et,
            "action": torch.tensor(action_label, dtype=torch.long),
            "task": torch.tensor(task_label, dtype=torch.long),
            "y_task_raw": torch.tensor(yt_raw, dtype=torch.long),
        }
        if self.has_features:
            sample["eeg_psd"] = torch.from_numpy(self.X_psd[idx]).float()
            sample["emg_feat"] = torch.from_numpy(self.X_emgfeat[idx]).float()
        return sample


def make_loader(ds: Dataset, batch: int, shuffle: bool) -> DataLoader:
    return DataLoader(
        ds, batch_size=batch, shuffle=shuffle,
        num_workers=CFG.NUM_WORKERS, pin_memory=True, drop_last=False
    )


# ---------------- MODEL (same spirit as your current) ----------------

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T = x.size(1)
        return x + self.pe[:, :T]


class EEGTCNGRUEncoder(nn.Module):
    def __init__(self, in_ch: int, d_model: int, dropout: float):
        super().__init__()
        h = d_model
        self.conv1 = nn.Conv1d(in_ch, h, 5, padding=2, dilation=1)
        self.bn1 = nn.BatchNorm1d(h)
        self.conv2 = nn.Conv1d(h, h, 5, padding=4, dilation=2)
        self.bn2 = nn.BatchNorm1d(h)
        self.conv3 = nn.Conv1d(h, h, 5, padding=8, dilation=4)
        self.bn3 = nn.BatchNorm1d(h)
        self.drop = nn.Dropout(dropout)
        self.act = nn.ReLU()
        self.gru = nn.GRU(h, h // 2, num_layers=1, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(h, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.transpose(1, 2)
        x = self.act(self.bn1(self.conv1(x)))
        res = x
        x = self.act(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = self.act(x + res)
        x = self.drop(x)
        x = x.transpose(1, 2)
        x, _ = self.gru(x)
        return self.proj(x)


class EMGTCNGRUEncoder(nn.Module):
    def __init__(self, in_ch: int, d_model: int, dropout: float):
        super().__init__()
        h = d_model
        self.conv1 = nn.Conv1d(in_ch, h, 7, padding=3, dilation=1)
        self.bn1 = nn.BatchNorm1d(h)
        self.conv2 = nn.Conv1d(h, h, 7, padding=6, dilation=2)
        self.bn2 = nn.BatchNorm1d(h)
        self.drop = nn.Dropout(dropout)
        self.act = nn.ReLU()
        self.gru = nn.GRU(h, h // 2, num_layers=1, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(h, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.transpose(1, 2)
        res = self.act(self.bn1(self.conv1(x)))
        x = self.act(self.bn2(self.conv2(res)))
        x = self.drop(x + res)
        x = x.transpose(1, 2)
        x, _ = self.gru(x)
        return self.proj(x)


class EyeTinyGRUEncoder(nn.Module):
    def __init__(self, in_ch: int, d_model: int, dropout: float):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(in_ch, d_model), nn.ReLU(), nn.Dropout(dropout))
        self.gru = nn.GRU(d_model, d_model // 2, num_layers=1, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp(x)
        x, _ = self.gru(x)
        return self.proj(x)


class GatedSelfAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.drop = nn.Dropout(dropout)
        self.n1 = nn.LayerNorm(d_model)
        self.n2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
        )

    def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
        q = x
        k = x * g
        v = x * g
        y, _ = self.mha(q, k, v, need_weights=False)
        x = self.n1(x + self.drop(y))
        x = self.n2(x + self.drop(self.ffn(x)))
        return x


class GatedCrossModalEncoder(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_layers: int, dropout: float):
        super().__init__()
        self.pe = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([GatedSelfAttentionBlock(d_model, n_heads, dropout) for _ in range(n_layers)])

    def forward(self, tokens: torch.Tensor, gates: torch.Tensor) -> torch.Tensor:
        x = self.pe(tokens)
        for layer in self.layers:
            x = layer(x, gates)
        return x


class TriModalSafetyTransformer(nn.Module):
    def __init__(
        self,
        eeg_ch: int, emg_ch: int, et_ch: int,
        num_task_classes: int,
        d_model: int, dropout: float,
        use_eeg_psd: bool, use_emg_feat: bool,
        eeg_psd_dim: int, emg_feat_dim: int,
        eeg_psd_mean: Optional[torch.Tensor], eeg_psd_std: Optional[torch.Tensor],
        emg_feat_mean: Optional[torch.Tensor], emg_feat_std: Optional[torch.Tensor],
    ):
        super().__init__()
        self.d_model = d_model
        self.use_eeg_psd = use_eeg_psd and (eeg_psd_dim > 0)
        self.use_emg_feat = use_emg_feat and (emg_feat_dim > 0)

        self.eeg_enc = EEGTCNGRUEncoder(eeg_ch, d_model, dropout)
        self.emg_enc = EMGTCNGRUEncoder(emg_ch, d_model, dropout)
        self.et_enc  = EyeTinyGRUEncoder(et_ch, d_model, dropout)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.fuse = GatedCrossModalEncoder(d_model, CFG.N_HEADS_FUSE, CFG.N_LAYERS_FUSE, dropout)

        self.gate_mlp = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model * 3),
        )

        if self.use_eeg_psd:
            self.eeg_psd_proj = nn.Sequential(nn.Linear(eeg_psd_dim, d_model), nn.ReLU(), nn.Dropout(dropout))
            self.register_buffer("eeg_psd_mean", (eeg_psd_mean if eeg_psd_mean is not None else torch.zeros(1, eeg_psd_dim)).view(1, -1))
            self.register_buffer("eeg_psd_std",  (eeg_psd_std  if eeg_psd_std  is not None else torch.ones(1, eeg_psd_dim)).view(1, -1))
        if self.use_emg_feat:
            self.emg_feat_proj = nn.Sequential(nn.Linear(emg_feat_dim, d_model), nn.ReLU(), nn.Dropout(dropout))
            self.register_buffer("emg_feat_mean", (emg_feat_mean if emg_feat_mean is not None else torch.zeros(1, emg_feat_dim)).view(1, -1))
            self.register_buffer("emg_feat_std",  (emg_feat_std  if emg_feat_std  is not None else torch.ones(1, emg_feat_dim)).view(1, -1))

        self.action_head = nn.Sequential(
            nn.LayerNorm(d_model), nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model, 2)
        )
        self.task_head = nn.Sequential(
            nn.LayerNorm(d_model), nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model, num_task_classes)
        )

        # SSL decoders + cross-modal heads
        self.dec_eeg = nn.Linear(d_model, eeg_ch)
        self.dec_emg = nn.Linear(d_model, emg_ch)
        self.dec_et  = nn.Linear(d_model, et_ch)

        self.cross_eeg2emg = nn.Linear(d_model, d_model)
        self.cross_eeg2et  = nn.Linear(d_model, d_model)
        self.cross_emg2eeg = nn.Linear(d_model, d_model)
        self.cross_et2eeg  = nn.Linear(d_model, d_model)

    def forward_backbone(self, x_eeg, x_emg, x_et, eeg_psd=None, emg_feat=None):
        z_eeg = self.eeg_enc(x_eeg)
        z_emg = self.emg_enc(x_emg)
        z_et  = self.et_enc(x_et)

        if CFG.POOL_STRIDE > 1:
            s = CFG.POOL_STRIDE
            z_eeg = z_eeg[:, ::s, :]
            z_emg = z_emg[:, ::s, :]
            z_et  = z_et[:,  ::s, :]

        B = z_eeg.size(0)
        Te, Tm, Tt = z_eeg.size(1), z_emg.size(1), z_et.size(1)

        p_eeg = z_eeg.mean(1)
        p_emg = z_emg.mean(1)
        p_et  = z_et.mean(1)

        gates_logits = self.gate_mlp(torch.cat([p_eeg, p_emg, p_et], dim=-1)).view(B, 3, self.d_model)
        gates = torch.softmax(gates_logits, dim=1)  # (B,3,D)

        g_eeg = gates[:, 0, :].unsqueeze(1).expand(-1, Te, -1)
        g_emg = gates[:, 1, :].unsqueeze(1).expand(-1, Tm, -1)
        g_et  = gates[:, 2, :].unsqueeze(1).expand(-1, Tt, -1)

        z_cat = torch.cat([z_eeg, z_emg, z_et], dim=1)
        cls = self.cls_token.expand(B, 1, self.d_model)
        tokens = torch.cat([cls, z_cat], dim=1)

        g_cls = torch.ones(B, 1, self.d_model, device=tokens.device)
        g_tok = torch.cat([g_cls, torch.cat([g_eeg, g_emg, g_et], dim=1)], dim=1)

        out = self.fuse(tokens, g_tok)

        # slice back
        e0, e1 = 1, 1 + Te
        m0, m1 = e1, e1 + Tm
        t0, t1 = m1, m1 + Tt
        ze = out[:, e0:e1, :]
        zm = out[:, m0:m1, :]
        zt = out[:, t0:t1, :]

        cls_eeg = ze.mean(1)
        cls_emg = zm.mean(1)
        cls_et  = zt.mean(1)

        g_e = gates[:, 0, :]
        g_m = gates[:, 1, :]
        g_t = gates[:, 2, :]

        z_cls = g_e * cls_eeg + g_m * cls_emg + g_t * cls_et

        # Optional Phase 5.5 features
        add_feats = []
        if self.use_eeg_psd and eeg_psd is not None:
            x = (eeg_psd - self.eeg_psd_mean) / (self.eeg_psd_std + 1e-6)
            add_feats.append(self.eeg_psd_proj(x))
        if self.use_emg_feat and emg_feat is not None:
            x = (emg_feat - self.emg_feat_mean) / (self.emg_feat_std + 1e-6)
            add_feats.append(self.emg_feat_proj(x))
        if add_feats:
            z_cls = z_cls + torch.stack(add_feats, dim=0).mean(0)

        return ze, zm, zt, z_cls, gates, cls_eeg, cls_emg, cls_et

    def forward_supervised(self, x_eeg, x_emg, x_et, eeg_psd=None, emg_feat=None):
        _, _, _, z_cls, gates, _, _, _ = self.forward_backbone(x_eeg, x_emg, x_et, eeg_psd, emg_feat)
        return self.action_head(z_cls), self.task_head(z_cls), gates

    def forward_ssl(self, x_eeg, x_emg, x_et):
        ze, zm, zt, z_cls, gates, ce, cm, ct = self.forward_backbone(x_eeg, x_emg, x_et, None, None)
        return self.dec_eeg(ze), self.dec_emg(zm), self.dec_et(zt), z_cls, gates, ce, cm, ct


# ---------------- EMA ----------------

class EMA:
    def __init__(self, model: nn.Module, decay: float):
        self.decay = float(decay)
        self.shadow = {}
        for k, v in model.state_dict().items():
            if torch.is_floating_point(v):
                self.shadow[k] = v.detach().clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        msd = model.state_dict()
        d = self.decay
        for k, v in self.shadow.items():
            nv = msd[k]
            if torch.is_floating_point(nv):
                v.mul_(d).add_(nv.detach(), alpha=1.0 - d)

    def copy_to(self, model: nn.Module):
        msd = model.state_dict()
        for k, v in self.shadow.items():
            if k in msd and torch.is_floating_point(msd[k]):
                msd[k].copy_(v)


# ---------------- SSL HELPERS ----------------

def apply_ssl_mask(x: torch.Tensor, p: float) -> Tuple[torch.Tensor, torch.Tensor]:
    B, T, C = x.shape
    mask = (torch.rand(B, T, 1, device=x.device) < p).float()
    return x * (1.0 - mask), mask

def apply_modality_dropout(x: torch.Tensor, p: float) -> torch.Tensor:
    if p <= 0:
        return x
    B = x.size(0)
    m = (torch.rand(B, 1, 1, device=x.device) < p).float()
    return x * (1.0 - m)

def ssl_recon_loss(x_hat: torch.Tensor, x_true: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    # align if pooled
    Th = x_hat.size(1)
    Tt = x_true.size(1)
    if Th != Tt:
        if Tt % Th == 0:
            f = Tt // Th
            x_true = x_true[:, ::f, :]
            mask = mask[:, ::f, :]
        else:
            x_true = x_true.transpose(1,2)
            x_true = F.interpolate(x_true, size=Th, mode="linear", align_corners=False).transpose(1,2)
            mask = mask.transpose(1,2)
            mask = F.interpolate(mask, size=Th, mode="nearest").transpose(1,2)
    m = mask.expand_as(x_hat)
    diff2 = (x_hat - x_true) ** 2
    return (diff2 * m).sum() / (m.sum() + 1e-8)

def contrastive_loss(z1: torch.Tensor, z2: torch.Tensor, temp: float) -> torch.Tensor:
    if z1.size(0) < 2:
        return torch.tensor(0.0, device=z1.device)
    z1 = F.normalize(z1, dim=-1)
    z2 = F.normalize(z2, dim=-1)
    logits = (z1 @ z2.t()) / temp
    labels = torch.arange(z1.size(0), device=z1.device)
    return 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels))

def xmod_loss(model: TriModalSafetyTransformer, ce, cm, ct) -> torch.Tensor:
    # predict other modality CLS from each CLS (MSE)
    te = ce.detach(); tm = cm.detach(); tt = ct.detach()
    loss = 0.0
    loss += F.mse_loss(model.cross_eeg2emg(ce), tm)
    loss += F.mse_loss(model.cross_eeg2et(ce),  tt)
    loss += F.mse_loss(model.cross_emg2eeg(cm), te)
    loss += F.mse_loss(model.cross_et2eeg(ct),  te)
    return loss / 4.0


# ---------------- TRAINING: SSL ----------------

def pretrain_ssl(fold_id: int, eeg_ch: int, emg_ch: int, et_ch: int) -> Dict[str, torch.Tensor]:
    print("\n================ Stage 1: SSL pretraining (final) ================")
    fold_dir = CFG.DATASET_DIR / f"{CFG.SSL_PREFIX}_fold{fold_id}"
    ssl_ds = AllSplitsShardDataset(
        fold_dir=fold_dir, fold_id=fold_id, splits=["train","val","test"],
        ssl_mode=True, max_windows=CFG.SSL_MAX_WINDOWS, attach_features=False
    )
    loader = make_loader(ssl_ds, CFG.SSL_BATCH, shuffle=True)

    dummy_task_classes = len(CFG.TASK_CODES)
    model = TriModalSafetyTransformer(
        eeg_ch, emg_ch, et_ch, dummy_task_classes,
        CFG.D_MODEL, CFG.DROPOUT,
        use_eeg_psd=False, use_emg_feat=False,
        eeg_psd_dim=0, emg_feat_dim=0,
        eeg_psd_mean=None, eeg_psd_std=None,
        emg_feat_mean=None, emg_feat_std=None
    ).to(CFG.DEVICE)

    opt = torch.optim.AdamW(model.parameters(), lr=CFG.SSL_LR, weight_decay=CFG.WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=(CFG.USE_AMP and CFG.DEVICE.startswith("cuda")))

    for ep in range(1, CFG.SSL_EPOCHS + 1):
        model.train()
        tot = 0.0
        n = 0
        for batch in loader:
            x_eeg = batch["eeg"].to(CFG.DEVICE)
            x_emg = batch["emg"].to(CFG.DEVICE)
            x_et  = batch["et"].to(CFG.DEVICE)

            x1_eeg, m1_eeg = apply_ssl_mask(x_eeg, CFG.SSL_MASK_PROB)
            x1_emg, m1_emg = apply_ssl_mask(x_emg, CFG.SSL_MASK_PROB)
            x1_et,  m1_et  = apply_ssl_mask(x_et,  CFG.SSL_MASK_PROB)

            x2_eeg, _ = apply_ssl_mask(x_eeg, CFG.SSL_MASK_PROB)
            x2_emg, _ = apply_ssl_mask(x_emg, CFG.SSL_MASK_PROB)
            x2_et,  _ = apply_ssl_mask(x_et,  CFG.SSL_MASK_PROB)

            x2_eeg = apply_modality_dropout(x2_eeg, CFG.SSL_MODAL_DROPOUT)
            x2_emg = apply_modality_dropout(x2_emg, CFG.SSL_MODAL_DROPOUT)
            x2_et  = apply_modality_dropout(x2_et,  CFG.SSL_MODAL_DROPOUT)

            opt.zero_grad(set_to_none=True)

            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=scaler.is_enabled()):
                xh_eeg, xh_emg, xh_et, z1, gates, ce, cm, ct = model.forward_ssl(x1_eeg, x1_emg, x1_et)
                _, _, _, z2, _, _, _, _ = model.forward_backbone(x2_eeg, x2_emg, x2_et, None, None)

                # upweight EEG+EMG recon
                loss_mask = (
                    1.6 * ssl_recon_loss(xh_eeg, x_eeg, m1_eeg)
                    + 1.6 * ssl_recon_loss(xh_emg, x_emg, m1_emg)
                    + 1.0 * ssl_recon_loss(xh_et,  x_et,  m1_et)
                )
                loss_ctr = contrastive_loss(z1, z2, CFG.SSL_TEMP)
                loss_xm  = xmod_loss(model, ce, cm, ct)

                loss = CFG.LAMBDA_MASK * loss_mask + CFG.LAMBDA_CONTRAST * loss_ctr + CFG.LAMBDA_XMOD * loss_xm

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
            scaler.step(opt)
            scaler.update()

            tot += float(loss.item())
            n += 1

        print(f"[SSL] ep {ep:02d}/{CFG.SSL_EPOCHS} loss={tot/max(1,n):.4f}")

    return model.state_dict()


# ---------------- LOSS BUILDERS ----------------

def build_action_weights(y_action: np.ndarray) -> torch.Tensor:
    counts = np.bincount(y_action.astype(int), minlength=2).astype(np.float64)
    w = len(y_action) / (2.0 * (counts + 1e-6))
    return torch.tensor(w, dtype=torch.float32, device=CFG.DEVICE)

def build_task_weights(y_task: np.ndarray, y_action: np.ndarray, ncls: int) -> torch.Tensor:
    mask = (y_action.astype(int) == 1)
    yt = y_task[mask].astype(int)
    if yt.size == 0:
        return torch.ones(ncls, dtype=torch.float32, device=CFG.DEVICE)
    counts = np.bincount(yt, minlength=ncls).astype(np.float64)
    w = len(yt) / (ncls * (counts + 1e-6))
    return torch.tensor(w, dtype=torch.float32, device=CFG.DEVICE)


# ---------------- METRICS + THRESHOLD ----------------

def balanced_acc_binary(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    y_true = y_true.astype(int); y_pred = y_pred.astype(int)
    tp = np.sum((y_true==1) & (y_pred==1))
    tn = np.sum((y_true==0) & (y_pred==0))
    fp = np.sum((y_true==0) & (y_pred==1))
    fn = np.sum((y_true==1) & (y_pred==0))
    tpr = tp / (tp + fn + 1e-8)
    tnr = tn / (tn + fp + 1e-8)
    return float(0.5*(tpr+tnr))

@torch.no_grad()
def eval_on_loader(model: nn.Module, loader: DataLoader, base_threshold: float) -> Dict[str, float]:
    model.eval()
    ys_a, ps_a = [], []
    ys_t, ps_t = [], []
    for batch in loader:
        x_eeg = batch["eeg"].to(CFG.DEVICE)
        x_emg = batch["emg"].to(CFG.DEVICE)
        x_et  = batch["et"].to(CFG.DEVICE)
        y_a   = batch["action"].cpu().numpy()
        y_t   = batch["task"].cpu().numpy()

        eeg_psd = batch.get("eeg_psd")
        emg_feat = batch.get("emg_feat")
        if eeg_psd is not None: eeg_psd = eeg_psd.to(CFG.DEVICE)
        if emg_feat is not None: emg_feat = emg_feat.to(CFG.DEVICE)

        logits_a, logits_t, _ = model.forward_supervised(x_eeg, x_emg, x_et, eeg_psd=eeg_psd, emg_feat=emg_feat)
        p_act = torch.softmax(logits_a, dim=1)[:,1].detach().cpu().numpy()
        pred_a = (p_act >= base_threshold).astype(int)

        # task pred (force rest if no-move)
        lt = logits_t.detach().cpu().numpy()
        pred_t = lt.argmax(axis=1).astype(int)
        pred_t[pred_a == 0] = 0

        ys_a.append(y_a); ps_a.append(pred_a)
        ys_t.append(y_t); ps_t.append(pred_t)

    yA = np.concatenate(ys_a); pA = np.concatenate(ps_a)
    yT = np.concatenate(ys_t); pT = np.concatenate(ps_t)

    a_ba = balanced_acc_binary(yA, pA)

    # task balanced acc computed on ACTION GT only (more meaningful)
    mask = (yA == 1)
    if mask.any():
        # balanced acc for multiclass (manual)
        yt = yT[mask]; pt = pT[mask]
        cls = np.unique(yt)
        recs = []
        for c in cls:
            den = np.sum(yt == c)
            recs.append(np.sum((yt==c) & (pt==c)) / (den + 1e-8))
        t_ba = float(np.mean(recs)) if recs else 0.0
    else:
        t_ba = 0.0

    return {"action_bal_acc": a_ba, "task_bal_acc_action_only": t_ba}

@torch.no_grad()
def tune_threshold(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    probs, ys = [], []
    for batch in loader:
        x_eeg = batch["eeg"].to(CFG.DEVICE)
        x_emg = batch["emg"].to(CFG.DEVICE)
        x_et  = batch["et"].to(CFG.DEVICE)
        y_a   = batch["action"].cpu().numpy()

        eeg_psd = batch.get("eeg_psd")
        emg_feat = batch.get("emg_feat")
        if eeg_psd is not None: eeg_psd = eeg_psd.to(CFG.DEVICE)
        if emg_feat is not None: emg_feat = emg_feat.to(CFG.DEVICE)

        logits_a, _, _ = model.forward_supervised(x_eeg, x_emg, x_et, eeg_psd=eeg_psd, emg_feat=emg_feat)
        p = torch.softmax(logits_a, dim=1)[:,1].cpu().numpy()
        probs.append(p); ys.append(y_a)

    p = np.concatenate(probs); y = np.concatenate(ys)
    best_t, best_ba = 0.5, -1.0
    for t in CFG.THRESH_GRID:
        pred = (p >= t).astype(int)
        ba = balanced_acc_binary(y, pred)
        if ba > best_ba:
            best_ba, best_t = ba, float(t)
    print(f"[threshold] best τ={best_t:.2f} val_action_bal_acc={best_ba:.3f}")
    return best_t


# ---------------- LR SCHEDULER (warmup + cosine) ----------------

class WarmupCosine:
    def __init__(self, optimizer, warmup_steps: int, total_steps: int, min_lr_scale: float = 0.05):
        self.opt = optimizer
        self.warmup = max(1, warmup_steps)
        self.total = max(self.warmup + 1, total_steps)
        self.min_lr_scale = float(min_lr_scale)
        self.step_i = 0
        self.base_lrs = [g["lr"] for g in optimizer.param_groups]

    def step(self):
        self.step_i += 1
        if self.step_i <= self.warmup:
            scale = self.step_i / self.warmup
        else:
            t = (self.step_i - self.warmup) / (self.total - self.warmup)
            scale = self.min_lr_scale + 0.5 * (1.0 - self.min_lr_scale) * (1.0 + math.cos(math.pi * t))
        for lr0, g in zip(self.base_lrs, self.opt.param_groups):
            g["lr"] = lr0 * scale


# ---------------- FINAL TRAIN (ONE MODEL) ----------------

def train_final_allsubjects(fold_id: int, out_dir: Path) -> Tuple[Path, Path]:
    print("\n================ FINAL: Train ONE all-subjects model ================")
    print("Fold used:", fold_id)

    fold_bal = CFG.DATASET_DIR / f"{CFG.BAL_PREFIX}_fold{fold_id}"
    ds_all = AllSplitsShardDataset(
        fold_dir=fold_bal, fold_id=fold_id, splits=["train","val","test"],
        ssl_mode=False, max_windows=None, attach_features=True
    )
    eeg_ch, emg_ch, et_ch = ds_all.eeg_ch, ds_all.emg_ch, ds_all.et_ch

    # Build train/val split (stratify by task)
    y_task = ds_all.y_task.copy()
    tr_idx, va_idx = stratified_split_indices(y_task, CFG.VAL_FRAC, CFG.SEED)
    ds_tr = Subset(ds_all, tr_idx.tolist())
    ds_va = Subset(ds_all, va_idx.tolist())

    tr_loader = make_loader(ds_tr, CFG.SUP_BATCH, shuffle=True)
    va_loader = make_loader(ds_va, CFG.SUP_BATCH, shuffle=False)

    # Features stats from TRAIN only
    use_feats = ds_all.has_features and CFG.USE_EEG_PSD_FEATURES and CFG.USE_EMG_FEATURES
    eeg_psd_dim = ds_all.eeg_psd_dim if use_feats else 0
    emg_feat_dim = ds_all.emg_feat_dim if use_feats else 0

    eeg_psd_mean_t = eeg_psd_std_t = None
    emg_feat_mean_t = emg_feat_std_t = None

    if use_feats:
        Xpsd_tr = ds_all.X_psd[tr_idx].astype(np.float32)
        Xemg_tr = ds_all.X_emgfeat[tr_idx].astype(np.float32)
        eeg_psd_mean_t = torch.from_numpy(Xpsd_tr.mean(0)).float().to(CFG.DEVICE)
        eeg_psd_std_t  = torch.from_numpy(Xpsd_tr.std(0) + 1e-6).float().to(CFG.DEVICE)
        emg_feat_mean_t = torch.from_numpy(Xemg_tr.mean(0)).float().to(CFG.DEVICE)
        emg_feat_std_t  = torch.from_numpy(Xemg_tr.std(0) + 1e-6).float().to(CFG.DEVICE)

    # Build model
    model = TriModalSafetyTransformer(
        eeg_ch, emg_ch, et_ch, num_task_classes=len(CFG.TASK_CODES),
        d_model=CFG.D_MODEL, dropout=CFG.DROPOUT,
        use_eeg_psd=use_feats, use_emg_feat=use_feats,
        eeg_psd_dim=eeg_psd_dim, emg_feat_dim=emg_feat_dim,
        eeg_psd_mean=eeg_psd_mean_t, eeg_psd_std=eeg_psd_std_t,
        emg_feat_mean=emg_feat_mean_t, emg_feat_std=emg_feat_std_t,
    ).to(CFG.DEVICE)

    # Optional SSL init
    if CFG.USE_SSL:
        print("[FINAL] SSL pretraining enabled…")
        base_state = pretrain_ssl(fold_id, eeg_ch, emg_ch, et_ch)
        miss, unexp = model.load_state_dict(base_state, strict=False)
        print(f"[FINAL] Loaded SSL backbone. Missing={len(miss)} Unexpected={len(unexp)}")
    else:
        print("[FINAL] SSL disabled → random init.")

    # Freeze backbone warm-start
    def set_backbone_trainable(flag: bool):
        for name, p in model.named_parameters():
            if name.startswith("action_head") or name.startswith("task_head"):
                p.requires_grad = True
            else:
                p.requires_grad = flag

    # Loss weights from TRAIN subset (using underlying arrays)
    y_action_tr = ds_all.y_action[tr_idx]
    y_task_tr   = np.array([ds_all.task2idx.get(int(t), 0) for t in ds_all.y_task[tr_idx]], dtype=np.int64)

    w_action = build_action_weights(y_action_tr)
    w_task   = build_task_weights(y_task_tr, y_action_tr, ncls=len(CFG.TASK_CODES))

    crit_action = nn.CrossEntropyLoss(weight=w_action)
    crit_task   = nn.CrossEntropyLoss(weight=w_task)

    head_params, backbone_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("action_head.") or n.startswith("task_head."):
            head_params.append(p)
        else:
            backbone_params.append(p)

    opt = torch.optim.AdamW(
        [
            {"params": backbone_params, "lr": CFG.SUP_LR * CFG.BACKBONE_LR_SCALE},
            {"params": head_params,     "lr": CFG.SUP_LR},
        ],
        weight_decay=CFG.WEIGHT_DECAY,
    )

    # Steps for scheduler
    steps_per_epoch = max(1, len(tr_loader))
    total_steps = CFG.SUP_EPOCHS * steps_per_epoch
    warmup_steps = int(0.08 * total_steps)
    sched = WarmupCosine(opt, warmup_steps=warmup_steps, total_steps=total_steps, min_lr_scale=0.08)

    scaler = torch.cuda.amp.GradScaler(enabled=(CFG.USE_AMP and CFG.DEVICE.startswith("cuda")))
    ema = EMA(model, CFG.EMA_DECAY) if CFG.USE_EMA else None

    best_score = -1e9
    best_state = None
    best_tau = 0.5

    global_step = 0
    for ep in range(1, CFG.SUP_EPOCHS + 1):
        model.train()

        # backbone freeze warm-start
        if ep <= CFG.FREEZE_BACKBONE_EPOCHS:
            set_backbone_trainable(False)
        else:
            set_backbone_trainable(True)

        t0 = time.time()
        run_loss = 0.0
        n_samples = 0

        for batch in tr_loader:
            global_step += 1

            x_eeg = batch["eeg"].to(CFG.DEVICE)
            x_emg = batch["emg"].to(CFG.DEVICE)
            x_et  = batch["et"].to(CFG.DEVICE)
            y_a   = batch["action"].to(CFG.DEVICE)
            y_t   = batch["task"].to(CFG.DEVICE)

            eeg_psd = batch.get("eeg_psd")
            emg_feat = batch.get("emg_feat")
            if eeg_psd is not None: eeg_psd = eeg_psd.to(CFG.DEVICE)
            if emg_feat is not None: emg_feat = emg_feat.to(CFG.DEVICE)

            # supervised ET dropout
            if CFG.SUP_ET_DROPOUT > 0:
                x_et = apply_modality_dropout(x_et, CFG.SUP_ET_DROPOUT)

            opt.zero_grad(set_to_none=True)

            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=scaler.is_enabled()):
                la, lt, _ = model.forward_supervised(x_eeg, x_emg, x_et, eeg_psd=eeg_psd, emg_feat=emg_feat)
                loss_a = crit_action(la, y_a)

                mask = (y_a == 1)
                if mask.any():
                    loss_t = crit_task(lt[mask], y_t[mask])
                else:
                    loss_t = torch.tensor(0.0, device=CFG.DEVICE)

                loss = CFG.ALPHA_ACTION * loss_a + CFG.BETA_TASK * loss_t

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
            scaler.step(opt)
            scaler.update()
            sched.step()

            if ema is not None:
                ema.update(model)

            B = int(y_a.size(0))
            run_loss += float(loss.item()) * B
            n_samples += B

        # Validate using EMA weights if enabled
        val_model = model
        if ema is not None:
            tmp = TriModalSafetyTransformer(
                ds_all.eeg_ch, ds_all.emg_ch, ds_all.et_ch, len(CFG.TASK_CODES),
                CFG.D_MODEL, CFG.DROPOUT,
                use_eeg_psd=use_feats, use_emg_feat=use_feats,
                eeg_psd_dim=eeg_psd_dim, emg_feat_dim=emg_feat_dim,
                eeg_psd_mean=eeg_psd_mean_t, eeg_psd_std=eeg_psd_std_t,
                emg_feat_mean=emg_feat_mean_t, emg_feat_std=emg_feat_std_t,
            ).to(CFG.DEVICE)
            tmp.load_state_dict(model.state_dict(), strict=True)
            ema.copy_to(tmp)
            val_model = tmp

        # quick threshold=0.5 for metric scoring
        metrics = eval_on_loader(val_model, va_loader, base_threshold=0.5)
        score = 0.35 * metrics["action_bal_acc"] + 0.65 * metrics["task_bal_acc_action_only"]

        print(
            f"[FINAL] ep {ep:02d}/{CFG.SUP_EPOCHS} "
            f"train_loss={run_loss/max(1,n_samples):.4f} "
            f"val_action_BA={metrics['action_bal_acc']:.3f} "
            f"val_task_BA(action)={metrics['task_bal_acc_action_only']:.3f} "
            f"score={score:.3f} "
            f"time={time.time()-t0:.1f}s"
        )

        if score > best_score + 1e-4:
            best_score = score
            # tune threshold on best model snapshot
            tau = tune_threshold(val_model, va_loader)
            best_tau = float(tau)
            best_state = {k: v.detach().cpu().clone() for k, v in val_model.state_dict().items()}

    if best_state is None:
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    out_dir.mkdir(parents=True, exist_ok=True)
    ckpt_path = out_dir / f"best_model_allsubjects_fold{fold_id}.pt"
    stats_path = out_dir / f"stats_allsubjects_fold{fold_id}.json"

    torch.save(best_state, ckpt_path)

    stats = {
        "fold_id_used": int(fold_id),
        "base_threshold": float(best_tau),
        "task2idx": {str(k): int(v) for k, v in ds_all.task2idx.items()},
        "idx2task": {str(v): int(k) for k, v in ds_all.task2idx.items()},
        "eeg_ch": int(ds_all.eeg_ch),
        "emg_ch": int(ds_all.emg_ch),
        "et_ch": int(ds_all.et_ch),
        "use_features": bool(use_feats),
        "eeg_psd_dim": int(eeg_psd_dim),
        "emg_feat_dim": int(emg_feat_dim),
        "cfg": {
            "d_model": CFG.D_MODEL,
            "dropout": CFG.DROPOUT,
            "pool_stride": CFG.POOL_STRIDE,
            "n_heads_fuse": CFG.N_HEADS_FUSE,
            "n_layers_fuse": CFG.N_LAYERS_FUSE,
        }
    }
    if use_feats:
        stats["eeg_psd_mean"] = eeg_psd_mean_t.detach().cpu().numpy().tolist()
        stats["eeg_psd_std"]  = eeg_psd_std_t.detach().cpu().numpy().tolist()
        stats["emg_feat_mean"] = emg_feat_mean_t.detach().cpu().numpy().tolist()
        stats["emg_feat_std"]  = emg_feat_std_t.detach().cpu().numpy().tolist()

    stats_path.write_text(json.dumps(stats, indent=2))
    print("\n[SAVED]")
    print("  ckpt :", ckpt_path)
    print("  stats:", stats_path)
    print(f"  best τ={best_tau:.2f}")

    return ckpt_path, stats_path


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--fold", type=int, default=-1, help="Which fold to use as the 'all-data' source (train+val+test).")
    ap.add_argument("--out", type=str, default=str(CFG.DATASET_DIR / "checkpoints_allsubjects"),
                    help="Output directory for final checkpoint + stats json.")
    args, _ = ap.parse_known_args()

    folds = discover_folds(CFG.BAL_PREFIX)
    if not folds:
        raise SystemExit(f"No balanced folds found under {CFG.DATASET_DIR} with prefix {CFG.BAL_PREFIX}_fold*")

    fold_id = args.fold
    if fold_id < 0:
        fold_id = folds[0]  # any fold works; union train+val+test covers all subjects
    if fold_id not in folds:
        raise SystemExit(f"Requested fold={fold_id} not found. Available folds: {folds}")

    out_dir = Path(args.out)
    train_final_allsubjects(fold_id=fold_id, out_dir=out_dir)


if __name__ == "__main__":
    main()
