In [None]:
# =============================================================================
# BioRob Phase-2 Labeler — HSMM (Single-Action)
# =============================================================================

from __future__ import annotations
import argparse, re, glob, json, warnings
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.signal import butter, sosfiltfilt
from scipy.stats import median_abs_deviation, norm
from matplotlib import gridspec
from scipy import signal as _sig

warnings.filterwarnings('ignore')

# =========================== USER PATHS ===========================
ROOT_DIR   = r"/home/tsultan1/BioRob(Final)/Data"
SYNC_SUBPATH = r"cleaned/synchronized_proper_lite_union_v3"
DEFAULT_SUBJECT_SYNC_DIR = Path(
    r"/home/tsultan1/BioRob(Final)/Data/Sub-23/cleaned/synchronized_proper_lite_union_v3"
)
FILE_GLOB = "*_synchronized_corrected.csv"
FS_HINT   = 250.0

SAVE_PLOTS_STANDARD  = False
SAVE_PLOTS_EXTENDED  = False
ADD_SPECTROGRAM      = False

FORCE_MODE = 'physical'   # 'physical'|'imagery'|None
MODE_POLICY = 'auto'      # 'auto'|'filename'|'infer'
USE_SKLEARN_GMM = False

# =============================== BASE PARAMS =========================
BASE = dict(
    smooth_win_default = 75,     # ~300 ms @ 250 Hz
    smooth_win_physical = 50,
    smooth_win_imagery  = 125,

    eye_valid_min   = 0.60,

    target_dur_physical = (0.8, 8.0),
    target_dur_imagery  = (0.5, 6.0),

    prob_scale_frac = 0.25,

    # boundary refiners (pre HSMM result)
    phys_backseek_s   = 0.75,  phys_forward_s = 0.75,
    phys_on_grad_thr  = 0.60,  phys_on_env_thr  = 0.30,
    phys_off_env_thr  = 0.25,  phys_min_sustain_s = 0.10,

    imag_backseek_s   = 1.00,  imag_forward_s = 0.75,
    imag_on_slope_thr = -0.40, imag_on_eeg_thr = -0.20,
    imag_off_abs_thr  = 0.20,  imag_min_sustain_s = 0.12,
)

MIN_BOUT_S = 0.20
MAX_GAP_S  = 0.10

# =================== SHIELDS & HARD CLAMPS ===================
PRE_BASELINE_S      = 2.0
POST_BASELINE_S     = 1.25

ON_Z_K              = 2.1
OFF_Z_K             = 1.8
MIN_SUSTAIN_ON_S    = 0.18
MIN_SUSTAIN_OFF_S   = 0.25
BACKOFF_ON_S        = 0.12
BACKOFF_OFF_S       = 0.06

# ---- HARD RULE WINDOWS ----
PRE_REST_MIN_S = 3.0     # earliest allowed end of pre-rest
PRE_REST_MAX_S = 4.0     # latest  allowed end of pre-rest
POST_REST_MIN_S = 1.0    # nominal min end-rest duration (can relax to 0.5s)
POST_REST_MAX_S = 4.0    # max end-rest duration

# ---- Tail quiet-window recovery (EMG-driven) ----
TAIL_EXTEND_MAX_S        = 2.5   # allow extending up to this much after HSMM offset
TAIL_QUIET_S             = 0.30  # quiet must persist this long
TAIL_BACKOFF_S           = 0.06  # step back a hair before the quiet starts
REST_K                   = 1.3   # rest envelope threshold = median + K*MAD (pre-rest baseline)
GRAD_QUIET_THR           = 0.45  # EMG gradient (z) must be below this to call it quiet
TAIL_OVERRIDE_MIN_REST_S = 0.5   # if action pushes near end, relax min tail-rest to 0.5s

# =============================== HELPERS =============================
def mad(x): 
    return 1.4826 * median_abs_deviation(x, nan_policy='omit')

def robust_z(x):
    x = np.asarray(x)
    m = np.nanmedian(x)
    s = mad(x)
    return np.zeros_like(x) if (s < 1e-12 or not np.isfinite(s)) else (x - m)/s

def rolling_mean(x, n):
    if n <= 1: return np.asarray(x)
    k = np.ones(int(n), float) / max(1, int(n))
    return np.convolve(np.asarray(x), k, mode='same')

def rolling_rms(x, fs, win_ms):
    n = max(1, int(win_ms/1000*fs))
    if n == 1: return np.abs(x)
    ker = np.ones(n)/n
    return np.sqrt(np.convolve(np.square(np.nan_to_num(x)), ker, 'same'))

def median_fs(t):
    dt = np.median(np.diff(t))
    return (1.0/dt) if np.isfinite(dt) and dt > 0 else FS_HINT

def _first_true_in_window(mask: np.ndarray, lo: int, hi: int) -> Optional[int]:
    lo = max(0, lo); hi = min(len(mask)-1, hi)
    idx = np.flatnonzero(mask[lo:hi+1])
    return (lo + int(idx[0])) if idx.size else None

def _last_true_in_window(mask: np.ndarray, lo: int, hi: int) -> Optional[int]:
    lo = max(0, lo); hi = min(len(mask)-1, hi)
    idx = np.flatnonzero(mask[lo:hi+1])
    return (lo + int(idx[-1])) if idx.size else None

import re as _re
def canonicalize_columns(df: pd.DataFrame) -> pd.DataFrame:
    mapping = {}
    for raw in df.columns:
        s = str(raw).replace("\ufeff","")
        s = _re.sub(r"\s+", " ", s).strip()
        if s.lower() == "timestamp_seconds": mapping[raw]="Timestamp_seconds"; continue
        if s.lower() == "timestamp_ms":      mapping[raw]="Timestamp_ms"; continue
        if s.lower() == "et_timesignal":     mapping[raw]="ET_TimeSignal"; continue
        m = _re.match(r'(?i)^ch\s*(\d+)\s*emg\s*raw$', s)
        if m: mapping[raw] = f"EMG_Ch{int(m.group(1))}"; continue
        m = _re.match(r'(?i)^ch\s*(\d+)\s*emg$', s)
        if m: mapping[raw] = f"EMG_Ch{int(m.group(1))}"; continue
        m = _re.match(r'(?i)^ch\s*(\d+)$', s)
        if m:
            ch = int(m.group(1))
            if 1 <= ch <= 32:
                mapping[raw] = f"EEG_Ch{ch}"; continue
        if s.startswith("ET_") or s.lower().startswith("et_"):
            name = s if s.startswith("ET_") else ("ET_" + s[3:])
            name = name.replace("GazeLeftX","GazeLeftx").replace("GazeLeftY","GazeLefty")
            name = name.replace("GazeRightX","GazeRightx").replace("GazeRightY","GazeRighty")
            mapping[raw] = name; continue
        mapping[raw] = s
    return df.rename(columns=mapping)

# ======================= FEATURE EXTRACTION ==========================
def _et_signal_audit(df: pd.DataFrame) -> dict:
    groups = {
        "pupil": ["ET_PupilLeft","ET_PupilRight"],
        "gaze_xy": ["ET_GazeLeftx","ET_GazeLefty","ET_GazeRightx","ET_GazeRighty"],
        "valid": ["ET_ValidityLeftEye","ET_ValidityRightEye"],
        "imu": ["ET_GyroX","ET_GyroY","ET_GyroZ","ET_AccX","ET_AccY","ET_AccZ",
                "ET_HeadRotationPitch","ET_HeadRotationYaw","ET_HeadRotationRoll"],
        "events": ["ET_Blink","ET_Fixation","ET_Worn"],
        "fallback_axis": ["ET_Gaze3dOpticalAxisXLeft","ET_Gaze3dOpticalAxisYLeft","ET_Gaze3dOpticalAxisZLeft"]
    }
    return {k: [c for c in v if c in df.columns] for k, v in groups.items()}

def _eye_activity(df, fs):
    if {'ET_GazeLeftx','ET_GazeLefty'}.issubset(df.columns):
        gx = df['ET_GazeLeftx'].astype(float).ffill().bfill().to_numpy()
        gy = df['ET_GazeLefty'].astype(float).ffill().bfill().to_numpy()
        vx = np.diff(gx, prepend=gx[0]) * fs
        vy = np.diff(gy, prepend=gy[0]) * fs
        spd = np.sqrt(vx**2 + vy**2)
        return robust_z(rolling_rms(spd, fs, 50))
    cols = ['ET_Gaze3dOpticalAxisXLeft','ET_Gaze3dOpticalAxisYLeft','ET_Gaze3dOpticalAxisZLeft']
    if set(cols).issubset(df.columns):
        X = df[cols].astype(float).ffill().bfill().to_numpy()
        nrm = np.linalg.norm(X, axis=1, keepdims=True) + 1e-12
        V = X / nrm
        dots = np.clip(np.sum(V[1:] * V[:-1], axis=1), -1.0, 1.0)
        dtheta = np.arccos(dots)
        ang_spd = np.r_[0.0, dtheta * fs]
        return robust_z(rolling_rms(ang_spd, fs, 50))
    if 'ET_Blink' in df.columns:
        b = df['ET_Blink'].astype(float).fillna(0.0).to_numpy()
        feat = rolling_rms(np.diff(b, prepend=b[0]) * fs, fs, 100)
        return robust_z(feat)
    return np.zeros(len(df), float)

def _eye_valid_rate(df) -> float:
    if {'ET_ValidityLeftEye','ET_ValidityRightEye'}.issubset(df.columns):
        vl = (df['ET_ValidityLeftEye'].values == 0)
        vr = (df['ET_ValidityRightEye'].values == 0)
        return float(np.mean(vl & vr))
    return 0.4 if any(c.startswith('ET_') for c in df.columns) else 0.0

def extract_features(df, fs):
    t = df['Timestamp_seconds'].to_numpy()
    nyq = fs/2

    eeg_cols = [c for c in [f'EEG_Ch{i}' for i in range(1,9)] if c in df.columns]
    if eeg_cols:
        sos_mrcp  = butter(4, [0.05/nyq, 3.0/nyq], btype='band', output='sos')
        sos_beta  = butter(4, [13/nyq, 30/nyq], btype='band', output='sos')
        eeg_feats = []; eeg_slope_feats=[]
        for ch in eeg_cols:
            x = df[ch].astype(float).ffill().bfill().to_numpy()
            mrcp = sosfiltfilt(sos_mrcp, x)
            mrcp_slope = -np.gradient(mrcp, 1/fs)
            eeg_slope_feats.append(robust_z(mrcp_slope))
            beta = sosfiltfilt(sos_beta, x)
            beta_power = rolling_rms(beta, fs, 200)
            eeg_feats.append(0.7*eeg_slope_feats[-1] + 0.3*(-robust_z(beta_power)))
        eeg = np.mean(eeg_feats, axis=0)
        eeg_slope = np.mean(eeg_slope_feats, axis=0)
    else:
        eeg = eeg_slope = np.zeros_like(t)

    emg_cols = [c for c in [f'EMG_Ch{i}' for i in range(1,5)] if c in df.columns]
    if emg_cols:
        envs = []; raws = []
        for ch in emg_cols:
            x = df[ch].astype(float).fillna(0.0).to_numpy()
            raws.append(x); envs.append(robust_z(rolling_rms(x, fs, 50)))
        emg = np.mean(envs, axis=0)
        emg_raw_mean = np.mean(raws, axis=0)
        emg_grad = robust_z(rolling_rms(np.gradient(emg_raw_mean, 1/fs), fs, 30))
    else:
        emg = emg_grad = np.zeros_like(t)

    eye = _eye_activity(df, fs)
    eye_valid = _eye_valid_rate(df)

    def rel_from_signal(sig):
        if sig.size == 0: return 0.0
        spread = float(np.nanpercentile(sig,95) - np.nanpercentile(sig,50))
        denom = mad(sig) + 1e-9
        val = max(0.0, spread/denom)
        return val if np.isfinite(val) else 0.0

    R = dict(emg=rel_from_signal(emg),
             eeg=rel_from_signal(eeg_slope),
             eye=rel_from_signal(eye) * eye_valid)

    return dict(t=t, eeg=eeg, eeg_slope=eeg_slope, emg=emg, emg_grad=emg_grad,
                eye=eye, R=R, eye_valid=eye_valid)

# ======================= MODE & FUSION ===============================
def mode_from_filename(stem: str) -> Optional[str]:
    m = re.search(r'(?:^|_)(T|M)(\d{2,})', stem, flags=re.IGNORECASE)
    if not m:
        return None
    return 'physical' if m.group(1).upper() == 'T' else 'imagery'


def infer_mode_from_signals(feats) -> str:
    emg = feats['emg']; eeg_slope = feats['eeg_slope']
    emg_p95, emg_med = np.nanpercentile(emg,95), np.nanmedian(emg)
    eeg_neg = np.nanmedian(eeg_slope[:max(5, int(0.6*len(eeg_slope)))])
    score = 0.7*(emg_p95 - 1.5*emg_med) - 0.3*(eeg_neg)
    return 'physical' if score > 0.4 else 'imagery'

def choose_mode(stem: str, feats) -> Tuple[str, str]:
    if FORCE_MODE in ('physical','imagery'):
        return FORCE_MODE, 'forced'
    if MODE_POLICY == 'filename':
        m = mode_from_filename(stem)
        return (m or 'imagery', 'filename' if m else 'infer')
    if MODE_POLICY == 'infer':
        return infer_mode_from_signals(feats), 'infer'
    m = mode_from_filename(stem)
    return (m, 'filename') if m else (infer_mode_from_signals(feats), 'infer')

def adaptive_weights(mode, R, eye_valid, eye_valid_min):
    if mode == 'physical': base = dict(emg=0.60, eeg=0.30, eye=0.10)
    else:                  base = dict(emg=0.15, eeg=0.75, eye=0.10)
    w = {k: v*(0.5 + 0.5*min(3.0, R.get(k,0.0))) for k,v in base.items()}
    if eye_valid < BASE['eye_valid_min']: w['eye'] *= 0.10
    s = sum(w.values()) + 1e-12
    return {k: v/s for k,v in w.items()}

def fuse_score(feats, weights):
    return (weights['eeg']*feats['eeg'] + weights['emg']*feats['emg'] + weights['eye']*feats['eye'])

# ======================= EMISSIONS & PRIOR ===========================
def _fit_two_gaussians(x):
    x = np.asarray(x); x = x[np.isfinite(x)]
    if x.size < 64:
        m = np.nanmedian(x); s = np.nanstd(x) + 1e-6
        return (m-0.5*s, 0.7*s), (m+0.7*s, 0.7*s)
    if USE_SKLEARN_GMM:
        try:
            from sklearn.mixture import GaussianMixture
            gm = GaussianMixture(n_components=2, covariance_type='spherical', random_state=0)
            gm.fit(x.reshape(-1,1))
            means = gm.means_.ravel()
            stds  = np.sqrt(gm.covariances_.ravel()) + 1e-9
            idx = np.argsort(means)
            return (float(means[idx[0]]), float(stds[idx[0]])), (float(means[idx[1]]), float(stds[idx[1]]))
        except Exception:
            pass
    hist, edges = np.histogram(x, bins=128)
    p = hist / max(1, hist.sum())
    omega = np.cumsum(p)
    centers = (edges[:-1] + edges[1:]) / 2.0
    mu = np.cumsum(p * centers); mu_t = mu[-1]
    sigma_b2 = (mu_t*omega - mu)**2 / (omega*(1.0 - omega) + 1e-12)
    idx = int(np.nanargmax(sigma_b2))
    thr = (edges[idx] + edges[idx+1]) / 2
    lo = x[x <= thr]; hi = x[x > thr]
    if lo.size < 8 or hi.size < 8:
        m = np.nanmedian(x); s = np.nanstd(x) + 1e-6
        return (m-0.5*s, 0.7*s), (m+0.7*s, 0.7*s)
    means = np.array([np.mean(lo), np.mean(hi)])
    stds  = np.array([np.std(lo)+1e-9, np.std(hi)+1e-9])
    order = np.argsort(means)
    return (float(means[order[0]]), float(stds[order[0]])), (float(means[order[1]]), float(stds[order[1]]))

def _logN(x, mu, s):
    v = (x - mu) / (s + 1e-12)
    return -0.5*(v**2) - np.log(s + 1e-12) - 0.5*np.log(2*np.pi)

def _lognorm_params_from_quantiles(q_lo, q_hi, p_lo=0.025, p_hi=0.975):
    a, b = np.log(q_lo+1e-9), np.log(q_hi+1e-9)
    z_lo, z_hi = norm.ppf(p_lo), norm.ppf(p_hi)
    sigma = (b - a) / max(1e-9, (z_hi - z_lo))
    mu = a - z_lo*sigma
    return mu, sigma

def duration_prior_log(d_range_s, mode, fs):
    lo, hi = BASE['target_dur_physical'] if mode=='physical' else BASE['target_dur_imagery']
    mu, sig = _lognorm_params_from_quantiles(lo, hi, 0.025, 0.975)
    d_s = np.asarray(d_range_s)
    pdf = (1.0/(d_s*sig*np.sqrt(2*np.pi)+1e-12))*np.exp(-(np.log(d_s+1e-12)-mu)**2/(2*sig*sig))
    pdf = np.clip(pdf, 1e-12, None); pdf /= np.sum(pdf)
    return np.log(pdf)

# ======================= HSMM SINGLE-SEGMENT =========================
def hsmm_single_segment(smooth, fs, mode):
    x = np.asarray(smooth)
    (mu0,s0), (mu1,s1) = _fit_two_gaussians(x)
    ll0 = _logN(x, mu0, s0); ll1 = _logN(x, mu1, s1)
    delta = ll1 - ll0; csum = np.cumsum(np.r_[0.0, delta])

    T = len(x)
    d_max = int(min(T, (BASE['target_dur_physical'][1] if mode=='physical' else BASE['target_dur_imagery'][1])*fs)+1)
    d_min = int(max(1, (BASE['target_dur_physical'][0] if mode=='physical' else BASE['target_dur_imagery'][0])*fs))
    d_vals = np.arange(1, d_max+1, dtype=int); d_secs = d_vals / fs
    log_p_d = duration_prior_log(d_secs, mode, fs)

    best_score = -1e18; best_e = -1; best_d = -1
    for e in range(1, T+1):
        d_hi = min(e, d_max); d_lo = min(e, max(1, d_min))
        local_best = -1e18; local_d = -1
        for d in range(d_lo, d_hi+1):
            s = e - d
            score = (csum[e] - csum[s]) + log_p_d[d-1]
            if score > local_best:
                local_best = score; local_d = d
        if local_best > best_score:
            best_score = local_best; best_e = e; best_d = local_d

    s_idx = max(0, best_e - best_d); e_idx = max(0, best_e - 1)
    params = dict(mu_rest=float(mu0), sd_rest=float(s0), mu_act=float(mu1), sd_act=float(s1))
    diag   = dict(delta=delta, csum=csum, log_p_d=log_p_d, d_vals=d_vals)
    return s_idx, e_idx, params, diag

# ======================= BOUNDARY REFINEMENT =========================
def refine_boundaries(feats, s_idx, e_idx, fs, mode):
    T = len(feats['t'])
    def sustain_mask(mask, min_sec):
        n = max(1, int(min_sec*fs))
        if n == 1: return mask
        run = np.convolve(mask.astype(int), np.ones(n, int), 'same') >= n
        return run

    if mode == 'physical':
        back = int(BASE['phys_backseek_s']*fs); fwd  = int(BASE['phys_forward_s']*fs)
        emg_grad = feats['emg_grad']; emg = feats['emg']
        on_mask = (emg_grad > BASE['phys_on_grad_thr']) | (emg > BASE['phys_on_env_thr'])
        on_mask = sustain_mask(on_mask, BASE['phys_min_sustain_s'])
        cand = _first_true_in_window(on_mask, max(0, s_idx - back), s_idx)
        if cand is not None: s_idx = cand
        off_mask = sustain_mask(emg > BASE['phys_off_env_thr'], BASE['phys_min_sustain_s'])
        cand2 = _last_true_in_window(off_mask, e_idx, min(T-1, e_idx + fwd))
        if cand2 is not None: e_idx = max(cand2, s_idx+1)
    else:
        back = int(BASE['imag_backseek_s']*fs); fwd  = int(BASE['imag_forward_s']*fs)
        eeg_slope = feats['eeg_slope']; eeg = feats['eeg']
        on_mask = (eeg_slope < BASE['imag_on_slope_thr']) & (eeg < BASE['imag_on_eeg_thr'])
        on_mask = sustain_mask(on_mask, BASE['imag_min_sustain_s'])
        cand = _first_true_in_window(on_mask, max(0, s_idx - back), s_idx)
        if cand is not None: s_idx = cand
        off_mask = sustain_mask(np.abs(eeg) > BASE['imag_off_abs_thr'], BASE['imag_min_sustain_s'])
        cand2 = _last_true_in_window(off_mask, e_idx, min(T-1, e_idx + fwd))
        if cand2 is not None: e_idx = max(cand2, s_idx+1)

    s_idx = int(max(0, min(s_idx, T-2)))
    e_idx = int(max(s_idx+1, min(e_idx, T-1)))
    return s_idx, e_idx

# ======================= SHIELDS (with clamps) =======================
def _sustain(mask: np.ndarray, fs: float, dur_s: float) -> np.ndarray:
    n = max(1, int(dur_s*fs))
    return np.convolve(mask.astype(int), np.ones(n, int), 'same') >= n

def compute_rest_shields(feats, fused, fs):
    """
    Learn head/tail baselines; detect earliest activation & latest rest.
    Then clamp: pre-rest ∈ [3,4] s; post-rest ∈ [1,4] s (nominal).
    """
    T = len(feats['t'])
    if T == 0: return 0, 0

    z_fused = robust_z(fused)
    z_combo = 0.6*feats['emg'] + 0.3*z_fused + 0.1*feats['eye']

    # ---- PRE (forward) in [3,4] s window
    pre_min_idx = int(min(T-1, PRE_REST_MIN_S * fs))
    pre_max_idx = int(min(T-1, PRE_REST_MAX_S * fs))

    base_n   = max(5, int(PRE_BASELINE_S * fs))
    base_pre = z_combo[:base_n] if base_n < T else z_combo[:max(5, T//6)]
    thr_on   = float(np.nanmedian(base_pre)) + ON_Z_K * float(mad(base_pre) + 1e-9)

    ma = rolling_mean(z_combo, int(0.08 * fs))
    act = _sustain(ma > thr_on, fs, MIN_SUSTAIN_ON_S)

    first_cross = _first_true_in_window(act, lo=pre_min_idx, hi=pre_max_idx)
    if first_cross is None:
        pre_end_idx = int(min(T-1, round(3.5 * fs)))  # midpoint if no crossing
    else:
        pre_end_idx = first_cross - int(BACKOFF_ON_S * fs)

    pre_end_idx = int(np.clip(pre_end_idx, pre_min_idx, pre_max_idx))

    # ---- POST (backward) yielding 1–4 s tail rest (nominal)
    post_earliest_idx = max(0, T - int(POST_REST_MAX_S * fs))  # ≥ T-4s
    post_latest_idx   = max(0, T - int(POST_REST_MIN_S * fs))  # ≤ T-1s

    tail_slice  = z_combo[max(0, T - max(5, int(POST_BASELINE_S*fs))):]
    thr_rest    = float(np.nanmedian(tail_slice)) + OFF_Z_K * float(mad(tail_slice) + 1e-9)

    ma2 = rolling_mean(z_combo, int(0.10 * fs))
    is_rest = _sustain(ma2 <= thr_rest, fs, MIN_SUSTAIN_OFF_S)

    last_rest_rev = _first_true_in_window(is_rest[::-1], lo=0, hi=int(POST_REST_MAX_S * fs))
    if last_rest_rev is None:
        post_start_idx = post_latest_idx
    else:
        rr = T - 1 - (last_rest_rev - 1 if last_rest_rev > 0 else 0)
        post_start_idx = rr + int(BACKOFF_OFF_S * fs)

    post_start_idx = max(post_start_idx, post_earliest_idx)
    post_start_idx = min(post_start_idx, post_latest_idx)
    post_start_idx = int(np.clip(post_start_idx, 0, T))

    if pre_end_idx >= post_start_idx:
        mid = (pre_end_idx + post_start_idx)//2
        pre_end_idx = max(0, mid - int(0.25*fs))
        post_start_idx = min(T, pre_end_idx + int(0.50*fs))
    return int(pre_end_idx), int(post_start_idx)

# -------- Tail extension via EMG quiet-window (robust, no truncation) --------
def extend_action_tail_emg_quiet(feats, fs, s_idx, e_idx):
    """
    From HSMM/refined offset, extend until we hit a sustained quiet window,
    defined by:
      - EMG envelope near pre-rest baseline (median + REST_K*MAD),
      - EMG gradient below GRAD_QUIET_THR,
      - sustained for TAIL_QUIET_S seconds.
    """
    T = len(feats['t']); env = feats['emg']; grad = feats['emg_grad']

    # Pre-rest baseline from first 2 s (or PRE_REST_MIN_S if shorter)
    pre_win = int(max(1, min(PRE_REST_MIN_S*fs, 2.0*fs)))
    base = env[:pre_win] if pre_win < T else env[:max(5, T//6)]
    rest_thr = float(np.nanmedian(base)) + REST_K * float(mad(base) + 1e-9)

    # Quiet if both envelope is low and gradient is low
    env_ok  = env <= rest_thr
    grad_ok = grad <= GRAD_QUIET_THR
    quiet   = env_ok & grad_ok
    quiet_sustain = _sustain(quiet, fs, TAIL_QUIET_S)

    # Search forward for first quiet window after the current offset
    search_lo = int(min(T-1, e_idx + int(0.05*fs)))
    search_hi = int(min(T-1, e_idx + int(TAIL_EXTEND_MAX_S*fs)))
    q_start = _first_true_in_window(quiet_sustain, search_lo, search_hi)

    if q_start is None:
        # No quiet found — extend to the cap
        new_e = search_hi
    else:
        new_e = max(e_idx, q_start - int(TAIL_BACKOFF_S*fs))

    return int(min(T-1, new_e))

# ======================= FINALIZE HELPERS ============================
def refine_active_mask(mask: np.ndarray, fs: float,
                       min_bout_s: float = MIN_BOUT_S,
                       max_gap_s: float = MAX_GAP_S) -> np.ndarray:
    x = np.asarray(mask, dtype=np.uint8).copy()
    if x.size == 0: return x
    min_bout = max(1, int(round(min_bout_s * fs)))
    max_gap  = max(1, int(round(max_gap_s  * fs)))

    def runs(arr):
        d = np.diff(np.r_[0, arr, 0])
        starts = np.where(d == 1)[0]; ends   = np.where(d == -1)[0]
        out = []; prev = 0
        for s,e in zip(starts, ends):
            if s > prev: out.append((prev, s, 0))
            out.append((s, e, 1)); prev = e
        if prev < len(arr): out.append((prev, len(arr), 0))
        return out

    for s,e,v in runs(x):
        if v == 1 and (e - s) < min_bout: x[s:e] = 0
    r = runs(x)
    for i in range(1, len(r)-1):
        s,e,v = r[i]
        if v == 0 and (e - s) <= max_gap and r[i-1][2] == 1 and r[i+1][2] == 1:
            x[s:e] = 1
    return x

def segments_from_mask(mask: np.ndarray):
    x = np.asarray(mask, dtype=np.uint8)
    d = np.diff(np.r_[0, x, 0])
    starts = np.where(d == 1)[0]; ends   = np.where(d == -1)[0]
    return list(zip(starts, ends))

# ============================ PLOTTING ===============================
def plot_results_basic(df, feats, fused, start_s, end_s, smooth, hsmm_params, save_path=None, title_suffix=""):
    t = feats['t']
    fig, axes = plt.subplots(5, 1, figsize=(16, 12))

    axes[0].plot(t, fused, lw=2, alpha=0.9, label='Fused Score')
    axes[0].plot(t, smooth, alpha=0.35, label='Smoothed')
    axes[0].axvspan(start_s, end_s, color='purple', alpha=0.2, label='HSMM+Refined')
    axes[0].axhline(hsmm_params['mu_rest'], ls='--', alpha=0.6, label='μ_rest (unsup)')
    axes[0].axhline(hsmm_params['mu_act'],  ls='--', alpha=0.6, label='μ_action (unsup)')
    axes[0].axvspan(t[0], start_s, color='k', alpha=0.05, hatch='///', label='Pre-rest (learned)')
    axes[0].axvspan(end_s, t[-1], color='k', alpha=0.05, hatch='\\\\\\', label='Post-rest (learned)')
    axes[0].set_ylabel('Z-score'); axes[0].grid(True, alpha=0.3); axes[0].legend(loc='upper right')
    axes[0].set_title('ICML: HSMM Single-Action Decode ' + title_suffix, fontweight='bold')

    axes[1].plot(t, feats['eeg'], lw=1.5, label='EEG Activity')
    axes[1].plot(t, feats['emg'], lw=1.5, label='EMG Activity')
    axes[1].plot(t, feats['eye'], lw=1.5, label='EYE Activity')
    axes[1].axvspan(start_s, end_s, color='gray', alpha=0.15)
    axes[1].set_ylabel('Z-score'); axes[1].grid(True, alpha=0.3); axes[1].legend(loc='upper right')
    axes[1].set_title('Modality Contributions', fontweight='bold')

    emg_cols = [c for c in ['EMG_Ch1','EMG_Ch2','EMG_Ch3','EMG_Ch4'] if c in df.columns]
    if emg_cols:
        for ch in emg_cols:
            axes[2].plot(df['Timestamp_seconds'], df[ch].to_numpy(), lw=0.8, alpha=0.85, label=ch)
        axes[2].axvspan(start_s, end_s, color='red', alpha=0.2)
        axes[2].set_ylabel('EMG (arb)'); axes[2].legend(ncol=min(2,len(emg_cols)), loc='upper right')
        axes[2].grid(True, alpha=0.3); axes[2].set_title('EMG Execution', fontweight='bold')

    eeg_cols = [c for c in [f'EEG_Ch{i}' for i in range(1,9)] if c in df.columns][:3]
    if eeg_cols:
        for i, ch in enumerate(eeg_cols):
            x = df[ch].ffill().bfill().to_numpy()
            xn = (x - np.mean(x)) / (np.std(x)+1e-12)
            axes[3].plot(df['Timestamp_seconds'], xn + 4*i, lw=1.0, alpha=0.9, label=ch)
        axes[3].axvspan(start_s, end_s, color='blue', alpha=0.2)
        axes[3].set_ylabel('EEG (norm)'); axes[3].legend(loc='upper right')
        axes[3].grid(True, alpha=0.3); axes[3].set_title('EEG Preparation', fontweight='bold')

    lab = np.where((t>=start_s)&(t<=end_s),1,0)
    axes[4].step(t, lab, where='post', lw=3, label='Action Label')
    axes[4].set_yticks([0,0.5,1]); axes[4].set_yticklabels(['Rest','0.5','Action'])
    axes[4].set_ylim(-0.1,1.1); axes[4].grid(True, alpha=0.3)
    axes[4].set_xlabel('Time (s)'); axes[4].set_ylabel('Action / Prob.')
    axes[4].legend(loc='upper right')

    plt.tight_layout()
    if save_path: fig.savefig(save_path, dpi=300, bbox_inches='tight')
    return fig

def plot_results_extended(df, feats, fused, start_s, end_s, fs, smooth, hsmm_params, hsmm_diag,
                          pre_end_idx: int, post_start_idx: int,
                          save_path=None, title_suffix=""):
    t = feats['t']; n_rows = 7 + (1 if ADD_SPECTROGRAM else 0)
    fig = plt.figure(figsize=(16, 3*n_rows + 2))
    gs = gridspec.GridSpec(n_rows, 1, figure=fig, hspace=0.35); r = 0

    t_roi = t[pre_end_idx:post_start_idx]; delta = hsmm_diag['delta']; csum  = hsmm_diag['csum']

    ax = fig.add_subplot(gs[r, 0]); r += 1
    ax.plot(t, fused, lw=2, label='Fused Score')
    ax.plot(t, smooth, alpha=0.35, label='Smoothed')
    ax.axvspan(start_s, end_s, color='purple', alpha=0.20, label='HSMM+Refined')
    ax.axhline(hsmm_params['mu_rest'],  ls='--', alpha=0.6, label='μ_rest')
    ax.axhline(hsmm_params['mu_act'],   ls='--', alpha=0.6, label='μ_act')
    ax.axvspan(t[0], start_s, color='k', alpha=0.05, hatch='///', label='Pre-rest')
    ax.axvspan(end_s, t[-1], color='k', alpha=0.05, hatch='\\\\\\', label='Post-rest')
    ax.set_ylabel('Z'); ax.set_title(f'HSMM Decode — {title_suffix}'); ax.grid(True, alpha=0.3); ax.legend(loc='upper right')

    ax = fig.add_subplot(gs[r, 0]); r += 1
    ax.plot(t_roi, delta, lw=1.1, label='Δ logL (act - rest)')
    ax2 = ax.twinx(); ax2.plot(t_roi, csum[1:], lw=1.1, alpha=0.6, label='cum Δ logL', linestyle='--')
    ax.axvspan(start_s, end_s, color='gray', alpha=0.12)
    ax.set_title('Likelihood Terms'); ax.grid(True, alpha=0.3)

    ax = fig.add_subplot(gs[r, 0]); r += 1
    ax.plot(t, feats['emg'], label='EMG envelope (z)', lw=1.5)
    ax.plot(t, feats['emg_grad'], label='EMG gradient (z)', lw=1.0, alpha=0.85)
    ax.axvspan(start_s, end_s, color='red', alpha=0.12)
    ax.set_title('EMG Evidence'); ax.grid(True, alpha=0.3); ax.legend(loc='upper right')

    ax = fig.add_subplot(gs[r, 0]); r += 1
    ax.plot(t, feats['eeg_slope'], label='EEG MRCP slope (z)', lw=1.2)
    ax.plot(t, feats['eeg'],       label='EEG combined (z)', lw=1.2, alpha=0.9)
    ax.axvspan(start_s, end_s, color='blue', alpha=0.12)
    ax.set_title('EEG Evidence'); ax.grid(True, alpha=0.3); ax.legend(loc='upper right')

    ax = fig.add_subplot(gs[r, 0]); r += 1
    ax.plot(t, feats['eye'], label='Eye activity (z)', lw=1.2)
    ax.axvspan(start_s, end_s, color='green', alpha=0.10)
    ax.set_title('Oculomotor Evidence'); ax.grid(True, alpha=0.3); ax.legend(loc='upper right')

    ax = fig.add_subplot(gs[r, 0]); r += 1
    R = feats['R']; keys = ['EMG','EEG','Eye']; vals = [R['emg'], R['eeg'], R['eye']]
    ax.bar(keys, vals); ax.set_ylim(0, max(1.0, max(vals)+0.1))
    ax.set_title('Modality Reliability (relative)'); ax.grid(True, axis='y', alpha=0.3)

    ax = fig.add_subplot(gs[r, 0]); r += 1
    d_vals = hsmm_diag['d_vals']; log_p = hsmm_diag['log_p_d']
    secs = d_vals / fs; ax.plot(secs, np.exp(log_p), lw=1.5)
    ax.set_xlabel('Duration (s)'); ax.set_ylabel('Prior p(d)'); ax.set_title('HSMM Duration Prior'); ax.grid(True, alpha=0.3)

    if ADD_SPECTROGRAM:
        eeg_cols = [c for c in [f'EEG_Ch{i}' for i in range(1,9)] if c in df.columns]
        if eeg_cols:
            ax = fig.add_subplot(gs[r, 0]); r += 1
            x = df[eeg_cols[0]].ffill().bfill().to_numpy()
            nper = int(fs*1.0); nover = int(fs*0.9)
            f, tt_spec, Sxx = _sig.spectrogram(x, fs=fs, nperseg=nper, noverlap=nover,
                                               scaling='spectrum', mode='psd')
            keep = (f>=1) & (f<=40)
            im = ax.pcolormesh(tt_spec, f[keep], 10*np.log10(Sxx[keep,:]+1e-12), shading='auto')
            ax.axvspan(start_s, end_s, color='purple', alpha=0.15)
            ax.set_ylabel('Hz'); ax.set_xlabel('Time (s)'); ax.set_title(f'EEG Spectrogram ({eeg_cols[0]})')
            cb = fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02); cb.set_label('Power (dB)')

    fig.suptitle('ICML Extended Diagnostics (HSMM + Head Clamp + EMG Tail Recovery)', fontsize=14, fontweight='bold', y=0.995)
    fig.tight_layout()
    if save_path: fig.savefig(save_path, dpi=300, bbox_inches='tight')
    return fig

# ============================ FILE PROCESS ===========================
def parse_ids_from_stem(stem: str, input_dir: Path):
    """
    subject_id is taken from the folder name like 'Sub-6',
    but we store only the numeric part → 6 (not 'Sub-6').

    Filename convention (your rule):
        ...T114...  → task=1, trial=14
        ...T216...  → task=2, trial=16
        ...M305...  → imagery: task=3, trial=5
    i.e., FIRST digit after T/M = task, REMAINING digits = trial.
    """
    subj = None
    # walk up directory tree looking for 'Sub-<num>' or 'sub<num>' or 'sub-<num>'
    for part in input_dir.parts[::-1]:
        m = re.match(r'(?i)^sub-?(\d+)$', part)
        if m:
            subj = int(m.group(1))  # store as 6, not 'Sub-6'
            break

    mode = None
    task = None
    trial = None

    # Allow codes at start OR after an underscore:
    # matches: "T216...", "Sub-2_T216_synchronized_corrected", "M114", etc.
    m = re.search(r'(?:^|_)(T|M)(\d{2,})', stem, flags=re.I)
    if m:
        mode = m.group(1).upper()
        digits = m.group(2)

        # Your convention: first digit = task, rest = trial
        # e.g., "216" -> task=2, trial=16; "114" -> task=1, trial=14
        if len(digits) >= 2:
            task  = int(digits[0])
            trial = int(digits[1:])
        else:
            # Just in case of weird filenames like T2 (no trial info)
            task  = int(digits)
            trial = -1

    return subj, mode, task, trial



def process_file(path: str, label_dir: Path, plots_dir: Path) -> dict:
    df = pd.read_csv(path, low_memory=False)
    df = canonicalize_columns(df)

    if 'Timestamp_seconds' not in df.columns:
        if 'Timestamp_ms' in df.columns:
            df['Timestamp_seconds'] = df['Timestamp_ms'].astype(float)/1000.0
        elif 'ET_TimeSignal' in df.columns:
            df['Timestamp_seconds'] = df['ET_TimeSignal'].astype(float)
        else:
            raise ValueError(f"{Path(path).name}: missing time column")

    stem = Path(path).stem
    subj, mode_tok, task, trial = parse_ids_from_stem(stem, label_dir.parent)

    t  = df['Timestamp_seconds'].astype(float).to_numpy()
    fs = median_fs(t); T = len(t)

    et_present = _et_signal_audit(df)
    print(f"[ET] {Path(path).name} → pupil:{len(et_present['pupil'])>0} "
          f"gazeXY:{len(et_present['gaze_xy'])>0} valid:{len(et_present['valid'])>0} "
          f"axis3d:{len(et_present['fallback_axis'])>0} imu:{len(et_present['imu'])>0} events:{len(et_present['events'])>0}")

    feats = extract_features(df, fs)

    if FORCE_MODE in ('physical','imagery'): 
        mode, mode_source = FORCE_MODE, 'forced'
    else:
        if MODE_POLICY in {'auto','filename'} and mode_tok is not None:
            mode = 'physical' if mode_tok == 'T' else 'imagery'; mode_source = 'filename'
        elif MODE_POLICY == 'infer':
            mode, mode_source = infer_mode_from_signals(feats), 'infer'
        else:
            mode, mode_source = choose_mode(stem, feats)

    smooth_win = BASE['smooth_win_physical'] if mode=='physical' else (
                 BASE['smooth_win_imagery'] if mode=='imagery' else BASE['smooth_win_default'])

    weights = adaptive_weights(mode, feats['R'], feats['eye_valid'], BASE['eye_valid_min'])
    fused   = fuse_score(feats, weights)
    smooth  = np.convolve(fused, np.ones(smooth_win)/smooth_win, 'same')

    # >>> REST-ONLY BYPASS (T0x) <<<
    # If filename encodes Task 0 (e.g., ..._T02_... → task=0, trial=2), treat as pure REST.
    if (mode_tok == 'T' and task == 0):
        pre_end_idx, post_start_idx = compute_rest_shields(feats, fused, fs)
        out_df = df.copy()
        # zero labels
        out_df['active_raw']   = 0
        out_df['active']       = 0
        out_df['active_prob']  = 0.0
        out_df['label_action'] = 0
        # meta columns
        if subj is not None:  out_df['subject_id'] = subj
        if task is not None:  out_df['task'] = task
        if trial is not None: out_df['trial'] = trial
        out_df['label_11']   = 0
        out_df['task_target']= 0

        sidecar = {
            "file": str(Path(path).resolve()),
            "subject_id": subj,
            "task_code": int(task) if task is not None else 0,
            "trial_id": int(trial) if trial is not None else -1,
            "fs_hz": float(fs),
            "duration_s": float(t[-1]-t[0]) if T else 0.0,
            "n_active_segments": 0,
            "segments": [],
            "rest_shields": {
                "pre_rest_end_s":  float(t[min(pre_end_idx,  T-1)]),
                "post_rest_start_s": float(t[min(post_start_idx, T-1)])
            }
        }
        label_dir.mkdir(parents=True, exist_ok=True); plots_dir.mkdir(parents=True, exist_ok=True)
        with open(label_dir / f"{stem}_onsets.json", "w", encoding="utf-8") as f:
            json.dump(sidecar, f, indent=2)

        out_csv = str(label_dir / f"{stem}_icml_consensus_labels.csv")
        out_df.to_csv(out_csv, index=False)

        return dict(file=Path(path).name, subject_id=subj, mode='physical', mode_source='forced',
                    fs=round(fs,2), start_s=None, end_s=None, duration_s=0.0,
                    mu_rest=np.nan, sd_rest=np.nan, mu_act=np.nan, sd_act=np.nan,
                    out_csv=out_csv, out_png_basic=None, out_png_ext=None, quality="REST_ONLY")

    # ----------- SHIELDS + CLAMP of ROI -----------
    pre_end_idx, post_start_idx = compute_rest_shields(feats, fused, fs)
    # Clamp ROI start to [3,4] s (strict head rest)
    min_on = int(PRE_REST_MIN_S*fs); max_on = int(PRE_REST_MAX_S*fs)
    pre_end_idx = int(np.clip(pre_end_idx, min_on, max_on))
    post_start_idx = max(post_start_idx, pre_end_idx + int(0.5*fs))  # ensure some room

    # HSMM within legal region
    smooth_roi = smooth[pre_end_idx:post_start_idx].copy()
    if smooth_roi.size < 8:
        out_df = df.copy()
        out_df['active'] = 0; out_df['active_prob'] = 0.0
        out_df['active_raw']=0; out_df['label_11']=0; out_df['task_target']=0
        out_df['label_action'] = 0
        if subj is not None: out_df['subject_id']=subj
        if task is not None: out_df['task']=task
        if trial is not None: out_df['trial']=trial
        sidecar = {
            "file": str(Path(path).resolve()),
            "subject_id": subj, "task_code": int(task) if task is not None else 0,
            "trial_id": int(trial) if trial is not None else -1,
            "fs_hz": float(fs),
            "duration_s": float(t[-1]-t[0]) if T else 0.0,
            "n_active_segments": 0,
            "segments": [],
            "rest_shields": {"pre_rest_end_s": float(t[min(pre_end_idx, T-1)]),
                             "post_rest_start_s": float(t[min(post_start_idx, T-1)])}
        }
        label_dir.mkdir(parents=True, exist_ok=True); plots_dir.mkdir(parents=True, exist_ok=True)
        with open(label_dir / f"{stem}_onsets.json","w",encoding="utf-8") as f: json.dump(sidecar,f,indent=2)
        out_csv = str(label_dir / f"{stem}_icml_consensus_labels.csv"); out_df.to_csv(out_csv, index=False)
        return dict(file=Path(path).name, subject_id=subj, mode='physical', mode_source='forced', fs=round(fs,2),
                    start_s=None, end_s=None, duration_s=0.0,
                    mu_rest=np.nan, sd_rest=np.nan, mu_act=np.nan, sd_act=np.nan,
                    out_csv=out_csv, out_png_basic=None, out_png_ext=None, quality="NO-ROOM")

    s_local, e_local, hsmm_params, hsmm_diag = hsmm_single_segment(smooth_roi, fs, mode)
    s_idx = pre_end_idx + s_local; e_idx = pre_end_idx + e_local

    s_idx_ref, e_idx_ref = refine_boundaries(feats, s_idx, e_idx, fs, mode)
    s_idx_ref = max(s_idx_ref, pre_end_idx); e_idx_ref = min(e_idx_ref, post_start_idx-1)

    # ---- HARD HEAD RULE: force onset into [3,4] s band
    s_idx_ref = int(np.clip(s_idx_ref, min_on, max_on))
    pre_end_idx = min(pre_end_idx, s_idx_ref)

    # ---- TAIL: EMG quiet-window recovery (robust)
    e_idx_ref = extend_action_tail_emg_quiet(feats, fs, s_idx_ref, e_idx_ref)

    # ---- Ensure legal tail rest; allow evidence override down to 0.5 s
    min_tail = int(POST_REST_MIN_S * fs)
    min_tail_relaxed = int(TAIL_OVERRIDE_MIN_REST_S * fs)
    post_start_idx = max(post_start_idx, e_idx_ref + min_tail)  # nominal
    if post_start_idx > T - 1:  # relax if evidence pushes near end
        post_start_idx = min(T-1, e_idx_ref + min_tail_relaxed)
    post_start_idx = min(post_start_idx, T-1)

    start_s = float(t[s_idx_ref]); end_s = float(t[e_idx_ref]); dur = max(0.0, end_s - start_s)

    delta = hsmm_diag['delta']; scale = max(1e-6, mad(smooth_roi))
    p_active_roi = 1.0/(1.0 + np.exp(-delta / (BASE['prob_scale_frac']*scale + 1e-12)))

    out_df = df.copy()
    active = np.zeros(T, dtype=int); active[s_idx_ref:e_idx_ref+1] = 1
    active[:pre_end_idx] = 0; active[post_start_idx:] = 0
    active_ref = refine_active_mask(active, fs, min_bout_s=MIN_BOUT_S, max_gap_s=MAX_GAP_S)

    active_prob = np.zeros(T, float)
    roi_len = post_start_idx - pre_end_idx
    pa = p_active_roi if np.ndim(p_active_roi)==1 else np.asarray(p_active_roi)
    active_prob[pre_end_idx:post_start_idx] = (pa[:roi_len] if pa.size >= roi_len
                                               else float(np.nanmean(pa)) if pa.size else 0.0)

    out_df['active_raw']=active.astype(int)
    out_df['active']=active_ref.astype(int)
    out_df['active_prob']=active_prob
    out_df['label_action']=out_df['active'].astype(int)  # 0=REST, 1=ACTION

    if subj is not None: out_df['subject_id']=subj
    if task is not None: out_df['task']=task
    if trial is not None: out_df['trial']=trial

    task_code = int(task) if task is not None else 0
    out_df['label_11'] = np.where(out_df['active']==1, task_code, 0).astype(int)
    out_df['task_target'] = out_df['label_11']

    segs = segments_from_mask(out_df['active'].to_numpy())
    onsets=[]
    for s_i, e_i in segs:
        s_i=int(s_i); e_i=int(e_i)
        s_sec=float(t[s_i]); e_sec=float(t[min(e_i-1, T-1)])
        onsets.append({"start_idx":s_i,"end_idx":e_i,"start_s":round(s_sec,6),
                       "end_s":round(e_sec,6),"duration_s":round(max(0.0,e_sec-s_sec),6)})

    sidecar = {
        "file": str(Path(path).resolve()), "subject_id": subj,
        "task_code": task_code, "trial_id": int(trial) if trial is not None else -1,
        "fs_hz": float(fs), "duration_s": float(t[-1]-t[0]) if T else 0.0,
        "n_active_segments": int(len(onsets)), "segments": onsets,
        "rest_shields": {"pre_rest_end_s": float(t[min(pre_end_idx, T-1)]),
                         "post_rest_start_s": float(t[min(post_start_idx, T-1)])}
    }
    label_dir.mkdir(parents=True, exist_ok=True); plots_dir.mkdir(parents=True, exist_ok=True)
    with open(label_dir / f"{stem}_onsets.json","w",encoding="utf-8") as f: json.dump(sidecar,f,indent=2)

    out_csv = str(label_dir / f"{stem}_icml_consensus_labels.csv"); out_df.to_csv(out_csv, index=False)

    out_png_basic = str(plots_dir / f"{stem}_icml_detection_plot.png")
    out_png_ext   = str(plots_dir / f"{stem}_icml_detection_plot_extended.png")

    try:
        if SAVE_PLOTS_STANDARD:
            fig = plot_results_basic(out_df, feats, fused, float(t[s_idx_ref]), float(t[e_idx_ref]),
                                     smooth, hsmm_params,
                                     save_path=out_png_basic,
                                     title_suffix=f"(Mode: {mode.upper()} • {mode_source} • Head Clamp + EMG Tail Recovery)")
            plt.close(fig)
    except Exception as e:
        print(f"[warn] basic plot failed for {Path(path).name}: {e}"); out_png_basic=None

    try:
        if SAVE_PLOTS_EXTENDED:
            fig2 = plot_results_extended(out_df, feats, fused, float(t[s_idx_ref]), float(t[e_idx_ref]), fs,
                                         smooth, hsmm_params, hsmm_diag,
                                         pre_end_idx, post_start_idx,
                                         save_path=out_png_ext,
                                         title_suffix=f"Mode: {mode.upper()} • {mode_source} • Head Clamp + EMG Tail Recovery)")
            plt.close(fig2)
    except Exception as e:
        print(f"[warn] extended plot failed for {Path(path).name}: {e}"); out_png_ext=None

    quality = ("SHORT" if (dur < 0.8 and mode=='physical') else
               "LONG"  if (dur > 8.0 and mode=='physical')  else "GOOD")

    return dict(file=Path(path).name, subject_id=subj, mode=mode, mode_source=mode_source, fs=round(fs,2),
        start_s=round(float(t[s_idx_ref]),3), end_s=round(float(t[e_idx_ref]),3), duration_s=round(dur,3),
        mu_rest=round(hsmm_params['mu_rest'],4), sd_rest=round(hsmm_params['sd_rest'],4),
        mu_act=round(hsmm_params['mu_act'],4),   sd_act=round(hsmm_params['sd_act'],4),
        out_csv=out_csv, out_png_basic=out_png_basic, out_png_ext=out_png_ext, quality=quality)

# ============================ SUBJECT RUNNER =========================
def process_subject(subject_sync_dir: Path) -> pd.DataFrame:
    label_dir = subject_sync_dir / "label"
    plots_dir = subject_sync_dir / "plot"
    files = sorted(glob.glob(str(subject_sync_dir/FILE_GLOB)))
    if not files:
        print(f"[skip] No files matched in: {subject_sync_dir} | pattern: {FILE_GLOB}")
        return pd.DataFrame()

    results = []
    for f in files:
        try:
            res = process_file(f, label_dir=label_dir, plots_dir=plots_dir)
        except Exception as e:
            subj_name = Path(subject_sync_dir).parents[1].name if len(subject_sync_dir.parents) > 1 else None
            res = dict(file=Path(f).name, subject_id=subj_name, error=str(e))
        results.append(res); print(res)

    rows = []
    for r in results:
        if 'error' in r:
            rows.append(dict(file=r['file'], subject_id=r.get('subject_id'), status='ERROR', error=r['error']))
        else:
            rows.append(dict(file=r['file'], subject_id=r.get('subject_id'), status='OK',
                mode=r.get('mode','?'), mode_source=r.get('mode_source','?'), fs=r.get('fs'),
                start_s=r.get('start_s'), end_s=r.get('end_s'), duration_s=r.get('duration_s'),
                mu_rest=r.get('mu_rest'), sd_rest=r.get('sd_rest'),
                mu_act=r.get('mu_act'), sd_act=r.get('sd_act'),
                quality=r.get('quality'), out_csv=r.get('out_csv')))
    summary = pd.DataFrame(rows)
    out_summary = str(subject_sync_dir/"icml_consensus_batch_summary.csv")
    summary.to_csv(out_summary, index=False)
    print(f"[ok] Subject summary saved: {out_summary}")
    return summary

# ---- Notebook-friendly entrypoint ----
def main(subject: str | None = None, subject_dir: str | None = None):
    if subject_dir: subject_sync_dir = Path(subject_dir)
    elif subject:   subject_sync_dir = Path(ROOT_DIR) / subject / Path(SYNC_SUBPATH)
    else:           subject_sync_dir = DEFAULT_SUBJECT_SYNC_DIR
    print(f"Processing one subject at: {subject_sync_dir}")
    if not subject_sync_dir.exists():
        print(f"[error] Subject sync dir not found: {subject_sync_dir}"); return
    _ = process_subject(subject_sync_dir)

# In Jupyter or script:
main(subject_dir=DEFAULT_SUBJECT_SYNC_DIR)
