In [None]:
# =====================================================================
# Phase 6 (BioRob RQ) — SSL → Safety-Aware Tri-Modal LOSO Trainer
# ---------------------------------------------------------------------

# =====================================================================

from __future__ import annotations
import time
import os
import math
import json
import random
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from sklearn.metrics import (
    accuracy_score,
    f1_score,
    balanced_accuracy_score,
    roc_auc_score,
    average_precision_score,
    confusion_matrix,
)

import csv


# ---------------- CONFIG ----------------

class CFG:
    # Paths (must match Phase 5 exports)
    ROOT_DIR = Path(r"/home/tsultan1/BioRob(Final)/Data")
    DATASET_DIR = ROOT_DIR / "_dataset_icml_v1"
    print("Dataset dir:", DATASET_DIR)

    # Phase-5 export prefixes
    SSL_PREFIX      = "exports_v1_ssl"        # unbalanced, for SSL stage
    BALANCED_PREFIX = "exports_v1_balanced"   # balanced, for supervised LOSO

    # Task codes used (with rest)
    TASK_CODES      = [0, 1, 2, 3, 4, 5]

    # Device / loader
    DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
    NUM_WORKERS = 2

    # ---------------- SSL training (Stage 1) ----------------
    USE_SSL          = True
    SSL_EPOCHS       = 30          # was 20 (or 10); more SSL → better backbone
    SSL_BATCH        = 64
    SSL_LR           = 1e-3
    SSL_MAX_WINDOWS  = 3000        # was 1500; use more SSL windows for final run
    SSL_MASK_PROB    = 0.15
    SSL_TEMPERATURE  = 0.1
    SSL_MODALITY_DROPOUT = 0.25

    
    # ---------------- Supervised fine-tuning (Stage 2) ----------------
    SUP_EPOCHS        = 40         # was 30
    SUP_BATCH         = 64
    SUP_LR            = 1e-4
    WEIGHT_DECAY      = 1e-4
    PATIENCE          = 7          # was 5 → allow more epochs before early stop
    BACKBONE_LR_SCALE = 0.1

    # Relative weight of action-vs-task losses
    ALPHA_ACTION      = 0.4   # weight for action loss
    BETA_TASK         = 0.6   # weight for task loss
    
    
    # Architecture
    D_MODEL           = 128
    DROPOUT           = 0.2
    N_HEADS_FUSE      = 4
    N_LAYERS_FUSE     = 2

    # Optional temporal pooling before fusion
    POOL_STRIDE       = 2          # was 4 → keep more temporal detail for tasks

    # SSL loss weights
    LAMBDA_MASK       = 1.0
    LAMBDA_CONTRAST   = 0.5        # was 0.3 → stronger cross-view consistency
    LAMBDA_ORDER      = 0.0
    LAMBDA_GATE       = 0.0
    LAMBDA_XMOD       = 0.3        # was 0.2 → slightly stronger cross-modal link

    # Supervised ET dropout
    SUP_ET_DROPOUT    = 0.1        # was 0.2 → ET still regularized but less brutal

    # Metrics
    TOPK              = (1, 3)

    # Policy / trade-off grid for P2
    P2_THRESHOLDS     = [round(x, 2) for x in np.linspace(0.1, 0.9, 17)]

    # Random seed
    SEED              = 42

    # Feature usage (from Phase 5.5)
    USE_EEG_PSD_FEATURES = True   # use EEG PSD/Hjorth vectors
    USE_EMG_FEATURES     = True   # use EMG feature vectors


# Ablations for MSG (multimodal synergy)
ABLATIONS = {
    "all": {"use_eeg": True,  "use_emg": True,  "use_et": True},
    "eeg": {"use_eeg": True,  "use_emg": False, "use_et": False},
    "emg": {"use_eeg": False, "use_emg": True,  "use_et": False},
    "et":  {"use_eeg": False, "use_emg": False, "use_et": True},
}

# Robustness scenarios S0–S3 (sensor failures)
SCENARIOS = {
    "S0": {"drop_eeg": False, "drop_emg": False, "drop_et": False,
           "description": "All sensors OK"},
    "S1": {"drop_eeg": True,  "drop_emg": False, "drop_et": False,
           "description": "EEG failure (EEG dropped)"},
    "S2": {"drop_eeg": False, "drop_emg": True,  "drop_et": False,
           "description": "EMG failure (EMG dropped)"},
    "S3": {"drop_eeg": False, "drop_emg": False, "drop_et": True,
           "description": "Eye-tracking failure (ET dropped)"},
}

POLICIES = ["P0", "P1", "P2"]  # naive, gating-aware, safety-first


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(CFG.SEED)


# ---------------- POSITIONAL ENCODING (for fusion only) ----------------

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32)
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)   # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T = x.size(1)
        return x + self.pe[:, :T]


# ---------------- BIO-ROB TCN+GRU ENCODERS ----------------

class EEGTCNGRUEncoder(nn.Module):
    """
    BioRob EEG encoder:
      - Temporal ConvNet (dilated conv) + BiGRU
      - Lighter than ICML EEGGATEncoder.
    """
    def __init__(self, in_ch: int, d_model: int, dropout: float = 0.1):
        super().__init__()
        hidden = d_model

        self.conv1 = nn.Conv1d(in_ch, hidden, kernel_size=5, padding=2, dilation=1)
        self.bn1   = nn.BatchNorm1d(hidden)
        self.conv2 = nn.Conv1d(hidden, hidden, kernel_size=5, padding=4, dilation=2)
        self.bn2   = nn.BatchNorm1d(hidden)
        self.conv3 = nn.Conv1d(hidden, hidden, kernel_size=5, padding=8, dilation=4)
        self.bn3   = nn.BatchNorm1d(hidden)

        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

        self.gru = nn.GRU(
            input_size=hidden,
            hidden_size=hidden // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.proj = nn.Linear(hidden, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C)
        x = x.transpose(1, 2)          # (B, C, T)
        x = self.conv1(x); x = self.bn1(x); x = self.relu(x)

        # small residual block to stabilize deeper dilations (helps EEG)
        res = x
        x = self.conv2(x); x = self.bn2(x); x = self.relu(x)
        x = self.conv3(x); x = self.bn3(x)
        x = x + res
        x = self.relu(x)

        x = self.dropout(x)            # (B, hidden, T)

        x = x.transpose(1, 2)          # (B, T, hidden)
        x, _ = self.gru(x)             # (B, T, hidden)
        x = self.proj(x)               # (B, T, d_model)
        return x


class EMGTCNGRUEncoder(nn.Module):
    """
    BioRob EMG encoder: TCN + BiGRU.
    """
    def __init__(self, in_ch: int, d_model: int, dropout: float = 0.1):
        super().__init__()
        hidden = d_model

        self.conv1 = nn.Conv1d(in_ch, hidden, kernel_size=7, padding=3, dilation=1)
        self.bn1   = nn.BatchNorm1d(hidden)
        self.conv2 = nn.Conv1d(hidden, hidden, kernel_size=7, padding=6, dilation=2)
        self.bn2   = nn.BatchNorm1d(hidden)

        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

        self.gru = nn.GRU(
            input_size=hidden,
            hidden_size=hidden // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.proj = nn.Linear(hidden, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C)
        x = x.transpose(1, 2)          # (B, C, T)

        # small residual block here too
        res = self.conv1(x); res = self.bn1(res); res = self.relu(res)
        x = self.conv2(res); x = self.bn2(x); x = self.relu(x)
        x = x + res
        x = self.dropout(x)

        x = x.transpose(1, 2)          # (B, T, hidden)
        x, _ = self.gru(x)             # (B, T, hidden)
        x = self.proj(x)               # (B, T, d_model)
        return x


class EyeTinyGRUEncoder(nn.Module):
    """
    BioRob Eye-tracking encoder: tiny MLP + BiGRU.
    """
    def __init__(self, in_ch: int, d_model: int, dropout: float = 0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_ch, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.gru = nn.GRU(
            input_size=d_model,
            hidden_size=d_model // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C)
        x = self.mlp(x)                # (B, T, d_model)
        x, _ = self.gru(x)             # (B, T, d_model)
        x = self.proj(x)               # (B, T, d_model)
        return x


# ---------------- GATED CROSS-MODAL SELF-ATTENTION ----------------

class GatedSelfAttentionBlock(nn.Module):
    """
    One layer of gated multihead self-attention + FFN.

    - x:        (B, L, D)
    - g_tokens: (B, L, D), per-token, per-feature gates applied to K,V.
    """
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
        )

    def forward(
        self,
        x: torch.Tensor,
        g_tokens: torch.Tensor,
        return_attn: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        Q = x
        K = x * g_tokens
        V = x * g_tokens

        attn_out, attn_weights = self.mha(
            Q, K, V,
            need_weights=return_attn,
            average_attn_weights=True,
        )
        if not return_attn:
            attn_weights = None

        x = self.norm1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        return x, attn_weights


class GatedCrossModalEncoder(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_layers: int, dropout: float = 0.1):
        super().__init__()
        self.pos_encoding = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([
            GatedSelfAttentionBlock(d_model, n_heads, dropout)
            for _ in range(n_layers)
        ])

    def forward(
        self,
        tokens: torch.Tensor,
        g_tokens: torch.Tensor,
        return_attn: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        x = self.pos_encoding(tokens)
        last_attn = None
        for li, layer in enumerate(self.layers):
            want_attn = return_attn and (li == len(self.layers) - 1)
            x, attn = layer(x, g_tokens, return_attn=want_attn)
            if attn is not None:
                last_attn = attn
        return (x, last_attn) if return_attn else (x, None)


# ---------------- TRI-MODAL SAFETY-AWARE MODEL ----------------

class TriModalSafetyTransformer(nn.Module):
    """
    Tri-modal backbone:
      - EEGTCNGRUEncoder, EMGTCNGRUEncoder, EyeTinyGRUEncoder
      - Gated cross-modal self-attention over concatenated tokens
      - Modality-specific CLS embeddings (z_EEG, z_EMG, z_ET)
      - Feature-wise modality gates (g_EEG, g_EMG, g_ET)
      - Fused CLS z_fused = Σ_m g_m ⊙ z_m
      - Optional Phase 5.5 feature fusion (EEG-PSD, EMG features) with z-scoring
      - SSL decoders + cross-modal prediction heads
    """

    def __init__(
        self,
        eeg_ch: int,
        emg_ch: int,
        et_ch: int,
        num_task_classes: int,
        d_model: int = 128,
        dropout: float = 0.2,
        use_eeg_psd: bool = False,
        use_emg_feat: bool = False,
        eeg_psd_dim: int = 0,
        emg_feat_dim: int = 0,
        eeg_psd_mean: Optional[torch.Tensor] = None,
        eeg_psd_std: Optional[torch.Tensor] = None,
        emg_feat_mean: Optional[torch.Tensor] = None,
        emg_feat_std: Optional[torch.Tensor] = None,
    ):
        super().__init__()
        self.d_model = d_model

        self.use_eeg_psd = use_eeg_psd and (eeg_psd_dim > 0)
        self.use_emg_feat = use_emg_feat and (emg_feat_dim > 0)
        self.eeg_psd_dim = eeg_psd_dim
        self.emg_feat_dim = emg_feat_dim

        # Encoders
        self.eeg_enc = EEGTCNGRUEncoder(eeg_ch, d_model, dropout)
        self.emg_enc = EMGTCNGRUEncoder(emg_ch, d_model, dropout)
        self.et_enc  = EyeTinyGRUEncoder(et_ch, d_model, dropout)

        # Fusion encoder
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.fuse_encoder = GatedCrossModalEncoder(
            d_model=d_model,
            n_heads=CFG.N_HEADS_FUSE,
            n_layers=CFG.N_LAYERS_FUSE,
            dropout=dropout,
        )

        # Projections for window-level EEG/EMG features (Phase 5.5)
        if self.use_eeg_psd:
            self.eeg_psd_proj = nn.Sequential(
                nn.Linear(eeg_psd_dim, d_model),
                nn.ReLU(),
                nn.Dropout(dropout),
            )
            # Fold-wise stats (if provided) or default (0,1)
            if eeg_psd_mean is not None and eeg_psd_std is not None:
                self.register_buffer("eeg_psd_mean", eeg_psd_mean.view(1, -1))
                self.register_buffer("eeg_psd_std", eeg_psd_std.view(1, -1))
            else:
                self.register_buffer("eeg_psd_mean", torch.zeros(1, eeg_psd_dim))
                self.register_buffer("eeg_psd_std", torch.ones(1, eeg_psd_dim))

        if self.use_emg_feat:
            self.emg_feat_proj = nn.Sequential(
                nn.Linear(emg_feat_dim, d_model),
                nn.ReLU(),
                nn.Dropout(dropout),
            )
            # Fold-wise stats (if provided) or default (0,1)
            if emg_feat_mean is not None and emg_feat_std is not None:
                self.register_buffer("emg_feat_mean", emg_feat_mean.view(1, -1))
                self.register_buffer("emg_feat_std", emg_feat_std.view(1, -1))
            else:
                self.register_buffer("emg_feat_mean", torch.zeros(1, emg_feat_dim))
                self.register_buffer("emg_feat_std", torch.ones(1, emg_feat_dim))

        # Gating MLP → feature-wise modality gates
        self.gate_mlp = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model * 3),
        )

        # Classification heads
        self.action_head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 2),
        )
        self.task_head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_task_classes),
        )

        # SSL decoders
        self.dec_eeg = nn.Linear(d_model, eeg_ch)
        self.dec_emg = nn.Linear(d_model, emg_ch)
        self.dec_et  = nn.Linear(d_model, et_ch)

        # Temporal order (correct vs shuffled segments)
        self.order_head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 2),
        )

        # Cross-modal prediction (CLS_→CLS_)
        self.cross_eeg2emg = nn.Linear(d_model, d_model)
        self.cross_eeg2et  = nn.Linear(d_model, d_model)
        self.cross_emg2eeg = nn.Linear(d_model, d_model)
        self.cross_et2eeg  = nn.Linear(d_model, d_model)

    # --------- Backbone (shared) ---------

    def forward_backbone(
        self,
        x_eeg: torch.Tensor,
        x_emg: torch.Tensor,
        x_et: torch.Tensor,
        eeg_psd: Optional[torch.Tensor] = None,
        emg_feat: Optional[torch.Tensor] = None,
        return_attn: bool = False,
    ):
        # (B, T, C) → (B, T, d)
        z_eeg = self.eeg_enc(x_eeg)
        z_emg = self.emg_enc(x_emg)
        z_et  = self.et_enc(x_et)

        # Optional temporal pooling to shorten sequences before attention
        pool_s = getattr(CFG, "POOL_STRIDE", 1)
        if pool_s > 1:
            z_eeg = z_eeg[:, ::pool_s, :]
            z_emg = z_emg[:, ::pool_s, :]
            z_et  = z_et[:,  ::pool_s, :]

        B, T_eeg, _ = z_eeg.shape
        _, T_emg, _ = z_emg.shape
        _, T_et,  _ = z_et.shape

        # Pre-fusion pooled features
        p_eeg = z_eeg.mean(dim=1)   # (B, d)
        p_emg = z_emg.mean(dim=1)
        p_et  = z_et.mean(dim=1)

        gates_logits = self.gate_mlp(torch.cat([p_eeg, p_emg, p_et], dim=-1))  # (B, 3d)
        gates_logits = gates_logits.view(-1, 3, self.d_model)                 # (B,3,d)
        gates = torch.softmax(gates_logits, dim=1)                             # (B,3,d)

        g_eeg = gates[:, 0, :].unsqueeze(1)   # (B,1,d)
        g_emg = gates[:, 1, :].unsqueeze(1)
        g_et  = gates[:, 2, :].unsqueeze(1)

        g_eeg_tokens = g_eeg.expand(-1, T_eeg, -1)
        g_emg_tokens = g_emg.expand(-1, T_emg, -1)
        g_et_tokens  = g_et.expand(-1, T_et,  -1)
        g_tokens_no_cls = torch.cat([g_eeg_tokens, g_emg_tokens, g_et_tokens], dim=1)

        # Concatenate tokens and add CLS
        z_cat = torch.cat([z_eeg, z_emg, z_et], dim=1)   # (B, T_all, d)
        cls = self.cls_token.expand(B, 1, self.d_model)
        tokens = torch.cat([cls, z_cat], dim=1)          # (B, 1+T_all, d)

        g_cls = torch.ones(B, 1, self.d_model, device=tokens.device)
        g_tokens = torch.cat([g_cls, g_tokens_no_cls], dim=1)  # (B, 1+T_all, d)

        z_fused_all, attn = self.fuse_encoder(tokens, g_tokens, return_attn=return_attn)

        # Slice modality tokens after fusion
        idx_eeg_start = 1
        idx_eeg_end   = 1 + T_eeg
        idx_emg_start = idx_eeg_end
        idx_emg_end   = idx_emg_start + T_emg
        idx_et_start  = idx_emg_end
        idx_et_end    = idx_et_start + T_et

        z_eeg_post = z_fused_all[:, idx_eeg_start:idx_eeg_end, :]
        z_emg_post = z_fused_all[:, idx_emg_start:idx_emg_end, :]
        z_et_post  = z_fused_all[:, idx_et_start:idx_et_end, :]

        cls_eeg = z_eeg_post.mean(dim=1)
        cls_emg = z_emg_post.mean(dim=1)
        cls_et  = z_et_post.mean(dim=1)

        g_eeg_feat = gates[:, 0, :]
        g_emg_feat = gates[:, 1, :]
        g_et_feat  = gates[:, 2, :]

        z_cls = (
            g_eeg_feat * cls_eeg
            + g_emg_feat * cls_emg
            + g_et_feat  * cls_et
        )

        # Optionally fuse EEG/EMG window-level features (Phase 5.5) with z-scoring
        feat_list = []
        if eeg_psd is not None and hasattr(self, "eeg_psd_proj"):
            if hasattr(self, "eeg_psd_mean") and hasattr(self, "eeg_psd_std"):
                eeg_psd_norm = (eeg_psd - self.eeg_psd_mean) / (self.eeg_psd_std + 1e-6)
            else:
                eeg_psd_norm = eeg_psd
            feat_list.append(self.eeg_psd_proj(eeg_psd_norm))   # (B, d_model)

        if emg_feat is not None and hasattr(self, "emg_feat_proj"):
            if hasattr(self, "emg_feat_mean") and hasattr(self, "emg_feat_std"):
                emg_feat_norm = (emg_feat - self.emg_feat_mean) / (self.emg_feat_std + 1e-6)
            else:
                emg_feat_norm = emg_feat
            feat_list.append(self.emg_feat_proj(emg_feat_norm)) # (B, d_model)

        if feat_list:
            # Simple additive fusion: average projected features and add to z_cls
            feat_fused = torch.stack(feat_list, dim=0).mean(dim=0)  # (B, d_model)
            z_cls = z_cls + feat_fused

        return z_eeg_post, z_emg_post, z_et_post, z_cls, gates, cls_eeg, cls_emg, cls_et, attn

    # --------- Supervised forward ---------

    def forward_supervised(
        self,
        x_eeg: torch.Tensor,
        x_emg: torch.Tensor,
        x_et: torch.Tensor,
        eeg_psd: Optional[torch.Tensor] = None,
        emg_feat: Optional[torch.Tensor] = None,
        return_attn: bool = False,
    ):
        z_eeg, z_emg, z_et, z_cls, gates, cls_eeg, cls_emg, cls_et, attn = self.forward_backbone(
            x_eeg, x_emg, x_et,
            eeg_psd=eeg_psd,
            emg_feat=emg_feat,
            return_attn=return_attn,
        )

        logits_action = self.action_head(z_cls)
        logits_task   = self.task_head(z_cls)
        return logits_action, logits_task, {
            "gates": gates, "cls_eeg": cls_eeg, "cls_emg": cls_emg,
            "cls_et": cls_et, "attn": attn,
        }

    # --------- SSL forwards ---------

    def forward_ssl(
        self,
        x_eeg_masked: torch.Tensor,
        x_emg_masked: torch.Tensor,
        x_et_masked: torch.Tensor,
    ):
        z_eeg, z_emg, z_et, z_cls, gates, cls_eeg, cls_emg, cls_et, _ = self.forward_backbone(
            x_eeg_masked, x_emg_masked, x_et_masked, return_attn=False
        )
        x_hat_eeg = self.dec_eeg(z_eeg)
        x_hat_emg = self.dec_emg(z_emg)
        x_hat_et  = self.dec_et(z_et)
        return x_hat_eeg, x_hat_emg, x_hat_et, {
            "z_cls": z_cls,
            "gates": gates,
            "cls_eeg": cls_eeg,
            "cls_emg": cls_emg,
            "cls_et":  cls_et,
        }

    def forward_order(
        self,
        x_eeg: torch.Tensor,
        x_emg: torch.Tensor,
        x_et: torch.Tensor,
    ):
        _, _, _, z_cls, _, _, _, _, _ = self.forward_backbone(
            x_eeg, x_emg, x_et, return_attn=False
        )
        logits_order = self.order_head(z_cls)
        return logits_order

    
    
# ---------------- DATASET: PHASE 5 EXPORTS ----------------

class ShardWindowDataset(Dataset):
    """
    Reads NPZ shards from:
        _dataset_icml_v1/{prefix}_foldK/{train,val,test}/split_shard_*.npz

    Keys:
        X_EEG: (N, T, C_eeg), X_EMG: (N, T, C_emg), X_ET: (N, T, C_et)
        y_action ∈ {0,1}, y_task ∈ {0,1,2,4,6,8,...}
    """

    def __init__(
        self,
        fold_dir: Path,
        split: str,
        ssl_mode: bool = False,
        max_windows: Optional[int] = None,
        task2idx: Optional[Dict[int, int]] = None,
    ):
        super().__init__()
        self.fold_dir = Path(fold_dir)
        self.split = split
        self.ssl_mode = ssl_mode

        split_dir = self.fold_dir / split
        if not split_dir.exists():
            raise FileNotFoundError(f"Split dir not found: {split_dir}")

        shard_paths = sorted(split_dir.glob("*_shard_*.npz"))
        if not shard_paths:
            raise FileNotFoundError(f"No shard npz files in {split_dir}")

        X_eeg_list, X_emg_list, X_et_list = [], [], []
        y_action_list, y_task_list = [], []

        for shard_path in shard_paths:
            z = np.load(shard_path, allow_pickle=True)
            X_eeg = z["X_EEG"]
            X_emg = z["X_EMG"]
            X_et  = z["X_ET"]
            y_action = z["y_action"]
            y_task   = z["y_task"]

            N = y_task.shape[0]
            if N == 0:
                continue

            if ssl_mode:
                mask_keep = np.ones(N, dtype=bool)
            else:
                mask_keep = np.isin(y_task, CFG.TASK_CODES)

            if not mask_keep.any():
                continue

            X_eeg_list.append(X_eeg[mask_keep])
            X_emg_list.append(X_emg[mask_keep])
            X_et_list.append(X_et[mask_keep])
            y_action_list.append(y_action[mask_keep])
            y_task_list.append(y_task[mask_keep])

        if not X_eeg_list:
            raise RuntimeError(
                f"No windows loaded for {fold_dir} / {split} (ssl_mode={ssl_mode})"
            )

        self.X_eeg = np.concatenate(X_eeg_list, axis=0)
        self.X_emg = np.concatenate(X_emg_list, axis=0)
        self.X_et  = np.concatenate(X_et_list,  axis=0)
        self.y_action = np.concatenate(y_action_list, axis=0)
        self.y_task   = np.concatenate(y_task_list,   axis=0)

        if max_windows is not None and self.X_eeg.shape[0] > max_windows:
            rng = np.random.RandomState(CFG.SEED)
            idx = rng.choice(self.X_eeg.shape[0], size=max_windows, replace=False)
            self.X_eeg    = self.X_eeg[idx]
            self.X_emg    = self.X_emg[idx]
            self.X_et     = self.X_et[idx]
            self.y_action = self.y_action[idx]
            self.y_task   = self.y_task[idx]

        self.eeg_ch = self.X_eeg.shape[-1]
        self.emg_ch = self.X_emg.shape[-1]
        self.et_ch  = self.X_et.shape[-1]

        uniq_tasks = sorted(np.unique(self.y_task))
        if task2idx is not None:
            self.task2idx = dict(task2idx)
        else:
            desired = CFG.TASK_CODES
            present_desired = [t for t in desired if t in uniq_tasks]
            self.task2idx = {t: i for i, t in enumerate(present_desired)}
        self.num_task_classes = len(self.task2idx)

        print(
            f"[ShardWindowDataset] {split_dir} | N={self.X_eeg.shape[0]}, "
            f"ssl_mode={self.ssl_mode}, tasks={sorted(np.unique(self.y_task))}"
        )
        print(
            f"  Shapes: X_eeg={self.X_eeg.shape}, X_emg={self.X_emg.shape}, X_et={self.X_et.shape}"
        )

    def __len__(self) -> int:
        return self.X_eeg.shape[0]

    def __getitem__(self, idx: int):
        x_eeg = torch.from_numpy(self.X_eeg[idx]).float()
        x_emg = torch.from_numpy(self.X_emg[idx]).float()
        x_et  = torch.from_numpy(self.X_et[idx]).float()

        y_action_raw = int(self.y_action[idx])  # 0 or 1
        y_task_raw   = int(self.y_task[idx])    # 0 or task-code

        action_label = 1 if y_action_raw == 1 else 0
        task_label = self.task2idx.get(y_task_raw, -1)

        sample = {
            "eeg": x_eeg,
            "emg": x_emg,
            "et":  x_et,
            "action": torch.tensor(action_label, dtype=torch.long),
            "task":   torch.tensor(task_label,   dtype=torch.long),
            "y_task_raw": torch.tensor(y_task_raw, dtype=torch.long),
        }

        # Optional Phase 5.5 features (only present for balanced supervised folds)
        if hasattr(self, "X_eeg_psd"):
            eeg_psd = torch.from_numpy(self.X_eeg_psd[idx]).float()
            sample["eeg_psd"] = eeg_psd
        if hasattr(self, "X_emg_feat"):
            emg_feat = torch.from_numpy(self.X_emg_feat[idx]).float()
            sample["emg_feat"] = emg_feat

        return sample


def attach_features_to_dataset(ds, fold_id: int, split: str):
    """
    Attach EEG-PSD and EMG feature vectors from Phase 5.5 to a balanced dataset.

    Expects files:
        features_v1_eeg_psd_full_fold{fold_id}_{split}.npz
    with keys: X_psd (N, F_psd), X_emg (N, F_emg),
    where N == len(ds).
    """
    feat_path = CFG.DATASET_DIR / f"features_v1_eeg_psd_full_fold{fold_id}_{split}.npz"
    if not feat_path.exists():
        print(f"[attach_features] No feature file for fold={fold_id}, split={split}: {feat_path}")
        return

    z = np.load(feat_path, allow_pickle=True)
    X_psd = z["X_psd"]   # (N, F_psd)
    X_emg = z["X_emg"]   # (N, F_emg)

    if X_psd.shape[0] != len(ds) or X_emg.shape[0] != len(ds):
        print(
            f"[attach_features] MISMATCH fold={fold_id}, split={split}: "
            f"len(ds)={len(ds)}, X_psd={X_psd.shape}, X_emg={X_emg.shape}. Skipping."
        )
        return

    ds.X_eeg_psd = X_psd
    ds.X_emg_feat = X_emg
    ds.eeg_psd_dim = X_psd.shape[1]
    ds.emg_feat_dim = X_emg.shape[1]

    print(
        f"[attach_features] Attached features for fold={fold_id}, split={split}: "
        f"X_psd={X_psd.shape}, X_emg={X_emg.shape}"
    )


# ---------------- FOLD DISCOVERY & DATALOADERS ----------------

def discover_folds(prefix: str) -> List[int]:
    folds = []
    for p in CFG.DATASET_DIR.glob(f"{prefix}_fold*"):
        name = p.name
        try:
            fid = int(name.split("fold")[-1])
            folds.append(fid)
        except ValueError:
            continue
    return sorted(set(folds))


def make_dataloader(ds: Dataset, batch_size: int, shuffle: bool) -> DataLoader:
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=CFG.NUM_WORKERS,
        pin_memory=True,
    )


def make_ssl_dataset() -> Tuple[ConcatDataset, int, int, int]:
    folds = discover_folds(CFG.SSL_PREFIX)
    if not folds:
        raise SystemExit(
            f"No SSL folds found with prefix {CFG.SSL_PREFIX}_fold* in {CFG.DATASET_DIR}"
        )
    ds_list = []
    eeg_ch = emg_ch = et_ch = None
    for fid in folds:
        fold_dir = CFG.DATASET_DIR / f"{CFG.SSL_PREFIX}_fold{fid}"
        ds = ShardWindowDataset(
            fold_dir=fold_dir,
            split="train",
            ssl_mode=True,
            max_windows=CFG.SSL_MAX_WINDOWS,   # CHANGED: subsample per SSL fold
            task2idx=None,
        )
        ds_list.append(ds)
        if eeg_ch is None:
            eeg_ch, emg_ch, et_ch = ds.eeg_ch, ds.emg_ch, ds.et_ch
        else:
            assert eeg_ch == ds.eeg_ch and emg_ch == ds.emg_ch and et_ch == ds.et_ch, \
                "Channel mismatch across SSL folds"
    ssl_dataset = ConcatDataset(ds_list)
    return ssl_dataset, eeg_ch, emg_ch, et_ch


def make_supervised_datasets_for_fold(
    fold_id: int,
) -> Tuple[ShardWindowDataset, ShardWindowDataset, ShardWindowDataset]:
    fold_dir = CFG.DATASET_DIR / f"{CFG.BALANCED_PREFIX}_fold{fold_id}"
    if not fold_dir.exists():
        raise FileNotFoundError(f"Balanced fold directory not found: {fold_dir}")

    train_ds = ShardWindowDataset(
        fold_dir=fold_dir,
        split="train",
        ssl_mode=False,
        max_windows=None,
        task2idx=None,
    )
    val_ds = ShardWindowDataset(
        fold_dir=fold_dir,
        split="val",
        ssl_mode=False,
        max_windows=None,
        task2idx=train_ds.task2idx,
    )
    test_ds = ShardWindowDataset(
        fold_dir=fold_dir,
        split="test",
        ssl_mode=False,
        max_windows=None,
        task2idx=train_ds.task2idx,
    )

    # Attach Phase 5.5 EEG-PSD + EMG features (if available)
    attach_features_to_dataset(train_ds, fold_id, "train")
    attach_features_to_dataset(val_ds,   fold_id, "val")
    attach_features_to_dataset(test_ds,  fold_id, "test")

    assert train_ds.eeg_ch == val_ds.eeg_ch == test_ds.eeg_ch
    assert train_ds.emg_ch == val_ds.emg_ch == test_ds.emg_ch
    assert train_ds.et_ch  == val_ds.et_ch  == test_ds.et_ch

    return train_ds, val_ds, test_ds


# ---------------- SSL HELPERS ----------------

def apply_ssl_mask(x: torch.Tensor, mask_prob: float):
    B, T, C = x.shape
    mask = (torch.rand(B, T, 1, device=x.device) < mask_prob).float()
    x_masked = x * (1.0 - mask)
    return x_masked, mask


def apply_modality_dropout(x: torch.Tensor, drop_prob: float) -> torch.Tensor:
    if drop_prob <= 0.0:
        return x
    B = x.size(0)
    mask = (torch.rand(B, 1, 1, device=x.device) < drop_prob).float()
    return x * (1.0 - mask)


def ssl_loss_reconstruction(
    x_hat: torch.Tensor, x_true: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
    """
    Reconstruction loss that is robust to temporal pooling.

    - x_hat:  (B, T_hat, C)  → decoder output (after POOL_STRIDE)
    - x_true: (B, T_true, C) → original time series (before pooling)
    - mask:   (B, T_true, 1) → which timesteps were masked in x_true

    If T_hat != T_true (e.g., T_true=500, T_hat=125 with stride 4),
    we downsample x_true and mask to match T_hat before computing MSE.
    """
    if x_true.size(-1) == 0:
        return torch.tensor(0.0, device=x_true.device)

    B, T_hat, C = x_hat.shape
    _, T_true, _ = x_true.shape

    # Align temporal dimension if pooling was applied
    if T_hat != T_true:
        # Try integer downsampling factor first
        if T_true % T_hat == 0:
            factor = T_true // T_hat
            x_true = x_true[:, ::factor, :]        # (B, T_hat, C)
            mask   = mask[:,  ::factor, :]        # (B, T_hat, 1)
        else:
            # Fallback: interpolate to T_hat (shouldn't happen with 500/4=125)
            x_true = x_true.transpose(1, 2)       # (B, C, T_true)
            x_true = F.interpolate(
                x_true, size=T_hat, mode="linear", align_corners=False
            )
            x_true = x_true.transpose(1, 2)       # (B, T_hat, C)

            mask = mask.transpose(1, 2)           # (B, 1, T_true)
            mask = F.interpolate(mask, size=T_hat, mode="nearest")
            mask = mask.transpose(1, 2)           # (B, T_hat, 1)

    # Now x_hat, x_true, mask all have temporal dim T_hat
    m = mask.expand_as(x_hat)                     # (B, T_hat, C)
    diff2 = (x_hat - x_true) ** 2
    num = (diff2 * m).sum()
    denom = m.sum() + 1e-8
    return num / denom



def contrastive_loss(
    z1: torch.Tensor, z2: torch.Tensor, temperature: float = 0.1
) -> torch.Tensor:
    if z1.size(0) < 2:
        return torch.tensor(0.0, device=z1.device)

    z1 = F.normalize(z1, dim=-1)
    z2 = F.normalize(z2, dim=-1)
    logits = z1 @ z2.t() / temperature
    labels = torch.arange(z1.size(0), device=z1.device)
    loss_12 = F.cross_entropy(logits, labels)
    loss_21 = F.cross_entropy(logits.t(), labels)
    return 0.5 * (loss_12 + loss_21)


def gating_regularization(gates: torch.Tensor) -> torch.Tensor:
    # gates: (B,3,D)
    # NOTE: currently disabled via LAMBDA_GATE=0.0 in CFG
    gates_mean = gates.mean(dim=-1)  # (B,3)
    eps = 1e-8
    ent = -(gates_mean * (gates_mean + eps).log()).sum(dim=1)  # (B,)
    max_ent = math.log(gates_mean.size(1))
    ent_norm = ent / (max_ent + eps)
    return 1.0 - ent_norm.mean()  # high entropy → low loss


def temporal_order_loss(
    model: TriModalSafetyTransformer,
    x_eeg: torch.Tensor,
    x_emg: torch.Tensor,
    x_et: torch.Tensor,
    segments: int = 4,
) -> torch.Tensor:
    B, T, Ce = x_eeg.shape
    if T < segments * 2:
        return torch.tensor(0.0, device=x_eeg.device)

    seg_len = T // segments
    T_used = seg_len * segments

    x_eeg_seg = x_eeg[:, :T_used, :].view(B, segments, seg_len, Ce)
    x_emg_seg = x_emg[:, :T_used, :].view(B, segments, seg_len, x_emg.shape[-1])
    x_et_seg  = x_et[:,  :T_used, :].view(B, segments, seg_len, x_et.shape[-1])

    indices = list(range(segments))
    perm = indices.copy()
    random.shuffle(perm)

    is_identity = (perm == indices)
    label = 0 if is_identity else 1
    y_order = torch.full((B,), label, dtype=torch.long, device=x_eeg.device)

    perm_tensor = torch.tensor(perm, dtype=torch.long, device=x_eeg.device)

    x_eeg_perm = x_eeg_seg[:, perm_tensor, :, :].reshape(B, T_used, Ce)
    x_emg_perm = x_emg_seg[:, perm_tensor, :, :].reshape(B, T_used, x_emg.shape[-1])
    x_et_perm  = x_et_seg[:,  perm_tensor, :, :].reshape(B, T_used, x_et.shape[-1])

    logits_order = model.forward_order(x_eeg_perm, x_emg_perm, x_et_perm)
    loss = F.cross_entropy(logits_order, y_order)
    return loss


def cross_modal_prediction_loss(
    model: TriModalSafetyTransformer,
    cls_eeg: torch.Tensor,
    cls_emg: torch.Tensor,
    cls_et: torch.Tensor,
) -> torch.Tensor:
    if cls_eeg.size(0) == 0:
        return torch.tensor(0.0, device=cls_eeg.device)

    tgt_eeg = cls_eeg.detach()
    tgt_emg = cls_emg.detach()
    tgt_et  = cls_et.detach()

    pred_emg_from_eeg = model.cross_eeg2emg(cls_eeg)
    pred_et_from_eeg  = model.cross_eeg2et(cls_eeg)
    pred_eeg_from_emg = model.cross_emg2eeg(cls_emg)
    pred_eeg_from_et  = model.cross_et2eeg(cls_et)

    loss = 0.0
    loss += F.mse_loss(pred_emg_from_eeg, tgt_emg)
    loss += F.mse_loss(pred_et_from_eeg,  tgt_et)
    loss += F.mse_loss(pred_eeg_from_emg, tgt_eeg)
    loss += F.mse_loss(pred_eeg_from_et,  tgt_eeg)
    return loss / 4.0


# ---------------- STAGE 1 — SSL PRETRAINING ----------------

def pretrain_ssl(
    eeg_ch: int, emg_ch: int, et_ch: int, ssl_dataset: Dataset
) -> Dict[str, torch.Tensor]:
    print("\n================ STAGE 1 — SSL PRETRAINING ================")
    dummy_num_task_classes = 6  # just for the task head; not used in SSL

    model = TriModalSafetyTransformer(
        eeg_ch=eeg_ch,
        emg_ch=emg_ch,
        et_ch=et_ch,
        num_task_classes=dummy_num_task_classes,
        d_model=CFG.D_MODEL,
        dropout=CFG.DROPOUT,
    ).to(CFG.DEVICE)

    loader = make_dataloader(ssl_dataset, batch_size=CFG.SSL_BATCH, shuffle=True)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CFG.SSL_LR,
        weight_decay=CFG.WEIGHT_DECAY,
    )

    for epoch in range(1, CFG.SSL_EPOCHS + 1):
        model.train()
        total_loss = 0.0
        total_batches = 0

        for batch in loader:
            x_eeg = batch["eeg"].to(CFG.DEVICE)
            x_emg = batch["emg"].to(CFG.DEVICE)
            x_et  = batch["et"].to(CFG.DEVICE)

            # View 1: masked
            x_eeg_v1, m_eeg_v1 = apply_ssl_mask(x_eeg, CFG.SSL_MASK_PROB)
            x_emg_v1, m_emg_v1 = apply_ssl_mask(x_emg, CFG.SSL_MASK_PROB)
            x_et_v1,  m_et_v1  = apply_ssl_mask(x_et,  CFG.SSL_MASK_PROB)

            # View 2: masked + modality dropout
            x_eeg_v2, _ = apply_ssl_mask(x_eeg, CFG.SSL_MASK_PROB)
            x_emg_v2, _ = apply_ssl_mask(x_emg, CFG.SSL_MASK_PROB)
            x_et_v2,  _ = apply_ssl_mask(x_et,  CFG.SSL_MASK_PROB)

            x_eeg_v2 = apply_modality_dropout(x_eeg_v2, CFG.SSL_MODALITY_DROPOUT)
            x_emg_v2 = apply_modality_dropout(x_emg_v2, CFG.SSL_MODALITY_DROPOUT)
            x_et_v2  = apply_modality_dropout(x_et_v2,  CFG.SSL_MODALITY_DROPOUT)

            x_hat_eeg, x_hat_emg, x_hat_et, aux1 = model.forward_ssl(
                x_eeg_v1, x_emg_v1, x_et_v1
            )
            z_cls1 = aux1["z_cls"]
            gates1 = aux1["gates"]
            cls_eeg = aux1["cls_eeg"]
            cls_emg = aux1["cls_emg"]
            cls_et  = aux1["cls_et"]

            _, _, _, z_cls2, _, _, _, _, _ = model.forward_backbone(
                x_eeg_v2, x_emg_v2, x_et_v2
            )

            loss_eeg = ssl_loss_reconstruction(x_hat_eeg, x_eeg, m_eeg_v1)
            loss_emg = ssl_loss_reconstruction(x_hat_emg, x_emg, m_emg_v1)
            loss_et  = ssl_loss_reconstruction(x_hat_et,  x_et,  m_et_v1)

            # Upweight EEG & EMG reconstruction in SSL
            w_eeg, w_emg, w_et = 1.5, 1.5, 1.0   # CHANGED
            loss_mask = w_eeg * loss_eeg + w_emg * loss_emg + w_et * loss_et

            loss_contrast = contrastive_loss(z_cls1, z_cls2, CFG.SSL_TEMPERATURE) \
                if CFG.LAMBDA_CONTRAST > 0 else torch.tensor(0.0, device=CFG.DEVICE)
            loss_gate = gating_regularization(gates1) \
                if CFG.LAMBDA_GATE > 0 else torch.tensor(0.0, device=CFG.DEVICE)
            loss_order = temporal_order_loss(model, x_eeg, x_emg, x_et) \
                if CFG.LAMBDA_ORDER > 0 else torch.tensor(0.0, device=CFG.DEVICE)
            loss_xmod = cross_modal_prediction_loss(model, cls_eeg, cls_emg, cls_et) \
                if CFG.LAMBDA_XMOD > 0 else torch.tensor(0.0, device=CFG.DEVICE)

            loss = (
                CFG.LAMBDA_MASK     * loss_mask
                + CFG.LAMBDA_CONTRAST * loss_contrast
                + CFG.LAMBDA_GATE   * loss_gate
                + CFG.LAMBDA_ORDER  * loss_order
                + CFG.LAMBDA_XMOD   * loss_xmod
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += float(loss.item())
            total_batches += 1

        avg_loss = total_loss / max(1, total_batches)
        print(f"[SSL] Epoch {epoch:02d}/{CFG.SSL_EPOCHS} loss={avg_loss:.4f}")

    print("[SSL] Pretraining complete.")
    return model.state_dict()


# ---------------- METRIC HELPERS ----------------

def compute_per_class_metrics_from_cm(cm: np.ndarray):
    num_classes = cm.shape[0]
    total = cm.sum()
    metrics = {}
    for cls in range(num_classes):
        tp = cm[cls, cls]
        support = cm[cls, :].sum()
        pred_count = cm[:, cls].sum()
        fn = support - tp
        fp = pred_count - tp
        tn = total - (tp + fn + fp)

        prec = float(tp / pred_count) if pred_count > 0 else 0.0
        rec = float(tp / support) if support > 0 else 0.0
        acc_ovr = float(tp + tn) / float(total) if total > 0 else 0.0

        tpr = rec
        tnr = float(tn) / float(tn + fp) if (tn + fp) > 0 else 0.0
        bal_acc = 0.5 * (tpr + tnr) if (support > 0 or (tn + fp) > 0) else 0.0

        f1 = float(2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0

        metrics[int(cls)] = {
            "class_index": int(cls),
            "precision": prec,
            "recall": rec,
            "f1": f1,
            "accuracy": acc_ovr,
            "balanced_accuracy": bal_acc,
            "support": int(support),
        }
    return metrics


def compute_topk_accuracies_from_logits(
    logits: np.ndarray,
    targets: np.ndarray,
    ks=(1, 3),
) -> Dict[int, float]:
    if logits.shape[0] == 0:
        return {k: 0.0 for k in ks}
    t_logits = torch.from_numpy(logits)
    t_targets = torch.from_numpy(targets)
    maxk = max(ks)
    _, pred = t_logits.topk(maxk, dim=1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(t_targets.view(1, -1).expand_as(pred))
    res = {}
    for k in ks:
        correct_k = correct[:k].any(dim=0).float().sum().item()
        res[k] = correct_k / t_targets.size(0)
    return res


# ---------------- CLASS-WEIGHTED LOSSES ----------------

def _collect_labels_from_dataset(train_ds):
    if isinstance(train_ds, ShardWindowDataset):
        y_action = np.array(train_ds.y_action, dtype=int)
        y_task_raw = np.array(train_ds.y_task, dtype=int)
        task2idx = dict(train_ds.task2idx)
    elif isinstance(train_ds, torch.utils.data.Subset):
        base = train_ds.dataset
        indices = np.array(train_ds.indices, dtype=int)
        assert isinstance(base, ShardWindowDataset)
        y_action = np.array(base.y_action, dtype=int)[indices]
        y_task_raw = np.array(base.y_task, dtype=int)[indices]
        task2idx = dict(base.task2idx)
    else:
        acts, tasks_raw = [], []
        for i in range(len(train_ds)):
            item = train_ds[i]
            acts.append(int(item["action"]))
            tasks_raw.append(int(item["y_task_raw"]))
        y_action = np.array(acts, dtype=int)
        y_task_raw = np.array(tasks_raw, dtype=int)
        uniq = sorted(np.unique(y_task_raw))
        task2idx = {t: i for i, t in enumerate(uniq)}
    return y_action, y_task_raw, task2idx


def _collect_labels_from_dataset(train_ds):
    if isinstance(train_ds, ShardWindowDataset):
        y_action = np.array(train_ds.y_action, dtype=int)
        y_task_raw = np.array(train_ds.y_task, dtype=int)
        task2idx = dict(train_ds.task2idx)
    elif isinstance(train_ds, torch.utils.data.Subset):
        base = train_ds.dataset
        indices = np.array(train_ds.indices, dtype=int)
        assert isinstance(base, ShardWindowDataset)
        y_action = np.array(base.y_action, dtype=int)[indices]
        y_task_raw = np.array(base.y_task, dtype=int)[indices]
        task2idx = dict(base.task2idx)
    else:
        acts, tasks_raw = [], []
        for i in range(len(train_ds)):
            item = train_ds[i]
            acts.append(int(item["action"]))
            tasks_raw.append(int(item["y_task_raw"]))
        y_action = np.array(acts, dtype=int)
        y_task_raw = np.array(tasks_raw, dtype=int)
        uniq = sorted(np.unique(y_task_raw))
        task2idx = {t: i for i, t in enumerate(uniq)}
    return y_action, y_task_raw, task2idx


# ---------- NEW: class-weighted action loss ----------
def build_action_criterion(train_ds) -> nn.Module:
    """
    Build a class-weighted CrossEntropy loss for the binary action head
    (REST=0, ACTION=1), based on the action label distribution in train_ds.
    """
    y_action, _, _ = _collect_labels_from_dataset(train_ds)   # y_action ∈ {0,1}
    class_counts = np.bincount(y_action, minlength=2)         # [count_0, count_1]
    # Inverse-frequency weighting
    weights = len(y_action) / (2.0 * (class_counts + 1e-6))
    print("[action weights]", weights)

    class_weights = torch.tensor(weights, dtype=torch.float32, device=CFG.DEVICE)
    return nn.CrossEntropyLoss(weight=class_weights)




# ---------- NEW: class-weighted TASK loss ----------
def build_task_criterion(train_ds,
                         task2idx: Dict[int, int],
                         num_task_classes: int) -> nn.Module:
    """
    Build a class-weighted CrossEntropy loss for the task head.

    We compute the label distribution over ACTION windows only
    (y_action == 1), because the task head is trained only on those
    samples in train_one_fold().
    """
    # Get raw labels from the dataset
    y_action, y_task_raw, _ = _collect_labels_from_dataset(train_ds)

    # Use only ACTION windows for task weighting
    mask_action = (y_action == 1)
    y_task_raw_action = y_task_raw[mask_action]

    if y_task_raw_action.size == 0:
        print("[task weights] No ACTION windows in train_ds; using unweighted CE.")
        return nn.CrossEntropyLoss()

    # Map raw task codes (e.g., 0,1,2,3,4,5) to indices [0..num_task_classes-1]
    y_task_idx = np.array(
        [task2idx[int(t)] for t in y_task_raw_action],
        dtype=int,
    )

    # Count per-class frequency
    class_counts = np.bincount(y_task_idx, minlength=num_task_classes)

    # Inverse-frequency weighting
    weights = len(y_task_idx) / (num_task_classes * (class_counts + 1e-6))
    print("[task weights]", weights)

    class_weights = torch.tensor(weights, dtype=torch.float32, device=CFG.DEVICE)
    return nn.CrossEntropyLoss(weight=class_weights)


# ---------------- ACTION THRESHOLD TUNING ----------------

def tune_action_threshold(
    model: TriModalSafetyTransformer,
    val_loader: DataLoader,
    use_eeg: bool,
    use_emg: bool,
    use_et: bool,
    device: str,
    num_thresholds: int = 17,
) -> float:
    model.eval()
    all_probs = []
    all_y = []

    with torch.no_grad():
        for batch in val_loader:
            x_eeg = batch["eeg"].to(device)
            x_emg = batch["emg"].to(device)
            x_et  = batch["et"].to(device)
            y_action = batch["action"].to(device)

            eeg_psd = batch.get("eeg_psd", None)
            if eeg_psd is not None:
                eeg_psd = eeg_psd.to(device)
            emg_feat = batch.get("emg_feat", None)
            if emg_feat is not None:
                emg_feat = emg_feat.to(device)

            # Ablation-level modality usage
            if not use_eeg:
                x_eeg = torch.zeros_like(x_eeg)
                if eeg_psd is not None:
                    eeg_psd = torch.zeros_like(eeg_psd)
            if not use_emg:
                x_emg = torch.zeros_like(x_emg)
                if emg_feat is not None:
                    emg_feat = torch.zeros_like(emg_feat)
            if not use_et:
                x_et  = torch.zeros_like(x_et)

            logits_action, _, _ = model.forward_supervised(
                x_eeg, x_emg, x_et,
                eeg_psd=eeg_psd,
                emg_feat=emg_feat,
            )
            probs_action = F.softmax(logits_action, dim=1)[:, 1]

            all_probs.append(probs_action.cpu().numpy())
            all_y.append(y_action.cpu().numpy())

    if not all_probs:
        return 0.5

    all_probs = np.concatenate(all_probs)
    all_y = np.concatenate(all_y)

    thresholds = np.linspace(0.1, 0.9, num_thresholds)
    best_t = 0.5
    best_bal = 0.0

    for t in thresholds:
        preds = (all_probs >= t).astype(int)
        bal = balanced_accuracy_score(all_y, preds)
        if bal > best_bal:
            best_bal = bal
            best_t = float(t)

    print(f"[tune_action_threshold] best τ={best_t:.3f}, val_bal_acc={best_bal:.3f}")
    return best_t


# ---------------- POLICY EVALUATION (S0–S3, P0–P2) ----------------
# (unchanged from your version; omitted comments to keep length manageable)

def evaluate_policies_for_scenarios(
    model: TriModalSafetyTransformer,
    test_loader: DataLoader,
    base_threshold: float,
    num_task_classes: int,
    use_eeg: bool,
    use_emg: bool,
    use_et: bool,
) -> Dict:
    model.eval()
    rest_idx = 0

    scenario_results = {}

    for scen_name, scen_cfg in SCENARIOS.items():
        drop_eeg = scen_cfg["drop_eeg"]
        drop_emg = scen_cfg["drop_emg"]
        drop_et  = scen_cfg["drop_et"]

        all_gt_action = []
        all_gt_task   = []
        all_p_act     = []
        all_logits_task = []
        all_gates_mean = []

        with torch.no_grad():
            for batch in test_loader:
                x_eeg = batch["eeg"].to(CFG.DEVICE)
                x_emg = batch["emg"].to(CFG.DEVICE)
                x_et  = batch["et"].to(CFG.DEVICE)
                y_action = batch["action"].to(CFG.DEVICE)
                y_task   = batch["task"].to(CFG.DEVICE)

                eeg_psd = batch.get("eeg_psd", None)
                if eeg_psd is not None:
                    eeg_psd = eeg_psd.to(CFG.DEVICE)
                emg_feat = batch.get("emg_feat", None)
                if emg_feat is not None:
                    emg_feat = emg_feat.to(CFG.DEVICE)

                # Ablation + scenario-level drops
                if (not use_eeg) or drop_eeg:
                    x_eeg = torch.zeros_like(x_eeg)
                    if eeg_psd is not None:
                        eeg_psd = torch.zeros_like(eeg_psd)
                if (not use_emg) or drop_emg:
                    x_emg = torch.zeros_like(x_emg)
                    if emg_feat is not None:
                        emg_feat = torch.zeros_like(emg_feat)
                if (not use_et) or drop_et:
                    x_et  = torch.zeros_like(x_et)

                logits_action, logits_task, aux = model.forward_supervised(
                    x_eeg, x_emg, x_et,
                    eeg_psd=eeg_psd,
                    emg_feat=emg_feat,
                )

                probs_action = F.softmax(logits_action, dim=1)[:, 1]
                gates = aux["gates"]  # (B,3,D)
                g_mean = gates.mean(dim=-1)  # (B,3)

                all_gt_action.append(y_action.cpu().numpy())
                all_gt_task.append(y_task.cpu().numpy())
                all_p_act.append(probs_action.cpu().numpy())
                all_logits_task.append(logits_task.cpu().numpy())
                all_gates_mean.append(g_mean.cpu().numpy())

        if not all_gt_action:
            continue

        gt_action = np.concatenate(all_gt_action)
        gt_task   = np.concatenate(all_gt_task)
        p_act     = np.concatenate(all_p_act)
        logits_task = np.concatenate(all_logits_task)
        gates_mean  = np.concatenate(all_gates_mean)  # (N,3)

        scen_mask = np.array([
            0.0 if (drop_eeg or not use_eeg) else 1.0,
            0.0 if (drop_emg or not use_emg) else 1.0,
            0.0 if (drop_et  or not use_et)  else 1.0,
        ], dtype=np.float32)
        healthy_modalities = scen_mask.sum()
        if healthy_modalities <= 0:
            print(f"[WARN] Scenario {scen_name}: no healthy modalities, skipping.")
            continue

        g_healthy = (gates_mean * scen_mask[None, :]).sum(axis=1)  # (N,)
        r_reliab = g_healthy

        p_task = torch.softmax(torch.from_numpy(logits_task), dim=1).numpy()
        eps = 1e-8
        H_task = -(p_task * np.log(p_task + eps)).sum(axis=1)
        H_task_norm = H_task / (math.log(num_task_classes) + eps)

        gates_prob = gates_mean / (gates_mean.sum(axis=1, keepdims=True) + eps)
        H_gate = -(gates_prob * np.log(gates_prob + eps)).sum(axis=1)
        H_gate_norm = H_gate / (math.log(3.0) + eps)

        gmue = float(H_gate.mean())
        gate_mean = gates_prob.mean(axis=0)
        gate_std  = gates_prob.std(axis=0)

        policy_entries = {}

        def compute_policy_metrics(do_move: np.ndarray,
                                   pred_task_idx: np.ndarray,
                                   score_vec: np.ndarray):
            gt_a = gt_action
            gt_t = gt_task

            total = len(gt_a)
            total_action = int((gt_a == 1).sum())
            total_rest   = int((gt_a == 0).sum())

            correct_move = (
                (do_move == 1)
                & (gt_a == 1)
                & (pred_task_idx == gt_t)
            )
            unsafe_move = (
                (do_move == 1)
                & (
                    (gt_a == 0)
                    | ((gt_a == 1) & (pred_task_idx != gt_t))
                )
            )
            missed_intent = ((do_move == 0) & (gt_a == 1))
            safe_idle = ((do_move == 0) & (gt_a == 0))

            n_correct = int(correct_move.sum())
            n_unsafe  = int(unsafe_move.sum())
            n_missed  = int(missed_intent.sum())
            n_safeidle = int(safe_idle.sum())

            denom_actions = max(1, n_correct + n_unsafe)
            SAE = n_unsafe / denom_actions

            denom_intent = max(1, total_action)
            UAR = n_correct / denom_intent
            MIR = n_missed / denom_intent

            denom_rest = max(1, total_rest)
            NRA = n_safeidle / denom_rest

            action_pred = do_move.astype(int)
            acc_action = accuracy_score(gt_a, action_pred)
            f1_action_macro = f1_score(gt_a, action_pred, average="macro")
            f1_action_micro = f1_score(gt_a, action_pred, average="micro")
            bal_action = balanced_accuracy_score(gt_a, action_pred)
            try:
                auc_roc = roc_auc_score(gt_a, score_vec)
            except Exception:
                auc_roc = None
            try:
                auprc = average_precision_score(gt_a, score_vec)
            except Exception:
                auprc = None

            cm_action = confusion_matrix(gt_a, action_pred, labels=[0, 1])
            per_class_action = compute_per_class_metrics_from_cm(cm_action)

            acc_task = accuracy_score(gt_t, pred_task_idx)
            f1_task_macro = f1_score(gt_t, pred_task_idx, average="macro")
            f1_task_micro = f1_score(gt_t, pred_task_idx, average="micro")
            bal_task = balanced_accuracy_score(gt_t, pred_task_idx)
            cm_task = confusion_matrix(gt_t, pred_task_idx, labels=list(range(num_task_classes)))
            per_class_task = compute_per_class_metrics_from_cm(cm_task)

            task_topk = compute_topk_accuracies_from_logits(
                logits_task, gt_t, ks=CFG.TOPK
            )

            return {
                "classic": {
                    "action_acc": float(acc_action),
                    "action_macro_f1": float(f1_action_macro),
                    "action_micro_f1": float(f1_action_micro),
                    "action_bal_acc": float(bal_action),
                    "action_auc_roc": float(auc_roc) if auc_roc is not None else None,
                    "action_auprc": float(auprc) if auprc is not None else None,
                    "cm_action": cm_action.tolist(),
                    "per_class_action": per_class_action,
                    "task_acc": float(acc_task),
                    "task_macro_f1": float(f1_task_macro),
                    "task_micro_f1": float(f1_task_micro),
                    "task_bal_acc": float(bal_task),
                    "cm_task": cm_task.tolist(),
                    "per_class_task": per_class_task,
                    "task_topk": {
                        f"top_{k}": float(task_topk.get(k, 0.0)) for k in CFG.TOPK
                    },
                },
                "safety": {
                    "SAE": float(SAE),
                    "UAR": float(UAR),
                    "MIR": float(MIR),
                    "NRA": float(NRA),
                    "counts": {
                        "total": int(total),
                        "total_action": int(total_action),
                        "total_rest": int(total_rest),
                        "correct_move": n_correct,
                        "unsafe_move": n_unsafe,
                        "missed_intent": n_missed,
                        "safe_idle": n_safeidle,
                    },
                },
            }

        tau0 = base_threshold
        p_eff0 = p_act.copy()
        do_move0 = (p_eff0 >= tau0).astype(int)

        logits_task_mod0 = logits_task.copy()
        move_idx0 = (do_move0 == 1)
        logits_task_mod0[move_idx0, rest_idx] = -1e9
        pred_task0 = logits_task_mod0.argmax(axis=1)
        pred_task0[do_move0 == 0] = rest_idx

        policy_entries["P0"] = compute_policy_metrics(
            do_move0, pred_task0, p_eff0
        )

        p_eff1 = p_act * r_reliab
        tau1 = base_threshold
        do_move1 = (p_eff1 >= tau1).astype(int)

        logits_task_mod1 = logits_task.copy()
        move_idx1 = (do_move1 == 1)
        logits_task_mod1[move_idx1, rest_idx] = -1e9
        pred_task1 = logits_task_mod1.argmax(axis=1)
        pred_task1[do_move1 == 0] = rest_idx

        policy_entries["P1"] = compute_policy_metrics(
            do_move1, pred_task1, p_eff1
        )

        unc = 0.5 * (H_task_norm + H_gate_norm)
        safety_factor = r_reliab * (1.0 - unc)
        p_eff2 = p_act * safety_factor
        tau2 = base_threshold
        do_move2 = (p_eff2 >= tau2).astype(int)

        logits_task_mod2 = logits_task.copy()
        move_idx2 = (do_move2 == 1)
        logits_task_mod2[move_idx2, rest_idx] = -1e9
        pred_task2 = logits_task_mod2.argmax(axis=1)
        pred_task2[do_move2 == 0] = rest_idx

        policy_entries["P2"] = compute_policy_metrics(
            do_move2, pred_task2, p_eff2
        )

        SAE_curve = []
        MIR_curve = []
        thresholds = CFG.P2_THRESHOLDS
        for t in thresholds:
            do_move_t = (p_eff2 >= t).astype(int)
            logits_task_mod_t = logits_task.copy()
            move_idx_t = (do_move_t == 1)
            logits_task_mod_t[move_idx_t, rest_idx] = -1e9
            pred_task_t = logits_task_mod_t.argmax(axis=1)
            pred_task_t[do_move_t == 0] = rest_idx

            metrics_t = compute_policy_metrics(
                do_move_t, pred_task_t, p_eff2
            )["safety"]
            SAE_curve.append(metrics_t["SAE"])
            MIR_curve.append(metrics_t["MIR"])

        scenario_results[scen_name] = {
            "description": scen_cfg["description"],
            "gmue": gmue,
            "gates_mean": gate_mean.tolist(),
            "gates_std": gate_std.tolist(),
            "policies": policy_entries,
            "P2_curve": {
                "thresholds": thresholds,
                "SAE": SAE_curve,
                "MIR": MIR_curve,
            },
        }

    return {
        "base_threshold": base_threshold,
        "scenarios": scenario_results,
    }



# ---------------- STAGE 2 — SUPERVISED LOSO TRAINING ----------------

def train_one_fold(
    fold_id: int,
    base_state: Optional[Dict[str, torch.Tensor]],
    eeg_ch: int,
    emg_ch: int,
    et_ch: int,
    ablation_name: str,
    use_eeg: bool,
    use_emg: bool,
    use_et: bool,
) -> Dict:

    print(
        f"\n================ STAGE 2 — Supervised LOSO (FOLD {fold_id}, "
        f"ablation={ablation_name}) ================"
    )

    train_ds, val_ds, test_ds = make_supervised_datasets_for_fold(fold_id)
    num_task_classes = train_ds.num_task_classes

    print(
        f"[Fold {fold_id}][{ablation_name}] num_task_classes={num_task_classes}, "
        f"tasks={sorted(train_ds.task2idx.keys())}"
    )

    eeg_psd_dim = getattr(train_ds, "eeg_psd_dim", 0)
    emg_feat_dim = getattr(train_ds, "emg_feat_dim", 0)

    use_eeg_psd = (eeg_psd_dim > 0) and CFG.USE_EEG_PSD_FEATURES and use_eeg
    use_emg_feat = (emg_feat_dim > 0) and CFG.USE_EMG_FEATURES and use_emg

    # ---- NEW: fold-wise normalization stats for Phase 5.5 features ----
    eeg_psd_mean_t = eeg_psd_std_t = None
    emg_feat_mean_t = emg_feat_std_t = None

    if use_eeg_psd:
        Xp = train_ds.X_eeg_psd.astype(np.float32)   # (N_train, F_psd)
        eeg_psd_mean = Xp.mean(axis=0)
        eeg_psd_std  = Xp.std(axis=0) + 1e-6
        eeg_psd_mean_t = torch.from_numpy(eeg_psd_mean).float().to(CFG.DEVICE)
        eeg_psd_std_t  = torch.from_numpy(eeg_psd_std).float().to(CFG.DEVICE)

    if use_emg_feat:
        Xe = train_ds.X_emg_feat.astype(np.float32)  # (N_train, F_emg)
        emg_feat_mean = Xe.mean(axis=0)
        emg_feat_std  = Xe.std(axis=0) + 1e-6
        emg_feat_mean_t = torch.from_numpy(emg_feat_mean).float().to(CFG.DEVICE)
        emg_feat_std_t  = torch.from_numpy(emg_feat_std).float().to(CFG.DEVICE)

    if use_eeg_psd or use_emg_feat:
        print(
            f"[Fold {fold_id}][{ablation_name}] Using features: "
            f"EEG_PSD_dim={eeg_psd_dim if use_eeg_psd else 0}, "
            f"EMG_feat_dim={emg_feat_dim if use_emg_feat else 0}"
        )
    else:
        print(f"[Fold {fold_id}][{ablation_name}] No feature vectors used.")

    model = TriModalSafetyTransformer(
        eeg_ch=eeg_ch,
        emg_ch=emg_ch,
        et_ch=et_ch,
        num_task_classes=num_task_classes,
        d_model=CFG.D_MODEL,
        dropout=CFG.DROPOUT,
        use_eeg_psd=use_eeg_psd,
        use_emg_feat=use_emg_feat,
        eeg_psd_dim=eeg_psd_dim,
        emg_feat_dim=emg_feat_dim,
        eeg_psd_mean=eeg_psd_mean_t,
        eeg_psd_std=eeg_psd_std_t,
        emg_feat_mean=emg_feat_mean_t,
        emg_feat_std=emg_feat_std_t,
    ).to(CFG.DEVICE)

    if CFG.USE_SSL and base_state is not None:
        missing, unexpected = model.load_state_dict(base_state, strict=False)
        print(
            f"[Fold {fold_id}][{ablation_name}] Loaded SSL backbone. "
            f"Missing={len(missing)}, Unexpected={len(unexpected)}"
        )
    else:
        print(f"[Fold {fold_id}][{ablation_name}] Random init (no SSL).")


    train_loader = make_dataloader(train_ds, batch_size=CFG.SUP_BATCH, shuffle=True)
    val_loader   = make_dataloader(val_ds,   batch_size=CFG.SUP_BATCH, shuffle=False)
    test_loader  = make_dataloader(test_ds,  batch_size=CFG.SUP_BATCH, shuffle=False)

    crit_action = build_action_criterion(train_ds)
    crit_task   = build_task_criterion(train_ds, task2idx=train_ds.task2idx,
                                       num_task_classes=num_task_classes)

    head_params = []
    backbone_params = []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("action_head.") or name.startswith("task_head."):
            head_params.append(p)
        else:
            backbone_params.append(p)

    optimizer = torch.optim.AdamW(
        [
            {"params": backbone_params, "lr": CFG.SUP_LR * CFG.BACKBONE_LR_SCALE},
            {"params": head_params,     "lr": CFG.SUP_LR},
        ],
        weight_decay=CFG.WEIGHT_DECAY,
    )

    best_val_loss = float("inf")
    best_state = None
    patience_counter = 0

    # ---------------- TRAINING LOOP ----------------
    for epoch in range(1, CFG.SUP_EPOCHS + 1):
        model.train()
        total_loss = 0.0
        total_samples = 0
        total_action_correct = 0

        for batch in train_loader:
            x_eeg = batch["eeg"].to(CFG.DEVICE)
            x_emg = batch["emg"].to(CFG.DEVICE)
            x_et  = batch["et"].to(CFG.DEVICE)
            y_action = batch["action"].to(CFG.DEVICE)
            y_task   = batch["task"].to(CFG.DEVICE)

            eeg_psd = batch.get("eeg_psd", None)
            if eeg_psd is not None:
                eeg_psd = eeg_psd.to(CFG.DEVICE)
            emg_feat = batch.get("emg_feat", None)
            if emg_feat is not None:
                emg_feat = emg_feat.to(CFG.DEVICE)

            # Optional ET dropout during supervised training
            if use_et and CFG.SUP_ET_DROPOUT > 0.0:
                x_et = apply_modality_dropout(x_et, CFG.SUP_ET_DROPOUT)

            if not use_eeg:
                x_eeg = torch.zeros_like(x_eeg)
                if eeg_psd is not None:
                    eeg_psd = torch.zeros_like(eeg_psd)
            if not use_emg:
                x_emg = torch.zeros_like(x_emg)
                if emg_feat is not None:
                    emg_feat = torch.zeros_like(emg_feat)
            if not use_et:
                x_et  = torch.zeros_like(x_et)

            logits_action, logits_task, _ = model.forward_supervised(
                x_eeg, x_emg, x_et,
                eeg_psd=eeg_psd,
                emg_feat=emg_feat,
            )

            loss_action = crit_action(logits_action, y_action)
            action_mask = (y_action == 1)
            if action_mask.any():
                loss_task = crit_task(logits_task[action_mask], y_task[action_mask])
            else:
                loss_task = torch.tensor(0.0, device=CFG.DEVICE)

            loss = CFG.ALPHA_ACTION * loss_action + CFG.BETA_TASK * loss_task

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            B = y_action.size(0)
            total_loss += float(loss.item()) * B
            total_samples += B

            preds_action = logits_action.argmax(dim=1)
            total_action_correct += (preds_action == y_action).sum().item()

        train_loss = total_loss / max(1, total_samples)
        train_action_acc = total_action_correct / max(1, total_samples)

        # ---------------- VALIDATION ----------------
        model.eval()
        val_loss = 0.0
        val_samples = 0
        val_action_correct = 0

        with torch.no_grad():
            for batch in val_loader:
                x_eeg = batch["eeg"].to(CFG.DEVICE)
                x_emg = batch["emg"].to(CFG.DEVICE)
                x_et  = batch["et"].to(CFG.DEVICE)
                y_action = batch["action"].to(CFG.DEVICE)
                y_task   = batch["task"].to(CFG.DEVICE)

                eeg_psd = batch.get("eeg_psd", None)
                if eeg_psd is not None:
                    eeg_psd = eeg_psd.to(CFG.DEVICE)
                emg_feat = batch.get("emg_feat", None)
                if emg_feat is not None:
                    emg_feat = emg_feat.to(CFG.DEVICE)

                if not use_eeg:
                    x_eeg = torch.zeros_like(x_eeg)
                    if eeg_psd is not None:
                        eeg_psd = torch.zeros_like(eeg_psd)
                if not use_emg:
                    x_emg = torch.zeros_like(x_emg)
                    if emg_feat is not None:
                        emg_feat = torch.zeros_like(emg_feat)
                if not use_et:
                    x_et  = torch.zeros_like(x_et)

                logits_action, logits_task, _ = model.forward_supervised(
                    x_eeg, x_emg, x_et,
                    eeg_psd=eeg_psd,
                    emg_feat=emg_feat,
                )

                loss_action = crit_action(logits_action, y_action)
                action_mask = (y_action == 1)
                if action_mask.any():
                    loss_task = crit_task(logits_task[action_mask], y_task[action_mask])
                else:
                    loss_task = torch.tensor(0.0, device=CFG.DEVICE)

                loss = CFG.ALPHA_ACTION * loss_action + CFG.BETA_TASK * loss_task

                B = y_action.size(0)
                val_loss += float(loss.item()) * B
                val_samples += B

                preds_action = logits_action.argmax(dim=1)
                val_action_correct += (preds_action == y_action).sum().item()

        val_loss /= max(1, val_samples)
        val_action_acc = val_action_correct / max(1, val_samples)

        print(
            f"[Fold {fold_id}][{ablation_name}] Epoch {epoch:02d}/{CFG.SUP_EPOCHS} "
            f"train_loss={train_loss:.4f} train_acc={train_action_acc:.3f} "
            f"val_loss={val_loss:.4f} val_acc={val_action_acc:.3f}"
        )

        if val_loss < best_val_loss - 1e-4:
            best_val_loss = val_loss
            best_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE:
                print(f"[Fold {fold_id}][{ablation_name}] Early stopping at epoch {epoch}")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    base_threshold = tune_action_threshold(
        model, val_loader,
        use_eeg=use_eeg, use_emg=use_emg, use_et=use_et,
        device=CFG.DEVICE,
    )

    # ---------------- TEST EVALUATION (S0, P0 ONLY FOR CLASSIC METRICS) ----------------
    model.eval()
    all_action_true, all_action_pred = [], []
    all_action_scores = []
    all_task_true, all_task_pred = [], []
    all_task_logits = []

    rest_idx = 0

    with torch.no_grad():
        for batch in test_loader:
            x_eeg = batch["eeg"].to(CFG.DEVICE)
            x_emg = batch["emg"].to(CFG.DEVICE)
            x_et  = batch["et"].to(CFG.DEVICE)
            y_action = batch["action"].to(CFG.DEVICE)
            y_task   = batch["task"].to(CFG.DEVICE)

            eeg_psd = batch.get("eeg_psd", None)
            if eeg_psd is not None:
                eeg_psd = eeg_psd.to(CFG.DEVICE)
            emg_feat = batch.get("emg_feat", None)
            if emg_feat is not None:
                emg_feat = emg_feat.to(CFG.DEVICE)

            if not use_eeg:
                x_eeg = torch.zeros_like(x_eeg)
                if eeg_psd is not None:
                    eeg_psd = torch.zeros_like(eeg_psd)
            if not use_emg:
                x_emg = torch.zeros_like(x_emg)
                if emg_feat is not None:
                    emg_feat = torch.zeros_like(emg_feat)
            if not use_et:
                x_et  = torch.zeros_like(x_et)

            logits_action, logits_task, _ = model.forward_supervised(
                x_eeg, x_emg, x_et,
                eeg_psd=eeg_psd,
                emg_feat=emg_feat,
            )

            probs_action = F.softmax(logits_action, dim=1)[:, 1]
            preds_action = (probs_action >= base_threshold).long()

            logits_task_mod = logits_task.clone()
            move_idx = (preds_action == 1)
            logits_task_mod[move_idx, rest_idx] = -1e9
            preds_task = logits_task_mod.argmax(dim=1)
            preds_task[preds_action == 0] = rest_idx

            all_action_true.append(y_action.cpu().numpy())
            all_action_pred.append(preds_action.cpu().numpy())
            all_action_scores.append(probs_action.cpu().numpy())
            all_task_true.append(y_task.cpu().numpy())
            all_task_pred.append(preds_task.cpu().numpy())
            all_task_logits.append(logits_task_mod.cpu().numpy())

    all_action_true = np.concatenate(all_action_true)
    all_action_pred = np.concatenate(all_action_pred)
    all_action_scores = np.concatenate(all_action_scores)
    all_task_true = np.concatenate(all_task_true)
    all_task_pred = np.concatenate(all_task_pred)
    all_task_logits = np.concatenate(all_task_logits)

    acc_action = accuracy_score(all_action_true, all_action_pred)
    f1_action_macro = f1_score(all_action_true, all_action_pred, average="macro")
    f1_action_micro = f1_score(all_action_true, all_action_pred, average="micro")
    bal_action = balanced_accuracy_score(all_action_true, all_action_pred)
    try:
        auc_roc = roc_auc_score(all_action_true, all_action_scores)
    except Exception:
        auc_roc = None
    try:
        auprc = average_precision_score(all_action_true, all_action_scores)
    except Exception:
        auprc = None
    cm_action = confusion_matrix(all_action_true, all_action_pred, labels=[0, 1])
    per_class_action = compute_per_class_metrics_from_cm(cm_action)

    print(
        f"[Fold {fold_id}][{ablation_name}] TEST P0/S0 — Action: "
        f"acc={acc_action:.3f}, macro-F1={f1_action_macro:.3f}, "
        f"micro-F1={f1_action_micro:.3f}, bal-acc={bal_action:.3f}"
    )

    acc_task = accuracy_score(all_task_true, all_task_pred)
    f1_task_macro = f1_score(all_task_true, all_task_pred, average="macro")
    f1_task_micro = f1_score(all_task_true, all_task_pred, average="micro")
    bal_task = balanced_accuracy_score(all_task_true, all_task_pred)
    cm_task = confusion_matrix(all_task_true, all_task_pred,
                               labels=list(range(num_task_classes)))
    per_class_task = compute_per_class_metrics_from_cm(cm_task)
    task_topk = compute_topk_accuracies_from_logits(
        all_task_logits, all_task_true, ks=CFG.TOPK
    )

    print(
        f"[Fold {fold_id}][{ablation_name}] TEST P0/S0 — Task: "
        f"acc={acc_task:.3f}, macro-F1={f1_task_macro:.3f}, "
        f"micro-F1={f1_task_micro:.3f}, bal-acc={bal_task:.3f}"
    )

    print(f"[Fold {fold_id}][{ablation_name}] Per-class ACTION metrics (P0/S0):")
    for cls_idx, m in sorted(per_class_action.items()):
        label = "REST" if cls_idx == 0 else "ACTION"
        print(
            f"  class {cls_idx} ({label}): "
            f"prec={m['precision']:.3f}, rec={m['recall']:.3f}, "
            f"f1={m['f1']:.3f}, bal_acc={m['balanced_accuracy']:.3f}, "
            f"support={m['support']}"
        )

    print(f"[Fold {fold_id}][{ablation_name}] Per-class TASK metrics (P0/S0):")
    idx2task = {idx: code for code, idx in train_ds.task2idx.items()}
    for cls_idx, m in sorted(per_class_task.items()):
        task_code = idx2task.get(cls_idx, None)
        print(
            f"  class {cls_idx} (task_code={task_code}): "
            f"prec={m['precision']:.3f}, rec={m['recall']:.3f}, "
            f"f1={m['f1']:.3f}, bal_acc={m['balanced_accuracy']:.3f}, "
            f"support={m['support']}"
        )

    safety_results = None
    gmue_fold = None
    if ablation_name == "all":
        safety_results = evaluate_policies_for_scenarios(
            model=model,
            test_loader=test_loader,
            base_threshold=base_threshold,
            num_task_classes=num_task_classes,
            use_eeg=use_eeg,
            use_emg=use_emg,
            use_et=use_et,
        )
        scen0 = safety_results["scenarios"].get("S0", None)
        if scen0 is not None:
            gmue_fold = scen0.get("gmue", None)

    return {
        "fold": fold_id,
        "ablation": ablation_name,
        "use_eeg": use_eeg,
        "use_emg": use_emg,
        "use_et": use_et,
        "num_task_classes": int(num_task_classes),
        "task2idx": train_ds.task2idx,
        "classic": {
            "action_acc": float(acc_action),
            "action_macro_f1": float(f1_action_macro),
            "action_micro_f1": float(f1_action_micro),
            "action_bal_acc": float(bal_action),
            "action_auc_roc": float(auc_roc) if auc_roc is not None else None,
            "action_auprc": float(auprc) if auprc is not None else None,
            "cm_action": cm_action.tolist(),
            "per_class_action": per_class_action,
            "task_acc": float(acc_task),
            "task_macro_f1": float(f1_task_macro),
            "task_micro_f1": float(f1_task_micro),
            "task_bal_acc": float(bal_task),
            "cm_task": cm_task.tolist(),
            "per_class_task": per_class_task,
            "task_topk": {
                f"top_{k}": float(task_topk.get(k, 0.0)) for k in CFG.TOPK
            },
        },
        "base_threshold": float(base_threshold),
        "safety": safety_results,
        "gmue": gmue_fold,
    }

#=================================

# ---------------- RUNTIME / LATENCY BENCHMARK HELPERS (NEW) ----------------

def count_parameters(model: nn.Module) -> int:
    """Total trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def _measure_latency_one_device(
    model: TriModalSafetyTransformer,
    batch: Dict[str, torch.Tensor],
    device: str,
    drop_et: bool,
    n_warmup: int = 5,
    n_iters: int = 20,
) -> Tuple[float, float]:
    """
    Returns (latency_ms_per_window, control_rate_hz) on a single device.
    Uses forward_supervised only (no policy logic), which dominates runtime.
    """
    model = model.to(device)
    model.eval()

    x_eeg = batch["eeg"].to(device)
    x_emg = batch["emg"].to(device)
    x_et  = batch["et"].to(device)

    eeg_psd = batch.get("eeg_psd")
    if eeg_psd is not None:
        eeg_psd = eeg_psd.to(device)

    emg_feat = batch.get("emg_feat")
    if emg_feat is not None:
        emg_feat = emg_feat.to(device)

    if drop_et:
        # S3-style runtime (ET failed → zeroed)
        x_et = torch.zeros_like(x_et)

    B = x_eeg.size(0)
    if B == 0:
        return 0.0, 0.0

    def _forward():
        with torch.no_grad():
            _ = model.forward_supervised(
                x_eeg, x_emg, x_et,
                eeg_psd=eeg_psd,
                emg_feat=emg_feat,
            )

    # Warm-up
    for _ in range(n_warmup):
        _forward()
        if device.startswith("cuda"):
            torch.cuda.synchronize()

    # Timed runs
    start = time.perf_counter()
    for _ in range(n_iters):
        _forward()
        if device.startswith("cuda"):
            torch.cuda.synchronize()
    end = time.perf_counter()

    avg_batch_time = (end - start) / float(n_iters)
    latency_ms = (avg_batch_time / float(B)) * 1000.0
    ctrl_hz = 1000.0 / latency_ms if latency_ms > 0 else 0.0

    return latency_ms, ctrl_hz


def benchmark_runtime_latency(
    eeg_ch: int,
    emg_ch: int,
    et_ch: int,
    folds: List[int],
) -> List[Dict]:
    """
    Build rows for 'Table Y – Computational cost and latency of TriSaFe-Trans'.
    Uses the first balanced fold to get realistic window shapes.
    """
    if not folds:
        return []

    fold0 = folds[0]
    train_ds0, _, test_ds0 = make_supervised_datasets_for_fold(fold0)

    eeg_psd_dim = getattr(train_ds0, "eeg_psd_dim", 0)
    emg_feat_dim = getattr(train_ds0, "emg_feat_dim", 0)
    use_eeg_psd = (eeg_psd_dim > 0) and CFG.USE_EEG_PSD_FEATURES
    use_emg_feat = (emg_feat_dim > 0) and CFG.USE_EMG_FEATURES

    # Small test batch (works on CPU / Jetson too)
    runtime_loader = make_dataloader(
        test_ds0,
        batch_size=min(32, CFG.SUP_BATCH),
        shuffle=False,
    )
    runtime_batch = next(iter(runtime_loader))

    # Base model on CPU (weights don't matter for latency / param count)
    model_cpu = TriModalSafetyTransformer(
        eeg_ch=eeg_ch,
        emg_ch=emg_ch,
        et_ch=et_ch,
        num_task_classes=train_ds0.num_task_classes,
        d_model=CFG.D_MODEL,
        dropout=CFG.DROPOUT,
        use_eeg_psd=use_eeg_psd,
        use_emg_feat=use_emg_feat,
        eeg_psd_dim=eeg_psd_dim,
        emg_feat_dim=emg_feat_dim,
        eeg_psd_mean=None,
        eeg_psd_std=None,
        emg_feat_mean=None,
        emg_feat_std=None,
    ).cpu()

    params_m = count_parameters(model_cpu) / 1e6

    # CPU latency (representative of Jetson-class hardware)
    cpu_lat_s0, cpu_rate_s0 = _measure_latency_one_device(
        model_cpu, runtime_batch, device="cpu", drop_et=False,
    )
    cpu_lat_s3, cpu_rate_s3 = _measure_latency_one_device(
        model_cpu, runtime_batch, device="cpu", drop_et=True,
    )

    rows = [
        {
            "model": "TriSaFe-Trans",
            "scenario": "S0: EEG+EMG+ET",
            "params_m": params_m,
            "flops_per_window": "",  # optional, fill manually if you want
            "latency_ms_gpu": "",
            "control_rate_hz_gpu": "",
            "latency_ms_cpu": cpu_lat_s0,
            "control_rate_hz_cpu": cpu_rate_s0,
        },
        {
            "model": "TriSaFe-Trans",
            "scenario": "S3: EEG+EMG (ET dropped)",
            "params_m": params_m,
            "flops_per_window": "",
            "latency_ms_gpu": "",
            "control_rate_hz_gpu": "",
            "latency_ms_cpu": cpu_lat_s3,
            "control_rate_hz_cpu": cpu_rate_s3,
        },
    ]

    # Optional GPU timings (desktop / cluster)
    if torch.cuda.is_available():
        model_gpu = TriModalSafetyTransformer(
            eeg_ch=eeg_ch,
            emg_ch=emg_ch,
            et_ch=et_ch,
            num_task_classes=train_ds0.num_task_classes,
            d_model=CFG.D_MODEL,
            dropout=CFG.DROPOUT,
            use_eeg_psd=use_eeg_psd,
            use_emg_feat=use_emg_feat,
            eeg_psd_dim=eeg_psd_dim,
            emg_feat_dim=emg_feat_dim,
            eeg_psd_mean=None,
            eeg_psd_std=None,
            emg_feat_mean=None,
            emg_feat_std=None,
        ).to("cuda")

        gpu_lat_s0, gpu_rate_s0 = _measure_latency_one_device(
            model_gpu, runtime_batch, device="cuda", drop_et=False,
        )
        gpu_lat_s3, gpu_rate_s3 = _measure_latency_one_device(
            model_gpu, runtime_batch, device="cuda", drop_et=True,
        )

        rows[0]["latency_ms_gpu"] = gpu_lat_s0
        rows[0]["control_rate_hz_gpu"] = gpu_rate_s0
        rows[1]["latency_ms_gpu"] = gpu_lat_s3
        rows[1]["control_rate_hz_gpu"] = gpu_rate_s3

    # NOTE: If you later add a TriModal-GRU baseline class, you can call the
    # same helpers here and append another row with model="TriModal-GRU".
    return rows


#================
# ---------------- MAIN RUNNER (BioRob RQ) ----------------

def run_biorob_phase6():
    print("========== Phase 6 (BioRob) — SSL → Safety-Aware LOSO ==========")
    print(f"Device: {CFG.DEVICE}")
    print(f"Dataset dir: {CFG.DATASET_DIR}")
    print(f"Task codes (with rest): {CFG.TASK_CODES}")

    if CFG.USE_SSL:
        ssl_dataset, eeg_ch, emg_ch, et_ch = make_ssl_dataset()
        print(
            f"SSL dataset size: {len(ssl_dataset)}, "
            f"channels: EEG={eeg_ch}, EMG={emg_ch}, ET={et_ch}"
        )
        base_state = pretrain_ssl(eeg_ch, emg_ch, et_ch, ssl_dataset)
    else:
        print("[Phase 6] USE_SSL=False → skipping SSL pretraining.")
        base_state = None
        folds_tmp = discover_folds(CFG.BALANCED_PREFIX)
        if not folds_tmp:
            raise SystemExit(
                f"No balanced folds found with prefix {CFG.BALANCED_PREFIX}_fold*"
            )
        first_fold = folds_tmp[0]
        tmp_train_ds, _, _ = make_supervised_datasets_for_fold(first_fold)
        eeg_ch = tmp_train_ds.eeg_ch
        emg_ch = tmp_train_ds.emg_ch
        et_ch  = tmp_train_ds.et_ch
        print(
            f"[Phase 6] Channels inferred from balanced fold {first_fold}: "
            f"EEG={eeg_ch}, EMG={emg_ch}, ET={et_ch}"
        )

    folds = discover_folds(CFG.BALANCED_PREFIX)
    if not folds:
        raise SystemExit(
            f"No balanced folds found with prefix {CFG.BALANCED_PREFIX}_fold*"
        )
    print(f"LOSO folds: {folds}")

    all_results: Dict[str, List[Dict]] = {name: [] for name in ABLATIONS.keys()}

    for ablation_name, flags in ABLATIONS.items():
        use_eeg = flags["use_eeg"]
        use_emg = flags["use_emg"]
        use_et  = flags["use_et"]
        print(
            f"\n\n############################################################\n"
            f"### Ablation: {ablation_name} (EEG={use_eeg}, EMG={use_emg}, ET={use_et})\n"
            f"############################################################"
        )

        for fid in folds:
            res = train_one_fold(
                fold_id=fid,
                base_state=base_state,
                eeg_ch=eeg_ch,
                emg_ch=emg_ch,
                et_ch=et_ch,
                ablation_name=ablation_name,
                use_eeg=use_eeg,
                use_emg=use_emg,
                use_et=use_et,
            )
            all_results[ablation_name].append(res)

    # ---------------- SUMMARY (MSG, MRS, GMUE, Safety) ----------------

    summary = {
        "folds": folds,
        "cfg": {
            "use_ssl": CFG.USE_SSL,
            "ssl_epochs": CFG.SSL_EPOCHS,
            "sup_epochs": CFG.SUP_EPOCHS,
            "d_model": CFG.D_MODEL,
            "dropout": CFG.DROPOUT,
            "topk": CFG.TOPK,
            "task_codes": CFG.TASK_CODES,
            "scenarios": SCENARIOS,
            "policies": POLICIES,
        },
        "ablations": {},
        "MSG": {},
        "MRS": {},
        "GMUE": {},
        "safety": {},
        "gating": {},
        "P2_tradeoff_curves": {},
        "confusion_matrices": {},
    }
    # --- Per-ablation classic metrics (for your "one small table") ---
    table1_rows = []

    for ablation_name, res_list in all_results.items():
        # Collect per-fold metrics
        action_accs       = [r["classic"]["action_acc"]       for r in res_list]
        action_bal        = [r["classic"]["action_bal_acc"]   for r in res_list]
        action_f1_macro   = [r["classic"]["action_macro_f1"]  for r in res_list]
        action_f1_micro   = [r["classic"]["action_micro_f1"]  for r in res_list]
        action_auc        = [r["classic"]["action_auc_roc"]   for r in res_list
                             if r["classic"]["action_auc_roc"] is not None]
        action_auprc      = [r["classic"]["action_auprc"]     for r in res_list
                             if r["classic"]["action_auprc"] is not None]

        task_accs         = [r["classic"]["task_acc"]         for r in res_list]
        task_bal          = [r["classic"]["task_bal_acc"]     for r in res_list]
        task_f1_macro     = [r["classic"]["task_macro_f1"]    for r in res_list]
        task_f1_micro     = [r["classic"]["task_micro_f1"]    for r in res_list]
        top1_vals         = [r["classic"]["task_topk"].get("top_1", 0.0)
                             for r in res_list]
        top3_vals         = [r["classic"]["task_topk"].get("top_3", 0.0)
                             for r in res_list]

        def _mean_std(xs):
            if len(xs) == 0:
                return (None, None)
            return (float(np.mean(xs)), float(np.std(xs)))

        a_acc_mean, a_acc_std     = _mean_std(action_accs)
        a_bal_mean, a_bal_std     = _mean_std(action_bal)
        a_f1M_mean, a_f1M_std     = _mean_std(action_f1_macro)
        a_f1m_mean, a_f1m_std     = _mean_std(action_f1_micro)
        a_auc_mean, a_auc_std     = _mean_std(action_auc)
        a_auprc_mean, a_auprc_std = _mean_std(action_auprc)

        t_acc_mean, t_acc_std     = _mean_std(task_accs)
        t_bal_mean, t_bal_std     = _mean_std(task_bal)
        t_f1M_mean, t_f1M_std     = _mean_std(task_f1_macro)
        t_f1m_mean, t_f1m_std     = _mean_std(task_f1_micro)
        t_top1_mean, t_top1_std   = _mean_std(top1_vals)
        t_top3_mean, t_top3_std   = _mean_std(top3_vals)

        summary["ablations"][ablation_name] = {
            "action": {
                "mean_acc": a_acc_mean,
                "std_acc": a_acc_std,
                "mean_bal_acc": a_bal_mean,
                "std_bal_acc": a_bal_std,
                "macro_f1_mean": a_f1M_mean,
                "macro_f1_std": a_f1M_std,
                "micro_f1_mean": a_f1m_mean,
                "micro_f1_std": a_f1m_std,
                "auc_mean": a_auc_mean,
                "auc_std": a_auc_std,
                "auprc_mean": a_auprc_mean,
                "auprc_std": a_auprc_std,
            },
            "task": {
                "mean_acc": t_acc_mean,
                "std_acc": t_acc_std,
                "mean_bal_acc": t_bal_mean,
                "std_bal_acc": t_bal_std,
                "macro_f1_mean": t_f1M_mean,
                "macro_f1_std": t_f1M_std,
                "micro_f1_mean": t_f1m_mean,
                "micro_f1_std": t_f1m_std,
                "top1_mean": t_top1_mean,
                "top1_std": t_top1_std,
                "top3_mean": t_top3_mean,
                "top3_std": t_top3_std,
            },
        }

        print(
            f"\n[Ablation={ablation_name}] Action acc={a_acc_mean:.3f}±{a_acc_std:.3f}, "
            f"bal-acc={a_bal_mean:.3f}±{a_bal_std:.3f}"
        )
        print(
            f"[Ablation={ablation_name}] Task   acc={t_acc_mean:.3f}±{t_acc_std:.3f}, "
            f"bal-acc={t_bal_mean:.3f}±{t_bal_std:.3f}"
        )

        # Row for Table 1
        table1_rows.append([
            ablation_name,
            a_acc_mean, a_acc_std,
            a_f1M_mean, a_f1M_std,
            a_f1m_mean, a_f1m_std,
            a_bal_mean, a_bal_std,
            a_auc_mean, a_auc_std,
            a_auprc_mean, a_auprc_std,
            t_acc_mean, t_acc_std,
            t_f1M_mean, t_f1M_std,
            t_f1m_mean, t_f1m_std,
            t_bal_mean, t_bal_std,
            t_top1_mean, t_top1_std,
            t_top3_mean, t_top3_std,
        ])

    # --- MSG (Multimodal Synergy Gain) from ablations (using balanced acc) ---
    if all(k in summary["ablations"] for k in ("all", "eeg", "emg", "et")):
        ba_all_task = summary["ablations"]["all"]["task"]["mean_bal_acc"]
        ba_eeg_task = summary["ablations"]["eeg"]["task"]["mean_bal_acc"]
        ba_emg_task = summary["ablations"]["emg"]["task"]["mean_bal_acc"]
        ba_et_task  = summary["ablations"]["et"]["task"]["mean_bal_acc"]
        msg_task = float(ba_all_task - max(ba_eeg_task, ba_emg_task, ba_et_task))

        ba_all_action = summary["ablations"]["all"]["action"]["mean_bal_acc"]
        ba_eeg_action = summary["ablations"]["eeg"]["action"]["mean_bal_acc"]
        ba_emg_action = summary["ablations"]["emg"]["action"]["mean_bal_acc"]
        ba_et_action  = summary["ablations"]["et"]["action"]["mean_bal_acc"]
        msg_action = float(ba_all_action - max(ba_eeg_action, ba_emg_action, ba_et_action))

        summary["MSG"] = {
            "task_bal_acc": msg_task,
            "action_bal_acc": msg_action,
        }
        print(
            f"\nMSG (task bal-acc)   : {msg_task:.3f}\n"
            f"MSG (action bal-acc) : {msg_action:.3f}"
        )
    else:
        msg_action = None
        msg_task = None

    # --- GMUE (mean gating entropy, S0) ---
    gmue_vals = []
    for r in all_results["all"]:
        g = r.get("gmue", None)
        if g is not None:
            gmue_vals.append(g)
    if gmue_vals:
        gmue_mean = float(np.mean(gmue_vals))
        gmue_std  = float(np.std(gmue_vals))
        summary["GMUE"] = {"mean": gmue_mean, "std": gmue_std}
        print(f"\nGMUE (gating entropy, S0, all-fold mean±std): {gmue_mean:.4f}±{gmue_std:.4f}")

    # --- Safety metrics (SAE, UAR, MIR, NRA) across folds for each scenario/policy ---
    safety_summary = {}
    for scen_name in SCENARIOS.keys():
        scen_entry = {"policies": {}}
        for policy in POLICIES:
            SAE_vals = []
            UAR_vals = []
            MIR_vals = []
            NRA_vals = []

            for r in all_results["all"]:
                s = r.get("safety", None)
                if s is None:
                    continue
                scen_dict = s["scenarios"].get(scen_name, None)
                if scen_dict is None:
                    continue
                p_entry = scen_dict["policies"].get(policy, None)
                if p_entry is None:
                    continue
                saf = p_entry["safety"]
                SAE_vals.append(saf["SAE"])
                UAR_vals.append(saf["UAR"])
                MIR_vals.append(saf["MIR"])
                NRA_vals.append(saf["NRA"])

            if SAE_vals:
                scen_entry["policies"][policy] = {
                    "SAE_mean": float(np.mean(SAE_vals)),
                    "SAE_std": float(np.std(SAE_vals)),
                    "UAR_mean": float(np.mean(UAR_vals)),
                    "UAR_std": float(np.std(UAR_vals)),
                    "MIR_mean": float(np.mean(MIR_vals)),
                    "MIR_std": float(np.std(MIR_vals)),
                    "NRA_mean": float(np.mean(NRA_vals)),
                    "NRA_std": float(np.std(NRA_vals)),
                }

        safety_summary[scen_name] = scen_entry
    summary["safety"] = safety_summary

    print("\nSafety metrics (SAE/UAR/MIR/NRA) — mean across folds:")
    for scen_name, scen_entry in safety_summary.items():
        print(f"  Scenario {scen_name} ({SCENARIOS[scen_name]['description']}):")
        for policy, vals in scen_entry["policies"].items():
            print(
                f"    {policy}: SAE={vals['SAE_mean']:.3f}, UAR={vals['UAR_mean']:.3f}, "
                f"MIR={vals['MIR_mean']:.3f}, NRA={vals['NRA_mean']:.3f}"
            )

    # --- Gating summary (for Fig. 8) ---
    gating_summary = {}
    for scen_name in SCENARIOS.keys():
        gate_means = []
        for r in all_results["all"]:
            s = r.get("safety", None)
            if s is None:
                continue
            scen_dict = s["scenarios"].get(scen_name, None)
            if scen_dict is None:
                continue
            gmean = scen_dict.get("gates_mean", None)
            if gmean is not None:
                gate_means.append(np.array(gmean, dtype=float))
        if gate_means:
            gate_means_arr = np.stack(gate_means, axis=0)
            mean_gate = gate_means_arr.mean(axis=0)
            std_gate = gate_means_arr.std(axis=0)
            gating_summary[scen_name] = {
                "mean_gate": mean_gate.tolist(),   # [EEG, EMG, ET]
                "std_gate": std_gate.tolist(),
            }
    summary["gating"] = gating_summary

    # --- MRS (Modality-Robustness Score) from S0 vs S1–S3 for P0 ---
    mrs_task = {}
    mrs_action = {}
    for drop, scen in zip(["drop_eeg", "drop_emg", "drop_et"], ["S1", "S2", "S3"]):
        base_ba_action = []
        drop_ba_action = []
        base_ba_task   = []
        drop_ba_task   = []
        for r in all_results["all"]:
            s = r.get("safety", None)
            if s is None:
                continue
            scen0 = s["scenarios"].get("S0", None)
            scenD = s["scenarios"].get(scen, None)
            if scen0 is None or scenD is None:
                continue
            p0_0 = scen0["policies"]["P0"]["classic"]
            p0_D = scenD["policies"]["P0"]["classic"]
            base_ba_action.append(p0_0["action_bal_acc"])
            drop_ba_action.append(p0_D["action_bal_acc"])
            base_ba_task.append(p0_0["task_bal_acc"])
            drop_ba_task.append(p0_D["task_bal_acc"])
        if base_ba_action:
            mrs_action[drop] = float(np.mean(np.array(base_ba_action) - np.array(drop_ba_action)))
            mrs_task[drop]   = float(np.mean(np.array(base_ba_task)   - np.array(drop_ba_task)))

    summary["MRS"] = {
        "action": mrs_action,
        "task": mrs_task,
    }
    print("\nMRS (Modality-Robustness Score, bal-acc drop under P0):")
    print("  Action:", mrs_action)
    print("  Task  :", mrs_task)

    # --- P2 trade-off curves (safety vs missed intent) averaged across folds ---
    p2_curves_summary = {}
    for scen_name in SCENARIOS.keys():
        sae_mat = []
        mir_mat = []
        thresholds = None
        for r in all_results["all"]:
            s = r.get("safety", None)
            if s is None:
                continue
            scen_dict = s["scenarios"].get(scen_name, None)
            if scen_dict is None:
                continue
            curve = scen_dict.get("P2_curve", None)
            if curve is None:
                continue
            th = curve["thresholds"]
            sae = curve["SAE"]
            mir = curve["MIR"]
            if thresholds is None:
                thresholds = th
            sae_mat.append(sae)
            mir_mat.append(mir)
        if thresholds is not None and sae_mat:
            sae_arr = np.array(sae_mat)
            mir_arr = np.array(mir_mat)
            p2_curves_summary[scen_name] = {
                "thresholds": thresholds,
                "SAE_mean": sae_arr.mean(axis=0).tolist(),
                "SAE_std": sae_arr.std(axis=0).tolist(),
                "MIR_mean": mir_arr.mean(axis=0).tolist(),
                "MIR_std": mir_arr.std(axis=0).tolist(),
            }

    summary["P2_tradeoff_curves"] = p2_curves_summary

    # --- Aggregate confusion matrices across folds (tri-modal, S0/P0) ---
    cm_action_agg = None
    cm_task_agg = None
    for r in all_results["all"]:
        cm_a = np.array(r["classic"]["cm_action"], dtype=int)
        cm_t = np.array(r["classic"]["cm_task"], dtype=int)
        if cm_action_agg is None:
            cm_action_agg = cm_a
            cm_task_agg = cm_t
        else:
            cm_action_agg += cm_a
            cm_task_agg += cm_t

    if cm_action_agg is not None:
        summary["confusion_matrices"]["S0P0_all"] = {
            "action_cm": cm_action_agg.tolist(),
            "task_cm": cm_task_agg.tolist(),
        }

    # --- Create table & fig-data CSVs ---
    tables_dir = CFG.DATASET_DIR / "phase6_tables"
    figdata_dir = CFG.DATASET_DIR / "phase6_figdata"
    tables_dir.mkdir(exist_ok=True, parents=True)
    figdata_dir.mkdir(exist_ok=True, parents=True)

    # Table 1 – Main LOSO performance (S0 / P0, all ablations)
    table1_path = tables_dir / "phase6_table1_loso_main_metrics.csv"
    with open(table1_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "ablation",
            "action_acc_mean", "action_acc_std",
            "action_macro_f1_mean", "action_macro_f1_std",
            "action_micro_f1_mean", "action_micro_f1_std",
            "action_bal_acc_mean", "action_bal_acc_std",
            "action_auc_roc_mean", "action_auc_roc_std",
            "action_auprc_mean", "action_auprc_std",
            "task_acc_mean", "task_acc_std",
            "task_macro_f1_mean", "task_macro_f1_std",
            "task_micro_f1_mean", "task_micro_f1_std",
            "task_bal_acc_mean", "task_bal_acc_std",
            "task_top1_mean", "task_top1_std",
            "task_top3_mean", "task_top3_std",
            "MSG_action_bal_acc", "MSG_task_bal_acc",
        ])
        for row in table1_rows:
            # Extend with MSG columns (blank for ablation rows)
            w.writerow(row + ["", ""])
        # MSG row at bottom
        w.writerow([
            "MSG", "", "", "", "", "", "", "", "",
            "", "", "", "", "", "", "", "",
            "", "", "", "", "", "", "",
            msg_action if msg_action is not None else "",
            msg_task if msg_task is not None else "",
        ])

    # Table 2 – Safety metrics per scenario & policy
    table2_path = tables_dir / "phase6_table2_safety_policies.csv"
    with open(table2_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "scenario", "policy",
            "SAE_mean", "SAE_std",
            "UAR_mean", "UAR_std",
            "MIR_mean", "MIR_std",
            "NRA_mean", "NRA_std",
        ])
        for scen_name, scen_entry in safety_summary.items():
            for policy, vals in scen_entry["policies"].items():
                w.writerow([
                    scen_name, policy,
                    vals["SAE_mean"], vals["SAE_std"],
                    vals["UAR_mean"], vals["UAR_std"],
                    vals["MIR_mean"], vals["MIR_std"],
                    vals["NRA_mean"], vals["NRA_std"],
                ])

    # Table 3 – MRS + GMUE
    table3_path = tables_dir / "phase6_table3_mrs_gmue.csv"
    gmue_mean = summary.get("GMUE", {}).get("mean", None)
    gmue_std  = summary.get("GMUE", {}).get("std", None)
    drop2mod = {"drop_eeg": "EEG", "drop_emg": "EMG", "drop_et": "ET"}
    with open(table3_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "modality_dropped",
            "MRS_action_bal_acc_drop",
            "MRS_task_bal_acc_drop",
            "GMUE_mean_S0",
            "GMUE_std_S0",
        ])
        for drop_key, m_action in mrs_action.items():
            m_task = mrs_task.get(drop_key, None)
            w.writerow([
                drop2mod.get(drop_key, drop_key),
                m_action,
                m_task,
                gmue_mean,
                gmue_std,
            ])

    # Table A2 – Per-task metrics (tri-modal, S0/P0)
    tableA2_path = tables_dir / "phase6_tableA2_per_task_metrics_all.csv"
    # Use task2idx mapping from first tri-modal result
    idx2task = {}
    if all_results["all"]:
        task2idx = all_results["all"][0]["task2idx"]
        idx2task = {idx: code for code, idx in task2idx.items()}

    per_class_agg = {}
    for r in all_results["all"]:
        per_class = r["classic"]["per_class_task"]
        for cls_idx, m in per_class.items():
            cls_idx = int(cls_idx)
            if cls_idx not in per_class_agg:
                per_class_agg[cls_idx] = {
                    "precision": [], "recall": [], "f1": [],
                    "balanced_accuracy": [], "support": 0,
                }
            per_class_agg[cls_idx]["precision"].append(m["precision"])
            per_class_agg[cls_idx]["recall"].append(m["recall"])
            per_class_agg[cls_idx]["f1"].append(m["f1"])
            per_class_agg[cls_idx]["balanced_accuracy"].append(m["balanced_accuracy"])
            per_class_agg[cls_idx]["support"] += m["support"]

    with open(tableA2_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "class_index", "task_code",
            "precision_mean", "recall_mean",
            "f1_mean", "balanced_accuracy_mean",
            "support_total",
        ])
        for cls_idx in sorted(per_class_agg.keys()):
            agg = per_class_agg[cls_idx]
            prec_mean = float(np.mean(agg["precision"])) if agg["precision"] else 0.0
            rec_mean  = float(np.mean(agg["recall"])) if agg["recall"] else 0.0
            f1_mean   = float(np.mean(agg["f1"])) if agg["f1"] else 0.0
            bal_mean  = float(np.mean(agg["balanced_accuracy"])) if agg["balanced_accuracy"] else 0.0
            support   = int(agg["support"])
            task_code = idx2task.get(cls_idx, None)
            w.writerow([
                cls_idx, task_code,
                prec_mean, rec_mean,
                f1_mean, bal_mean,
                support,
            ])

    # Table A3 – Action & task bal-acc under S0–S3 (tri-modal, P0)
    tableA3_path = tables_dir / "phase6_tableA3_bal_acc_per_scenario_P0.csv"
    # Table Y2 – Safety error breakdown for S0 / P2 (NEW)
    tableY2_path = tables_dir / "phase6_tableY2_error_breakdown_S0P2.csv"
    total_counts = {
        "total": 0,
        "total_action": 0,
        "total_rest": 0,
        "correct_move": 0,
        "unsafe_move": 0,
        "missed_intent": 0,
        "safe_idle": 0,
    }
    for r in all_results["all"]:
        s = r.get("safety", None)
        if s is None:
            continue
        scen0 = s["scenarios"].get("S0", None)
        if scen0 is None:
            continue
        p2 = scen0["policies"].get("P2", None)
        if p2 is None:
            continue
        counts = p2["safety"]["counts"]
        for k in total_counts.keys():
            if k in counts:
                total_counts[k] += counts[k]

    with open(tableY2_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "scenario",
            "policy",
            "total",
            "total_action",
            "total_rest",
            "correct_move",
            "unsafe_move",
            "missed_intent",
            "safe_idle",
        ])
        w.writerow([
            "S0",
            "P2",
            total_counts["total"],
            total_counts["total_action"],
            total_counts["total_rest"],
            total_counts["correct_move"],
            total_counts["unsafe_move"],
            total_counts["missed_intent"],
            total_counts["safe_idle"],
        ])
    print(f"Safety error breakdown table saved to: {tableY2_path}")

    with open(tableA3_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "scenario",
            "action_bal_acc_mean", "action_bal_acc_std",
            "task_bal_acc_mean", "task_bal_acc_std",
        ])
        for scen_name in ["S0", "S1", "S2", "S3"]:
            act_ba = []
            task_ba = []
            for r in all_results["all"]:
                s = r.get("safety", None)
                if s is None:
                    continue
                scen = s["scenarios"].get(scen_name, None)
                if scen is None:
                    continue
                classic = scen["policies"]["P0"]["classic"]
                act_ba.append(classic["action_bal_acc"])
                task_ba.append(classic["task_bal_acc"])
            if act_ba:
                w.writerow([
                    scen_name,
                    float(np.mean(act_ba)), float(np.std(act_ba)),
                    float(np.mean(task_ba)), float(np.std(task_ba)),
                ])

    # Fig. 6 – P2 trade-off curves (one CSV per scenario)
    for scen_name, curve in summary["P2_tradeoff_curves"].items():
        path = figdata_dir / f"phase6_fig6_P2_tradeoff_{scen_name}.csv"
        with open(path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["threshold", "SAE_mean", "SAE_std", "MIR_mean", "MIR_std"])
            for th, sae_m, sae_s, mir_m, mir_s in zip(
                curve["thresholds"],
                curve["SAE_mean"], curve["SAE_std"],
                curve["MIR_mean"], curve["MIR_std"],
            ):
                w.writerow([th, sae_m, sae_s, mir_m, mir_s])

    # Fig. 7 – Confusion matrices (aggregated across folds)
    if cm_action_agg is not None:
        np.savetxt(
            figdata_dir / "phase6_fig7_cm_action_S0P0_all.csv",
            cm_action_agg, fmt="%d", delimiter=","
        )
        np.savetxt(
            figdata_dir / "phase6_fig7_cm_task_S0P0_all.csv",
            cm_task_agg, fmt="%d", delimiter=","
        )

    # Fig. 8 – Gating distribution across modalities (per scenario)
    fig8_path = figdata_dir / "phase6_fig8_gating_distribution.csv"
    with open(fig8_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "scenario",
            "EEG_mean_gate", "EMG_mean_gate", "ET_mean_gate",
            "EEG_std_gate", "EMG_std_gate", "ET_std_gate",
        ])
        for scen_name, gvals in summary["gating"].items():
            mg = gvals["mean_gate"]
            sg = gvals["std_gate"]
            w.writerow([
                scen_name,
                mg[0], mg[1], mg[2],
                sg[0], sg[1], sg[2],
            ])
    # Table Y – Runtime / latency of TriSaFe-Trans (NEW)
    try:
        runtime_rows = benchmark_runtime_latency(eeg_ch, emg_ch, et_ch, folds)
        if runtime_rows:
            tableY_path = tables_dir / "phase6_tableY_runtime_latency.csv"
            with open(tableY_path, "w", newline="") as f:
                w = csv.writer(f)
                w.writerow([
                    "model",
                    "scenario",
                    "params_m",
                    "flops_per_window",
                    "latency_ms_gpu",
                    "control_rate_hz_gpu",
                    "latency_ms_cpu",
                    "control_rate_hz_cpu",
                ])
                for r in runtime_rows:
                    w.writerow([
                        r["model"],
                        r["scenario"],
                        f"{r['params_m']:.3f}",
                        r["flops_per_window"],
                        "" if r["latency_ms_gpu"] == "" else f"{r['latency_ms_gpu']:.2f}",
                        "" if r["control_rate_hz_gpu"] == "" else f"{r['control_rate_hz_gpu']:.1f}",
                        f"{r['latency_ms_cpu']:.2f}",
                        f"{r['control_rate_hz_cpu']:.1f}",
                    ])
            print(f"Runtime / latency table saved to: {tableY_path}")
    except Exception as e:
        print(f"[Runtime] Skipping runtime table due to error: {e}")

    # --- Save main JSON summary ---
    out_path = CFG.DATASET_DIR / "phase6_biorob_safety_summary.json"
    with open(out_path, "w") as f:
        json.dump(summary, f, indent=2)
    print(f"\nSaved BioRob Phase-6 safety summary to: {out_path}")
    print(f"Table CSVs saved to: {tables_dir}")
    print(f"Fig-data CSVs saved to: {figdata_dir}")


if __name__ == "__main__":
    run_biorob_phase6()

