In [None]:
# =============================================================================
# Phase 4 â€” Deterministic Preprocessing (per file, once) [Jupyter-safe]
# -----------------------------------------------------------------------------
# =============================================================================

from __future__ import annotations
import argparse, json, warnings, re
from pathlib import Path
import numpy as np
import pandas as pd
from scipy import signal as sig
from scipy.signal import welch, iirnotch
from scipy.interpolate import interp1d

warnings.filterwarnings("ignore")

# ---------------- PROJECT PATHS ----------------
ROOT_DIR = Path(r"/home/tsultan1/BioRob(Final)/Data")
MANIFEST_DEFAULT = ROOT_DIR / "_dataset_icml_v1" / "manifest_v1.csv"

# ---------------- CONFIG ----------------
TARGET_FS   = 250.0
FS_TOL_FRAC = 0.01           # resample if |fs/TARGET_FS - 1| > 1%

# EEG filtering (stable at 250 Hz)
EEG_BAND  = (1.0, 40.0)      # Hz
EEG_ORDER = 4
APPLY_NOTCH_60 = False
APPLY_NOTCH_50 = False
NOTCH_Q        = 30.0
EEG_CAR_MODE   = "median"    # "median" or "mean"

# Auto-notch / repair
EEG_AUTO_NOTCH        = True
EEG_NOTCH_DB          = 30.0
EEG_REPAIR            = "hampel"  # "none" | "hampel" | "zscore"
EEG_REPAIR_WIN_MS     = 120.0
EEG_REPAIR_SIGMA      = 8.0
EEG_POLARITY_CHECK    = True

# EMG envelope
EMG_ENVELOPE_MODE = "hp_rect_lp"   # "hp_rect_lp" or "bp_rect_ma"
EMG_HP_HZ   = 20.0
EMG_LP_HZ   = 10.0
EMG_BP      = (20.0, 100.0)
EMG_ORDER   = 4
MA_MS       = 50.0

# Eye-tracking cleanup
ET_INVALID_GT   = 1            # validity > 1 => invalid
ET_GAP_MS       = 150.0        # interpolate gaps shorter than this
ET_SMOOTH_TAPS  = 7            # odd integer

# Canonical channels
EEG_CH = [f"EEG_Ch{i}" for i in range(1, 9)]
EMG_CH = [f"EMG_Ch{i}" for i in range(1, 4+1)]
ET_CONT_CH = [
    "ET_PupilLeft","ET_PupilRight",
    "ET_DistanceLeft","ET_DistanceRight",
    "ET_GazeLeftx","ET_GazeLefty","ET_GazeRightx","ET_GazeRighty",
    "ET_CameraLeftx","ET_CameraLefty","ET_CameraRightx","ET_CameraRighty",
    "ET_GyroX","ET_GyroY","ET_GyroZ",
    "ET_AccX","ET_AccY","ET_AccZ",
    "ET_HeadRotationPitch","ET_HeadRotationYaw","ET_HeadRotationRoll",
    "ET_Gaze3dEyeballXLeft","ET_Gaze3dEyeballYLeft","ET_Gaze3dEyeballZLeft",
    "ET_Gaze3dEyeballXRight","ET_Gaze3dEyeballYRight","ET_Gaze3dEyeballZRight",
    "ET_Gaze3dOpticalAxisXLeft","ET_Gaze3dOpticalAxisYLeft","ET_Gaze3dOpticalAxisZLeft",
    "ET_Gaze3dOpticalAxisXRight","ET_Gaze3dOpticalAxisYRight","ET_Gaze3dOpticalAxisZRight",
]
ET_FLAG_CH = ["ET_Fixation","ET_Worn","ET_ValidityLeftEye","ET_ValidityRightEye"]

# Caching
SAVE_CACHE   = True
CACHE_SUFFIX = ".preproc.npz"

# ---------------- FILTER HELPERS ----------------
def _butter_band_sos(low, high, fs, order=4):
    nyq = fs * 0.5
    lo = max(1e-6, low/nyq); hi = min(0.9999, high/nyq)
    return sig.butter(order, [lo, hi], btype="bandpass", output="sos")

def _butter_high_sos(cut, fs, order=4):
    nyq = fs * 0.5
    wn = max(1e-6, cut/nyq)
    return sig.butter(order, wn, btype="highpass", output="sos")

def _butter_low_sos(cut, fs, order=4):
    nyq = fs * 0.5
    wn = min(0.9999, cut/nyq)
    return sig.butter(order, wn, btype="lowpass", output="sos")

def _apply_sos(x, sos):
    return sig.sosfiltfilt(sos, x, axis=0)

def _iir_notch_norm(x, fs, f0, Q=30.0):
    b, a = sig.iirnotch(w0=f0/(fs/2.0), Q=Q)
    return sig.filtfilt(b, a, x, axis=0)

def _iir_notch_fs(x, fs, f0, Q=30.0):
    try:
        b, a = iirnotch(w0=f0, Q=Q, fs=fs)
        return sig.filtfilt(b, a, x, axis=0)
    except TypeError:
        return _iir_notch_norm(x, fs, f0, Q)

def _moving_average(x, taps:int):
    if taps <= 1:
        return x
    k = np.ones(int(taps), dtype=float) / float(taps)
    if x.ndim == 1:
        return np.convolve(x, k, mode="same")
    return np.apply_along_axis(lambda c: np.convolve(c, k, mode="same"), 0, x)

# ---------------- CORE HELPERS ----------------
def _median_fs(t: np.ndarray) -> float:
    dt = np.diff(t)
    dt = dt[np.isfinite(dt) & (dt > 0)]
    if dt.size == 0: return np.nan
    return 1.0 / np.median(dt)

def _resample_linear(t: np.ndarray, X: np.ndarray, target_fs: float) -> tuple[np.ndarray, np.ndarray]:
    t0, t1 = float(t[0]), float(t[-1])
    n_out = int(round((t1 - t0) * target_fs)) + 1
    t_new = np.linspace(t0, t1, n_out)
    if X.ndim == 1:
        f = interp1d(t, X, kind="linear", bounds_error=False, fill_value="extrapolate")
        return t_new, f(t_new)
    X_new = np.zeros((n_out, X.shape[1]), dtype=float)
    for i in range(X.shape[1]):
        f = interp1d(t, X[:, i], kind="linear", bounds_error=False, fill_value="extrapolate")
        X_new[:, i] = f(t_new)
    return t_new, X_new

def _coerce_cols(df: pd.DataFrame, cols: list[str]) -> tuple[np.ndarray, np.ndarray, list[str]]:
    out_cols, arrs, masks = [], [], []
    for c in cols:
        if c in df.columns:
            v = pd.to_numeric(df[c], errors="coerce").to_numpy()
            m = np.isfinite(v).astype(np.float32)
            v = np.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)
        else:
            v = np.zeros(len(df), dtype=float)
            m = np.zeros(len(df), dtype=np.float32)
        out_cols.append(c); arrs.append(v); masks.append(m)
    X = np.column_stack(arrs) if arrs else np.zeros((len(df), 0), dtype=float)
    M = np.column_stack(masks) if masks else np.zeros((len(df), 0), dtype=np.float32)
    return X, M, out_cols

def _et_interpolate_and_smooth(X: np.ndarray, valid_mask: np.ndarray, fs: float) -> tuple[np.ndarray, np.ndarray]:
    N, D = X.shape
    out = X.copy()
    feat_mask = np.zeros((N, D), dtype=np.float32)

    v = (valid_mask.astype(int) == 1)
    lim = max(1, int(round((ET_GAP_MS/1000.0) * fs)))

    for d in range(D):
        col = out[:, d].astype(float)
        m = np.isfinite(col) & v
        s = pd.Series(col)
        s[~m] = np.nan
        s_i = s.interpolate(method="linear", limit=lim, limit_direction="both")
        long_gap = s_i.isna().to_numpy()
        s_i = s_i.fillna(0.0)

        taps = ET_SMOOTH_TAPS if ET_SMOOTH_TAPS % 2 == 1 else ET_SMOOTH_TAPS + 1
        sm = pd.Series(s_i).rolling(window=taps, center=True, min_periods=1).mean().to_numpy()

        out[:, d] = sm
        feat_mask[:, d] = (~long_gap).astype(np.float32)

    return out, feat_mask

# ---------- EEG extras ----------
def _need_notch_auto(X, fs, db_thresh=30.0):
    try:
        rms = np.sqrt(np.nanmean(X**2, axis=0))
        if not np.isfinite(rms).any(): return None
        k = int(np.nanargmin(np.abs(rms - np.nanmedian(rms))))
        y = np.asarray(X[:, k], float)
        nper = min(len(y), 2048)
        if nper < 256: return None
        f, P = welch(y, fs=fs, nperseg=nper)
        Pdb = 10*np.log10(P + 1e-15)
        for f0 in (60.0, 50.0):
            if f0 >= 0.45*fs: continue
            band  = (f >= f0-0.5) & (f <= f0+0.5)
            left  = (f >= f0-4.0) & (f <  f0-1.5)
            right = (f >  f0+1.5) & (f <= f0+4.0)
            if not band.any() or not (left.any() and right.any()): continue
            peak = np.nanmax(Pdb[band])
            base = np.nanmean([np.nanmedian(Pdb[left]), np.nanmedian(Pdb[right])])
            if (np.isfinite(peak) and np.isfinite(base) and (peak - base) >= float(db_thresh)):
                return f0
    except Exception:
        pass
    return None

def _hampel_1d(y, fs, win_ms=120.0, n_sigma=8.0):
    y = np.asarray(y, float)
    k = max(1, int(round((win_ms/1000.0)*fs)))
    s = pd.Series(y)
    med = s.rolling(window=2*k+1, center=True, min_periods=max(1,k//2)).median().to_numpy()
    mad = (s-med).abs().rolling(window=2*k+1, center=True, min_periods=max(1,k//2)).median().to_numpy()
    sigma = 1.4826*mad + 1e-12
    bad = np.abs(y - med) > (n_sigma * sigma)
    if bad.any() and bad.mean() < 0.30:
        yc = y.copy(); yc[bad] = np.nan
        idx = np.arange(y.size); good = np.isfinite(yc)
        if good.sum() >= 2:
            y[~good] = np.interp(idx[~good], idx[good], yc[good])
    return y, float(bad.mean())

def _maybe_flip_polarity(y_raw, y_clean):
    xr = np.asarray(y_raw, float); xc = np.asarray(y_clean, float)
    m = np.isfinite(xr) & np.isfinite(xc)
    if m.sum() < 200: return y_clean, False
    corr = np.corrcoef(xr[m], xc[m])[0,1]
    ratio = (np.sqrt(np.mean(xc[m]**2))+1e-12) / (np.sqrt(np.mean(xr[m]**2))+1e-12)
    if corr <= -0.90 and abs(1.0 - ratio) <= 0.30:
        return -xc, True
    return xc, False

# ---------------- CANONICALIZATION ----------------
def canonicalize_columns_phase4(df: pd.DataFrame) -> pd.DataFrame:
    """
    Map common aliases to canonical names used by Phase-4:
      - 'Ch1 EMG raw' -> 'EMG_Ch1', ..., 'Ch4 EMG raw' -> 'EMG_Ch4'
      - 'Ch1'..'Ch8' (EEG) -> 'EEG_Ch1'..'EEG_Ch8'
      - Camera/gaze X/Y case variants -> ...Leftx/Lefty/Rightx/Righty
      - Timestamp_ms -> if only ms present, derive Timestamp_seconds
    """
    mapping = {}
    for c in df.columns:
        s = str(c).strip()

        # Time
        if s.lower() in {"timestamp_seconds", "timestamps_seconds"}:
            mapping[c] = "Timestamp_seconds"; continue
        if s.lower() in {"timestamp_ms", "timestamps_ms"}:
            mapping[c] = "Timestamp_ms"; continue

        # EMG raw aliases like "Ch1 EMG raw"
        m = re.match(r'(?i)ch\s*([1-4])\s*emg\s*raw', s)
        if m:
            mapping[c] = f"EMG_Ch{int(m.group(1))}"; continue

        # EEG aliases "Ch1".."Ch8"
        m = re.match(r'(?i)^ch\s*([1-8])$', s)
        if m:
            mapping[c] = f"EEG_Ch{int(m.group(1))}"; continue

        # Already-canonical EMG/EEG
        m = re.match(r'(?i)^emg[_\s]*ch\s*([1-4])$', s)
        if m:
            mapping[c] = f"EMG_Ch{int(m.group(1))}"; continue
        m = re.match(r'(?i)^eeg[_\s]*ch\s*([1-8])$', s)
        if m:
            mapping[c] = f"EEG_Ch{int(m.group(1))}"; continue

        # ET: normalize case of X/Y endings (LeftX->Leftx)
        if s.startswith("ET_"):
            s2 = s.replace("LeftX","Leftx").replace("LeftY","Lefty")
            s2 = s2.replace("RightX","Rightx").replace("RightY","Righty")
            mapping[c] = s2; continue

        mapping[c] = c

    out = df.rename(columns=mapping)

    # Ensure Timestamp_seconds exists (fallback from ms)
    if "Timestamp_seconds" not in out.columns:
        if "Timestamp_ms" in out.columns:
            out["Timestamp_seconds"] = pd.to_numeric(out["Timestamp_ms"], errors="coerce")/1000.0
        else:
            raise ValueError("Missing time: neither Timestamp_seconds nor Timestamp_ms found.")
    return out

# ---------------- PREPROCESSOR ----------------
def preprocess_file(csv_path: Path, rebuild: bool=False) -> dict:
    """
    Deterministic per-file preprocessing. Writes <csv>.preproc.npz by default.
    Returns a dictionary with arrays + a 'log' dict.
    """
    csv_path = Path(csv_path)
    out_npz = csv_path.with_suffix(csv_path.suffix + CACHE_SUFFIX)
    if out_npz.exists() and not rebuild:
        npz = np.load(out_npz, allow_pickle=True)
        return {"fs": float(npz["fs"]), "t": npz["t"], "log": json.loads(npz["log"].item())}

    df = pd.read_csv(csv_path, low_memory=False)
    df = canonicalize_columns_phase4(df)

    # ----- Timebase & fs -----
    t = pd.to_numeric(df["Timestamp_seconds"], errors="coerce").to_numpy()
    if len(t) < 4:
        raise ValueError(f"{csv_path.name}: too few samples")
    fs = _median_fs(t)
    if not np.isfinite(fs) or fs <= 0:
        raise ValueError(f"{csv_path.name}: invalid fs")

    # ----- Extract modalities -----
    EEG_X, EEG_M, EEG_cols = _coerce_cols(df, EEG_CH)
    EMG_X, EMG_M, EMG_cols = _coerce_cols(df, EMG_CH)
    ET_X,  ET_M,  ET_cols  = _coerce_cols(df, ET_CONT_CH)

    # ET flags (optional; used only to assist cleaning, not as features)
    vl = pd.to_numeric(df["ET_ValidityLeftEye"],  errors="coerce").to_numpy() if "ET_ValidityLeftEye"  in df.columns else None
    vr = pd.to_numeric(df["ET_ValidityRightEye"], errors="coerce").to_numpy() if "ET_ValidityRightEye" in df.columns else None
    worn = pd.to_numeric(df["ET_Worn"], errors="coerce").fillna(1).to_numpy() if "ET_Worn" in df.columns else np.ones(len(df))
    if vl is not None and vr is not None:
        ET_valid = ((vl <= ET_INVALID_GT) & (vr <= ET_INVALID_GT)).astype(int)
    else:
        ET_valid = (worn >= 0.5).astype(int)

    ET_Fix = pd.to_numeric(df["ET_Fixation"], errors="coerce").fillna(0).to_numpy() if "ET_Fixation" in df.columns else np.zeros(len(df))

    # ----- Resample (if needed) to TARGET_FS -----
    resampled = False
    if abs(fs/TARGET_FS - 1.0) > FS_TOL_FRAC:
        all_cols = EEG_cols + EMG_cols + ET_cols
        X_all = np.column_stack([EEG_X, EMG_X, ET_X]) if all_cols else np.zeros((len(t), 0))
        t_new, X_all = _resample_linear(t, X_all, TARGET_FS)
        fs = TARGET_FS; t = t_new; resampled = True

        # split back
        n_eeg, n_emg, n_et = len(EEG_cols), len(EMG_cols), len(ET_cols)
        EEG_X = X_all[:, :n_eeg] if n_eeg else np.zeros((len(t), 0))
        EMG_X = X_all[:, n_eeg:n_eeg+n_emg] if n_emg else np.zeros((len(t), 0))
        ET_X  = X_all[:, n_eeg+n_emg:] if n_et else np.zeros((len(t), 0))

        # nearest resample masks & flags
        t_old = pd.to_numeric(df["Timestamp_seconds"], errors="coerce").to_numpy()
        N_new = len(t)

        def _rs_mask(m_old):
            t0, t1 = float(t_old[0]), float(t_old[-1])
            t_new2 = np.linspace(t0, t1, N_new)
            if m_old.ndim == 1:
                f = interp1d(t_old, m_old, kind="nearest", bounds_error=False, fill_value=(m_old[0], m_old[-1]))
                return (f(t_new2) > 0.5).astype(np.float32)
            out = np.zeros((N_new, m_old.shape[1]), np.float32)
            for i in range(m_old.shape[1]):
                f = interp1d(t_old, m_old[:, i], kind="nearest", bounds_error=False, fill_value=(m_old[0, i], m_old[-1, i]))
                out[:, i] = (f(t_new2) > 0.5).astype(np.float32)
            return out

        EEG_M = _rs_mask(EEG_M) if EEG_M.size else EEG_M
        EMG_M = _rs_mask(EMG_M) if EMG_M.size else EMG_M
        ET_M  = _rs_mask(ET_M)  if ET_M.size  else ET_M

        def _rs1(x_old):
            t0, t1 = float(t_old[0]), float(t_old[-1])
            t_new3 = np.linspace(t0, t1, N_new)
            f = interp1d(t_old, x_old, kind="nearest", bounds_error=False, fill_value=(x_old[0], x_old[-1]))
            return f(t_new3)

        ET_valid = (_rs1(ET_valid) > 0.5).astype(int)
        ET_Fix   = (_rs1(ET_Fix)   > 0.5).astype(int)

    # ----- EEG: BPF -> (optional notch) -> auto-notch -> CAR -> repair/polarity -----
    eeg_auto_notch_hz = None
    eeg_repair_frac   = []
    if EEG_X.shape[1] > 0:
        sos_bp = _butter_band_sos(EEG_BAND[0], EEG_BAND[1], fs, EEG_ORDER)
        EEG_f = _apply_sos(EEG_X, sos_bp)

        if APPLY_NOTCH_60: EEG_f = _iir_notch_fs(EEG_f, fs, 60.0, Q=NOTCH_Q)
        if APPLY_NOTCH_50: EEG_f = _iir_notch_fs(EEG_f, fs, 50.0, Q=NOTCH_Q)

        if EEG_AUTO_NOTCH:
            f0_auto = _need_notch_auto(EEG_f, fs, db_thresh=EEG_NOTCH_DB)
            if f0_auto is not None:
                EEG_f = _iir_notch_fs(EEG_f, fs, f0_auto, Q=NOTCH_Q)
                eeg_auto_notch_hz = float(f0_auto)

        ref = np.median(EEG_f, axis=1, keepdims=True) if EEG_CAR_MODE == "median" else np.mean(EEG_f, axis=1, keepdims=True)
        EEG_car = (EEG_f - ref).astype(np.float32)

        if EEG_REPAIR != "none":
            repaired = np.empty_like(EEG_car)
            for k in range(EEG_car.shape[1]):
                y = EEG_car[:, k]
                if EEG_REPAIR == "hampel":
                    y, fb = _hampel_1d(y, fs, win_ms=EEG_REPAIR_WIN_MS, n_sigma=EEG_REPAIR_SIGMA)
                else:
                    mu  = np.nanmedian(y)
                    mad = np.nanmedian(np.abs(y - mu)) + 1e-9
                    z   = (y - mu) / (1.4826*mad)
                    bad = np.abs(z) > EEG_REPAIR_SIGMA
                    fb  = float(bad.mean()) if bad.size else 0.0
                    if bad.any() and fb < 0.30:
                        yc = y.copy(); yc[bad] = np.nan
                        idx = np.arange(y.size); good = np.isfinite(yc)
                        if good.sum() >= 2:
                            y[~good] = np.interp(idx[~good], idx[good], yc[good])
                if EEG_POLARITY_CHECK:
                    y, _ = _maybe_flip_polarity(EEG_X[:, k], y)
                repaired[:, k] = y
                eeg_repair_frac.append(fb)
            EEG_car = repaired.astype(np.float32)
        EEG_mask = EEG_M.astype(np.float32)
    else:
        EEG_car = EEG_X.astype(np.float32)
        EEG_mask = EEG_M.astype(np.float32)

    # ----- EMG: envelope -----
    if EMG_X.shape[1] > 0:
        if EMG_ENVELOPE_MODE == "hp_rect_lp":
            sos_hp = _butter_high_sos(EMG_HP_HZ, fs, order=EMG_ORDER)
            hp = _apply_sos(EMG_X, sos_hp)
            rect = np.abs(hp)
            sos_lp = _butter_low_sos(EMG_LP_HZ, fs, order=EMG_ORDER)
            EMG_env = _apply_sos(rect, sos_lp)
        else:
            sos_bp = _butter_band_sos(EMG_BP[0], EMG_BP[1], fs, order=EMG_ORDER)
            bp = _apply_sos(EMG_X, sos_bp)
            rect = np.abs(bp)
            taps = max(1, int(round((MA_MS/1000.0) * fs)))
            if taps % 2 == 0: taps += 1
            EMG_env = _moving_average(rect, taps)
        EMG_env = EMG_env.astype(np.float32)
        EMG_mask = EMG_M.astype(np.float32)
    else:
        EMG_env = EMG_X.astype(np.float32)
        EMG_mask = EMG_M.astype(np.float32)

    # ----- ET: validity-assisted interpolation/smoothing -----
    if ET_X.shape[1] > 0:
        ET_clean, ET_feat_mask = _et_interpolate_and_smooth(ET_X, ET_valid, fs)
        ET_mask = (ET_feat_mask * ET_M).astype(np.float32) if ET_M.size else ET_feat_mask.astype(np.float32)
        ET_clean = ET_clean.astype(np.float32)
    else:
        ET_clean = ET_X.astype(np.float32)
        ET_mask = ET_M.astype(np.float32)

    # ----- Final NaN/Inf sanitization -----
    EEG_car  = np.nan_to_num(EEG_car,  nan=0.0, posinf=0.0, neginf=0.0)
    EMG_env  = np.nan_to_num(EMG_env,  nan=0.0, posinf=0.0, neginf=0.0)
    ET_clean = np.nan_to_num(ET_clean, nan=0.0, posinf=0.0, neginf=0.0)
    ET_valid = np.where(np.isfinite(ET_valid), ET_valid, 0).astype(np.int16)

    # ----- Build outputs -----
    N = len(t)
    out_log = {
        "file": str(csv_path.resolve()),
        "n_samples": int(N),
        "fs": float(fs),
        "resampled_to_250": bool(resampled),
        "EEG_band": EEG_BAND,
        "notch_60_fixed": bool(APPLY_NOTCH_60),
        "notch_50_fixed": bool(APPLY_NOTCH_50),
        "EEG_auto_notch_used_hz": eeg_auto_notch_hz,
        "EEG_CAR": EEG_CAR_MODE,
        "EEG_repair": EEG_REPAIR,
        "EEG_repair_frac_mean": float(np.mean(eeg_repair_frac)) if eeg_repair_frac else 0.0,
        "emg_mode": EMG_ENVELOPE_MODE,
        "et_gap_ms": ET_GAP_MS,
        "et_smooth_taps": ET_SMOOTH_TAPS,
        "missing_EEG": [c for c in EEG_CH if c not in EEG_cols],
        "missing_EMG": [c for c in EMG_CH if c not in EMG_cols],
        "missing_ET":  [c for c in ET_CONT_CH if c not in ET_cols],
    }

    out = {
        "fs": float(fs),
        "t": t.astype(float),
        "EEG": EEG_car,
        "EEG_mask": EEG_mask,
        "EEG_ch": EEG_cols,
        "EMG_env": EMG_env,
        "EMG_mask": EMG_mask,
        "EMG_ch": EMG_cols,
        "ET": ET_clean,
        "ET_mask": ET_mask,
        "ET_ch": ET_cols,
        "ET_valid": ET_valid.astype(np.int16),  # for QA; not a model feature
        "log": out_log,
    }

    # ----- Save cache -----
    if SAVE_CACHE:
        np.savez_compressed(
            out_npz,
            fs=out["fs"], t=out["t"],
            EEG=out["EEG"], EEG_mask=out["EEG_mask"], EEG_ch=np.array(out["EEG_ch"], dtype=object),
            EMG_env=out["EMG_env"], EMG_mask=out["EMG_mask"], EMG_ch=np.array(out["EMG_ch"], dtype=object),
            ET=out["ET"], ET_mask=out["ET_mask"], ET_ch=np.array(out["ET_ch"], dtype=object),
            ET_valid=out["ET_valid"],
            log=json.dumps(out["log"])
        )
        out["log"]["cache"] = str(out_npz)

    return out

# ---------------- BATCH VIA MANIFEST ----------------
def preprocess_from_manifest(manifest_csv: Path, limit_files: int|None=None, rebuild: bool=False):
    man = pd.read_csv(manifest_csv)
    paths = man["file"].tolist()
    if limit_files:
        paths = paths[:int(limit_files)]
    n_ok = 0; n_fail = 0
    for i, p in enumerate(paths, 1):
        p = Path(p)
        try:
            res = preprocess_file(p, rebuild=rebuild)
            print(f"[{i}/{len(paths)}] ok: {p.name} | fs={res['fs']:.2f} N={len(res['t'])}")
            n_ok += 1
        except Exception as e:
            print(f"[{i}/{len(paths)}] FAIL: {p.name} | {e}")
            n_fail += 1
    print(f"[summary] ok={n_ok} fail={n_fail}")

# ---------------- CLI (Jupyter-safe) ----------------
def _build_cli():
    ap = argparse.ArgumentParser(description="Phase-4 deterministic preprocessing")
    ap.add_argument("--file", type=str, help="Path to a single label CSV to preprocess")
    ap.add_argument("--manifest", type=str, help="Path to Phase-3 manifest_v1.csv")
    ap.add_argument("--limit", type=int, default=None, help="Limit number of files when using --manifest")
    ap.add_argument("--rebuild", action="store_true", help="Overwrite existing .preproc.npz caches")
    return ap

if __name__ == "__main__":
    parser = _build_cli()
    args, _unknown = parser.parse_known_args()

    if args.file:
        out = preprocess_file(Path(args.file), rebuild=args.rebuild)
        print(json.dumps(out["log"], indent=2))
    else:
        manifest = Path(args.manifest) if args.manifest else MANIFEST_DEFAULT
        if not manifest.exists():
            raise SystemExit(f"[stop] manifest not found: {manifest}")
        preprocess_from_manifest(manifest, limit_files=args.limit, rebuild=args.rebuild)
