In [None]:
# =============================================================================
# Tri-Modal Synchronizer — Lite-Union++ v3 (T-files, cleaned schema)
# =============================================================================
from __future__ import annotations
import os, glob, json, warnings
import numpy as np
import pandas as pd
from pathlib import Path
from scipy import signal
from scipy.interpolate import interp1d

# -------------------------
# CONFIG
# -------------------------
INPUT_DIR  = r"/home/tsultan1/BioRob(Final)/Data/Sub-23/cleaned"
OUTPUT_DIR = r"/home/tsultan1/BioRob(Final)/Data/Sub-23/cleaned/synchronized_proper_lite_union_v3"

GRID_MODE    = "fixed"     # "fixed" (target grid) or "native"
FIXED_HZ     = 250         # target fs
SPAN_MODE    = "all"       # "all" (union) or "overlap"
GRID_STEP_MS = 1000.0 / FIXED_HZ  # 4 ms

# Cross-corr params (for lag estimation)
CROSS_FS       = 1000       # high-res interpolation for xcorr/GCC
XCORR_MIN_SEC  = 6.0
XCORR_WIN_SEC  = 12.0
NWIN_BINS      = 5
NWIN_SELECT    = 3
MIN_ACCEPTED   = 2

# Filters (for feature extraction during lag estimation)
EEG_BP      = (1.0, 40.0)   # Hz
EMG_ENV_LP  = 10.0          # Hz
ET_VEL_BP   = (0.5, 10.0)   # Hz
BP_ORDER    = 3

# Anti-aliasing (for downsampling to FIXED_HZ)
AA_ENABLE   = True
AA_ORDER    = 8
AA_MARGIN   = 0.45          # cutoff = min(0.49, margin)*FIXED_HZ (≈112.5 Hz at 250 Hz)
AA_TOL      = 1.05          # AA only when native fs > FIXED_HZ * AA_TOL
AA_DISCRETE_COLS = {
    'ET_ValidityLeftEye','ET_ValidityRightEye','ET_Blink','ET_Fixation',
    'ET_Worn','EventSource.3'
}

# Guards/fallbacks
LAG_MIN_PEAK_RATIO        = 4.0
EMGEEG_LAG_ABS_REJECT_MS  = 150.0     # reject absurd EMG↔EEG lags
FALLBACK_MAX_ABS_MS       = 50.0
ET_XCORR_MIN_PR           = 8.0
ET_XCORR_MAX_ABS_MS       = 60.0
ET_REVERT_MARGIN          = 0.007     # need corr@lag > corr@0 + margin

# >>> New robust-bounding knobs <<<
EMGEEG_LAG_SEARCH_MAX_MS = 200.0   # hard bound for EMG↔EEG search
BOUNDARY_HIT_FRACTION    = 0.90    # consider edge-hit if |lag| > 0.9*max
TAPER_WINDOWS            = True    # Hann taper in _prep_for_corr
NAN_INTERP_FOR_CORR      = True    # fill NaNs linearly before corr
AGREE_TOL_MS             = 30.0    # xcorr & gcc-phat should agree within this

# Output formatting
FLOAT_DECIMALS        = 4
KEEP_LSL_TIMESTAMP    = False
WRITE_JSON_LOG        = True
SCALE_EMG_TO_uV       = False      # keep centered raw or mV as-is

# Optional soft prior (tune for your rig)
EMG_EEG_PRIOR_MS = -28.0   # ms (if you have a known offset)
PRIOR_STRENGTH   = 0.6     # 0..1

warnings.filterwarnings("ignore")

# -------------------------
# Utilities
# -------------------------
def load_and_prepare(path: str | Path) -> pd.DataFrame:
    """Respect existing Timestamp_seconds. Sort & dedup time for monotonicity."""
    df = pd.read_csv(path, encoding="utf-8-sig")
    if 'Timestamp_seconds' in df.columns:
        df['Timestamp_seconds'] = pd.to_numeric(df['Timestamp_seconds'], errors='coerce')
    else:
        # Fallbacks only if Timestamp_seconds is missing (shouldn't happen)
        ts_candidates = [
            'Timestamp', 'iMotions_Synchronization_Timestamp', 'iMotions_Synchronization_Timestamp(ms)'
        ]
        ts_col = next((c for c in ts_candidates if c in df.columns), None)
        if ts_col is not None:
            df['Timestamp_seconds'] = pd.to_numeric(df[ts_col], errors='coerce') / 1000.0
        elif 'ET_TimeSignal' in df.columns:
            v = pd.to_numeric(df['ET_TimeSignal'], errors='coerce')
            df['Timestamp_seconds'] = (v/1000.0) if v.dropna().max() > 1e4 else v.astype(float)
            warnings.warn("Using ET_TimeSignal as fallback timebase.")
        else:
            raise ValueError("No global timestamp column found.")

    df = df.sort_values('Timestamp_seconds').drop_duplicates('Timestamp_seconds')
    return df

def split_modalities(df: pd.DataFrame):
    """Exact modality selection per your schema."""
    emg_cols = [c for c in ["Ch1 EMG raw","Ch2 EMG raw","Ch3 EMG raw","Ch4 EMG raw"] if c in df.columns]
    eeg_cols = [c for c in [f"Ch{i}" for i in range(1,9)] if c in df.columns]
    et_cols  = [c for c in df.columns if c.startswith("ET_")]

    emg = df if emg_cols else df.iloc[0:0].copy()
    eeg = df if eeg_cols else df.iloc[0:0].copy()
    et  = df if et_cols  else df.iloc[0:0].copy()
    return emg, eeg, et, emg_cols, eeg_cols, et_cols

def estimate_fs(ts: np.ndarray) -> float:
    if ts is None or len(ts) < 3: return 0.0
    dt = np.diff(ts); dt = dt[(dt > 0) & (dt < 1.0)]
    if len(dt)==0: return 0.0
    return 1.0/np.median(dt)

def _butter_band(x, fs, f1, f2, order=BP_ORDER):
    try:
        ny = fs/2.0
        if 0 < f1 < f2 < ny:
            b,a = signal.butter(order, [f1/ny, f2/ny], btype='band')
            return signal.filtfilt(b,a,x)
    except Exception:
        pass
    return x

def _butter_lp(x, fs, fc, order=BP_ORDER):
    try:
        ny = fs/2.0
        if 0 < fc < ny:
            b,a = signal.butter(order, fc/ny, btype='low')
            return signal.filtfilt(b,a,x)
    except Exception:
        pass
    return x

def _interp_vec(t_src, y_src, t_grid, kind='linear'):
    m = np.isfinite(t_src) & np.isfinite(y_src)
    if m.sum() < 2: return np.full_like(t_grid, np.nan, dtype=float)
    f = interp1d(t_src[m], y_src[m], kind=kind, bounds_error=False, fill_value=np.nan)
    return f(t_grid)

def _round_to_grid_ms(lag_ms: float) -> float:
    return float(np.round(lag_ms / GRID_STEP_MS) * GRID_STEP_MS)

def _overlap_limits(t1, t2):
    return max(np.nanmin(t1), np.nanmin(t2)), min(np.nanmax(t1), np.nanmax(t2))

def _pick_energy_windows(t, x, ov_start, ov_end, n_bins=NWIN_BINS, n_select=NWIN_SELECT, win_sec=XCORR_WIN_SEC):
    if ov_end - ov_start <= win_sec + 1.0:
        return [(max(ov_start, ov_start+0.5), min(ov_end, ov_end-0.5))]
    edges = np.linspace(ov_start+0.5, ov_end-0.5, n_bins+1)
    cand = []
    for i in range(n_bins):
        a = edges[i]
        b_target = a + win_sec
        b = edges[i+1] if (edges[i+1] - a) < win_sec else b_target
        m = (t>=a) & (t<=b)
        if m.sum() < 10:
            continue
        v = np.nanvar(x[m])
        cand.append((v, a, b))
    cand.sort(key=lambda z: z[0], reverse=True)
    wins = [(a,b) for _,a,b in cand[:max(1, n_select)]]
    return wins if wins else [(ov_start+0.5, min(ov_end-0.5, ov_start+0.5+win_sec))]

def _prep_for_corr(t, x, fs, seg_start, seg_end, bp=None):
    m = (t>=seg_start) & (t<=seg_end)
    if m.sum() < 10: return None, None
    n = int((seg_end-seg_start)*fs)
    if n < 100: return None, None
    th = np.linspace(seg_start, seg_end, n, endpoint=False)
    xh = _interp_vec(t[m], x[m], th)

    # Fill NaNs BEFORE standardizing
    if NAN_INTERP_FOR_CORR:
        idx = np.arange(len(xh))
        mm = np.isfinite(xh)
        if mm.sum() >= 2:
            f = interp1d(idx[mm], xh[mm], bounds_error=False,
                         fill_value=(xh[mm][0], xh[mm][-1]))
            xh = f(idx)
        else:
            xh = np.nan_to_num(xh)

    # Standardize
    xh = (xh - np.nanmean(xh)) / (np.nanstd(xh) + 1e-12)
    xh = np.nan_to_num(xh)

    # Bandpass if requested
    if bp is not None:
        xh = _butter_band(xh, fs, bp[0], bp[1])

    # Taper to reduce edge energy
    if TAPER_WINDOWS:
        win = np.hanning(len(xh))
        if win.max() > 0:
            xh = xh * win

    return th, xh

def _bounded_peak(lags, corr, fs, lag_max_ms):
    # Limit to ±lag_max_ms
    max_samp = int(np.floor((lag_max_ms/1000.0) * fs))
    m = (lags >= -max_samp) & (lags <= max_samp)
    if not np.any(m): return np.nan, np.nan, np.nan
    l = lags[m]; c = corr[m]

    # Peak chosen by |c|, but PR computed on |c| too
    ai = int(np.argmax(np.abs(c)))
    peak = c[ai]
    pr = float((np.abs(peak)) / (np.median(np.abs(c)) + 1e-12))

    return float(l[ai]), float(peak), pr

def _xcorr_lag(th, x1h, x2h, fs, lag_max_ms):
    corr = signal.correlate(x1h, x2h, mode='full')
    lags = signal.correlation_lags(len(x1h), len(x2h), mode='full')
    lag_samp, peak, pr = _bounded_peak(lags, corr, fs, lag_max_ms)
    if not np.isfinite(lag_samp): return np.nan, np.nan
    return (lag_samp / fs) * 1000.0, pr

def _gcc_phat_lag(x1h, x2h, fs, lag_max_ms):
    n = int(2**np.ceil(np.log2(len(x1h) + len(x2h) - 1)))
    X1 = np.fft.rfft(x1h, n=n); X2 = np.fft.rfft(x2h, n=n)
    R = X1*np.conj(X2)
    denom = np.abs(R) + 1e-12
    R /= denom
    corr = np.fft.irfft(R, n=n)
    # Create "full" alignment like correlate
    corr = np.concatenate((corr[-(len(x1h)-1):], corr[:len(x2h)]))
    lags = np.arange(-len(x1h)+1, len(x2h))
    lag_samp, peak, pr = _bounded_peak(lags, corr, fs, lag_max_ms)
    if not np.isfinite(lag_samp): return np.nan, np.nan
    return (lag_samp / fs) * 1000.0, pr

def _corr_at_specific_lag(t1, x1, t2, x2, lag_ms, fs, seg_start, seg_end):
    th, _ = _prep_for_corr(t1, x1, fs, seg_start, seg_end, None)
    if th is None: return np.nan
    x2h = _interp_vec(t2, x2, th)
    th_shift = th - (lag_ms/1000.0)
    x1hs = _interp_vec(t1, x1, th_shift)
    x1hs = (x1hs - np.nanmean(x1hs)) / (np.nanstd(x1hs) + 1e-12)
    x2h  = (x2h  - np.nanmean(x2h )) / (np.nanstd(x2h ) + 1e-12)
    c = np.nan_to_num(np.corrcoef(x1hs, x2h)[0,1])
    return float(c)

# -------------- Anti-alias helpers --------------
def _fill_nan_linear(y):
    y = np.asarray(y, dtype=float)
    m = np.isfinite(y)
    if m.sum() < 2:
        return np.nan_to_num(y)
    idx = np.flatnonzero(m)
    f = interp1d(idx, y[m], bounds_error=False,
                 fill_value=(y[m][0], y[m][-1]))
    out = y.copy()
    out[~m] = f(np.flatnonzero(~m))
    return out

def _aa_lowpass(y, fs_native, fs_target, order=AA_ORDER, margin=AA_MARGIN):
    """Zero-phase Butterworth LP before resampling to fs_target."""
    if not np.isfinite(fs_native) or fs_native <= 0:
        return y
    fc = min(0.49 * fs_target, margin * fs_target)
    ny = fs_native / 2.0
    if fc >= ny:
        return y
    y_filled = _fill_nan_linear(y)
    try:
        b, a = signal.butter(order, fc / ny, btype='low')
        padlen = max(0, min(3*max(len(a),len(b)), len(y_filled)-1))
        if padlen < 1:
            return y_filled
        return signal.filtfilt(b, a, y_filled, padlen=padlen)
    except Exception:
        return y

# -------------------------
# EMG↔EEG lag (robust)
# -------------------------
def estimate_emg_eeg_lag(emg: pd.DataFrame, eeg: pd.DataFrame, emg_cols, eeg_cols):
    if not emg_cols or not eeg_cols or len(emg)<10 or len(eeg)<10:
        return 0.0, np.nan, None

    tE = emg['Timestamp_seconds'].values.astype(float)
    fsE = estimate_fs(tE)

    # EMG envelope: pick the most energetic EMG raw channel
    best_env, best_var, best_emg_ch = None, -np.inf, None
    for c in emg_cols:
        x = pd.to_numeric(emg[c], errors='coerce').to_numpy(dtype=float)
        x = _butter_lp(np.abs(np.nan_to_num(x)), fsE, EMG_ENV_LP)
        v = np.nanvar(x[np.isfinite(x)])
        if v > best_var:
            best_var, best_env, best_emg_ch = v, x, c

    tG = eeg['Timestamp_seconds'].values.astype(float)
    fsG = estimate_fs(tG)
    ov_start, ov_end = _overlap_limits(tE, tG)
    wins = _pick_energy_windows(tE, best_env, ov_start, ov_end)

    # choose EEG channel by highest median PR across windows
    best = {'score': -np.inf, 'ch': None}
    for ec in eeg_cols:
        eeg_x = pd.to_numeric(eeg[ec], errors='coerce').to_numpy(dtype=float)
        eeg_x = _butter_band(np.nan_to_num(eeg_x), fsG, EEG_BP[0], EEG_BP[1])
        prs = []
        for (a,b) in wins:
            th1, env = _prep_for_corr(tE, best_env, CROSS_FS, a, b, None)
            th2, eegf = _prep_for_corr(tG, eeg_x, CROSS_FS, a, b, None)
            if th1 is None or th2 is None: 
                continue
            _, pr1 = _xcorr_lag(th1, env, eegf, CROSS_FS, EMGEEG_LAG_SEARCH_MAX_MS)
            prs.append(pr1)
        if prs:
            score = float(np.median(prs))
            if score > best['score']:
                best = {'score': score, 'ch': ec}

    chosen_eeg = best['ch']
    if chosen_eeg is None:
        s1 = np.nanmedian(emg['Timestamp_seconds'])
        s2 = np.nanmedian(eeg['Timestamp_seconds'])
        lag = float(np.clip((s1-s2)*1000.0, -FALLBACK_MAX_ABS_MS, FALLBACK_MAX_ABS_MS))
        return _round_to_grid_ms(lag), np.nan, None

    # With chosen EEG ch, collect lags from both estimators across windows
    eeg_x = pd.to_numeric(eeg[chosen_eeg], errors='coerce').to_numpy(dtype=float)
    eeg_x = _butter_band(np.nan_to_num(eeg_x), fsG, EEG_BP[0], EEG_BP[1])

    lags, weights = [], []
    maxlag = EMGEEG_LAG_SEARCH_MAX_MS
    boundary_hits = 0

    for (a,b) in wins[:NWIN_SELECT]:
        th1, env = _prep_for_corr(tE, best_env, CROSS_FS, a, b, None)
        th2, eegf = _prep_for_corr(tG, eeg_x,  CROSS_FS, a, b, None)
        if th1 is None or th2 is None: 
            continue

        lagA, prA = _xcorr_lag(th1, env, eegf, CROSS_FS, maxlag)
        lagB, prB = _gcc_phat_lag(env, eegf, CROSS_FS, maxlag)

        for lag, pr in ((lagA, prA), (lagB, prB)):
            if not np.isfinite(lag) or not np.isfinite(pr): 
                continue
            lags.append(lag); weights.append(max(pr, 1.0))
            if abs(lag) > BOUNDARY_HIT_FRACTION * maxlag:
                boundary_hits += 1

    # Not enough evidence → small fallback
    if len(lags) < MIN_ACCEPTED:
        s1 = np.nanmedian(emg['Timestamp_seconds'])
        s2 = np.nanmedian(eeg['Timestamp_seconds'])
        use = float(np.clip((s1-s2)*1000.0, -FALLBACK_MAX_ABS_MS, FALLBACK_MAX_ABS_MS))
        return _round_to_grid_ms(use), np.nan, chosen_eeg

    # Require estimator agreement around the median
    l = np.array(lags, float); w = np.array(weights, float)
    med = np.median(l)
    agree = np.abs(l - med) <= AGREE_TOL_MS
    if agree.any():
        l = l[agree]; w = w[agree]

    # Add prior
    if EMG_EEG_PRIOR_MS is not None and np.isfinite(EMG_EEG_PRIOR_MS):
        l = np.append(l, EMG_EEG_PRIOR_MS)
        w = np.append(w, PRIOR_STRENGTH * (np.max(w) if len(w) else 1.0))

    # Weighted median
    order = np.argsort(l); l = l[order]; w = w[order]
    cw = np.cumsum(w) / np.sum(w)
    use_lag = float(l[min(np.searchsorted(cw, 0.5), len(l)-1)])

    # Final guards: boundary hit or absurd magnitude → fallback to 0
    if (abs(use_lag) > EMGEEG_LAG_ABS_REJECT_MS) or (boundary_hits >= 2):
        use_lag = 0.0

    return _round_to_grid_ms(use_lag), float(np.median(w)) if len(w) else np.nan, chosen_eeg

# -------------------------
# EEG↔ET lag (bounded + conservative)
# -------------------------
def _et_velocity(et: pd.DataFrame, prefer=('ET_GazeLeftx','ET_GazeRightx','ET_GazeDirectionX','ET_GazeX')):
    for name in prefer:
        if name in et.columns:
            x = pd.to_numeric(et[name], errors='coerce').to_numpy(dtype=float)
            t = et['Timestamp_seconds'].values.astype(float)
            if len(x) >= 3:
                dt = np.diff(t); dt[dt<=0] = np.nan
                v = np.empty_like(x, dtype=float); v[:] = np.nan
                dx = np.diff(x); v[1:] = dx / dt
                return t, v
    return None, None

def estimate_eeg_et_lag(eeg: pd.DataFrame, et: pd.DataFrame, eeg_ch: str | None):
    if len(eeg)<10 or len(et)<10 or eeg_ch is None:
        return 0.0
    tE = eeg['Timestamp_seconds'].values.astype(float)
    fsE = estimate_fs(tE)
    eeg_x = pd.to_numeric(eeg[eeg_ch], errors='coerce').to_numpy(dtype=float)
    eeg_x = _butter_band(np.nan_to_num(eeg_x), fsE, EEG_BP[0], EEG_BP[1])

    tT, vT = _et_velocity(et)
    if tT is None: 
        return 0.0
    fsT = estimate_fs(tT)
    vT = _butter_band(np.nan_to_num(vT), fsT, ET_VEL_BP[0], ET_VEL_BP[1])

    ov_start, ov_end = _overlap_limits(tT, tE)
    a,b = _pick_energy_windows(tT, vT, ov_start, ov_end, n_bins=3, n_select=1, win_sec=XCORR_WIN_SEC)[0]
    th1, vh = _prep_for_corr(tT, vT, CROSS_FS, a, b, None)
    th2, eh = _prep_for_corr(tE, eeg_x, CROSS_FS, a, b, None)
    if th1 is None or th2 is None: 
        return 0.0

    lag_ms, pr = _xcorr_lag(th1, vh, eh, CROSS_FS, ET_XCORR_MAX_ABS_MS)
    # Conservative checks
    if (not np.isfinite(lag_ms)) or (not np.isfinite(pr)) or (pr < ET_XCORR_MIN_PR) or (abs(lag_ms) > ET_XCORR_MAX_ABS_MS):
        return 0.0

    corr0   = _corr_at_specific_lag(tT, vT, tE, eeg_x, 0.0, CROSS_FS, a, b)
    corrlag = _corr_at_specific_lag(tT, vT, tE, eeg_x, lag_ms, CROSS_FS, a, b)
    if (not np.isfinite(corr0)) or (not np.isfinite(corrlag)) or (corrlag <= corr0 + ET_REVERT_MARGIN):
        return 0.0

    return _round_to_grid_ms(lag_ms)

# -------------------------
# Grid & interpolation (with AA)
# -------------------------
def make_time_grid(emg, eeg, et):
    spans = []
    for seg in (emg, eeg, et):
        if len(seg):
            spans.append((float(np.nanmin(seg['Timestamp_seconds'])),
                          float(np.nanmax(seg['Timestamp_seconds']))))
    if not spans: return np.array([])
    if SPAN_MODE == "overlap":
        tmin = max(s for s,_ in spans); tmax = min(e for _,e in spans)
    else:
        tmin = min(s for s,_ in spans); tmax = max(e for _,e in spans)
    if not np.isfinite(tmin) or not np.isfinite(tmax) or tmax<=tmin:
        return np.array([])
    if GRID_MODE == "fixed":
        n = int(np.floor((tmax - tmin)*FIXED_HZ))
        return tmin + np.arange(n)/FIXED_HZ
    if len(emg): return emg['Timestamp_seconds'].values
    if len(eeg): return eeg['Timestamp_seconds'].values
    return et['Timestamp_seconds'].values if len(et) else np.array([])

def apply_shift_and_interp(seg, t_col, cols, t_grid, lag_ms=0.0, kind='linear'):
    """
    Returns:
        out_dict, aa_list
    out_dict: {col_name: resampled_vector}
    aa_list:  [channel names that were anti-aliased]
    """
    if len(seg)==0 or len(t_grid)==0 or not cols: 
        return {}, []

    out = {}
    aa_list = []
    t_src0 = seg[t_col].values.astype(float)
    fs_native = estimate_fs(t_src0)
    t_src = t_src0 - (lag_ms/1000.0)

    for c in cols:
        if c not in seg.columns:
            continue
        y = pd.to_numeric(seg[c], errors='coerce').to_numpy(dtype=float)

        # Anti-alias if downsampling and column is continuous
        did_aa = False
        if (AA_ENABLE and GRID_MODE == "fixed"
            and np.isfinite(fs_native) and fs_native > FIXED_HZ * AA_TOL
            and c not in AA_DISCRETE_COLS):
            y = _aa_lowpass(y, fs_native, FIXED_HZ)
            did_aa = True

        # Optional EMG scaling (leave False for your centered raw)
        if SCALE_EMG_TO_uV and ('EMG' in c):
            y = y * 1000.0

        out[c] = _interp_vec(
            t_src, y, t_grid,
            kind=('nearest' if c in AA_DISCRETE_COLS else kind)
        )
        if did_aa:
            aa_list.append(c)

    # Optional LSL passthrough
    if KEEP_LSL_TIMESTAMP and ('LSL Timestamp' in seg.columns) and (t_col=='Timestamp_seconds'):
        y = pd.to_numeric(seg['LSL Timestamp'], errors='coerce').to_numpy(dtype=float)
        out['EEG_LSL_Timestamp'] = _interp_vec(t_src, y, t_grid, 'linear')

    return out, aa_list

# -------------------------
# Per-file processing
# -------------------------
def process_one(file_path, out_dir):
    filename = os.path.basename(file_path); stem = Path(filename).stem
    print("\n" + "="*70); print(f"Processing: {filename}"); print("="*70)

    df = load_and_prepare(file_path)
    emg, eeg, et, emg_cols, eeg_cols, et_cols = split_modalities(df)
    print(f"Samples → EMG:{len(emg)} EEG:{len(eeg)} ET:{len(et)}")
    fs_emg = estimate_fs(emg['Timestamp_seconds'].values) if len(emg) else 0
    fs_eeg = estimate_fs(eeg['Timestamp_seconds'].values) if len(eeg) else 0
    fs_et  = estimate_fs(et['Timestamp_seconds'].values)  if len(et)  else 0
    print(f"Native fs → EMG={fs_emg:.1f} Hz  EEG={fs_eeg:.1f} Hz  ET={fs_et:.1f} Hz")

    emg_eeg_lag, pr_med, best_eeg_ch = estimate_emg_eeg_lag(emg, eeg, emg_cols, eeg_cols)
    pr_show = np.nan if pr_med is None or not np.isfinite(pr_med) else pr_med
    print(f"EMG–EEG lag (ms): {emg_eeg_lag:+.1f}  [median_PR={pr_show if isinstance(pr_show,float) else pr_show}]  | EEG*={best_eeg_ch}")

    eeg_et_lag = estimate_eeg_et_lag(eeg, et, best_eeg_ch)
    print(f"EEG–ET  lag (ms): {eeg_et_lag:+.1f}  [{'xcorr' if eeg_et_lag!=0 else 'safe=0'}]")

    # Derived EMG–ET lag (relative to EEG reference)
    emg_et_lag = float(emg_eeg_lag - eeg_et_lag)

    t_grid = make_time_grid(emg, eeg, et)
    if len(t_grid)==0: 
        raise RuntimeError("Empty time grid.")
    print(f"Using FIXED grid @ {FIXED_HZ} Hz: {len(t_grid)} rows.")

    out = {'Timestamp_seconds': t_grid, 'Timestamp_ms': t_grid*1000.0}
    aa_log = {"emg": [], "eeg": [], "et": []}

    if emg_cols:
        emg_interp, aa_emg = apply_shift_and_interp(emg, 'Timestamp_seconds', emg_cols, t_grid, lag_ms=emg_eeg_lag)
        for c,v in emg_interp.items():
            out[c] = v  # keep original names (e.g., "Ch1 EMG raw")
        aa_log["emg"].extend(aa_emg)
        print(f"Added {len(emg_interp)} EMG channels  (AA: {len(aa_emg)})")

    if eeg_cols:
        eeg_interp, aa_eeg = apply_shift_and_interp(eeg, 'Timestamp_seconds', eeg_cols, t_grid, lag_ms=0.0)
        for c,v in eeg_interp.items():
            out[c] = v  # names already Ch1..Ch8 (µV)
        aa_log["eeg"].extend(aa_eeg)
        print(f"Added {len(eeg_interp)} EEG channels  (AA: {len(aa_eeg)})")

    if et_cols:
        et_interp, aa_et = apply_shift_and_interp(et, 'Timestamp_seconds', et_cols, t_grid, lag_ms=eeg_et_lag)
        for c,v in et_interp.items():
            out[c] = v
        aa_log["et"].extend(aa_et)
        print(f"Added {len(et_interp)} ET columns  (AA: {len(aa_et)})")

    synced = pd.DataFrame(out).dropna(axis=1, how='all')
    for c in synced.select_dtypes(include=['float64','float32']).columns:
        synced[c] = synced[c].round(FLOAT_DECIMALS)

    os.makedirs(out_dir, exist_ok=True)
    out_csv = os.path.join(out_dir, f"{stem}_synchronized_corrected.csv")
    synced.to_csv(out_csv, index=False, float_format=f"%.{FLOAT_DECIMALS}f")
    try:
        sz = os.path.getsize(out_csv)/ (1024*1024)
        print(f"Saved: {out_csv} ({sz:.1f} MB)")
    except Exception:
        print(f"Saved: {out_csv}")

    # -------- JSON lag log --------
    if WRITE_JSON_LOG:
        log = {
            "input_file": str(Path(file_path).resolve()),
            "output_csv": str(Path(out_csv).resolve()),
            "grid_mode": GRID_MODE,
            "grid_hz": FIXED_HZ,
            "span_mode": SPAN_MODE,
            "lag_emg_eeg_ms": float(emg_eeg_lag),
            "lag_eeg_eye_ms": float(eeg_et_lag),
            "lag_emg_eye_ms": float(emg_et_lag),
            "chosen_eeg_channel": best_eeg_ch if best_eeg_ch is not None else "",
            "median_pr_emg_eeg": (None if (pr_med is None or not np.isfinite(pr_med)) else float(pr_med)),
            "aa_applied_channels": aa_log,
            "log_units": {
                "Timestamp_seconds": "s",
                "Timestamp_ms": "ms",
                "EMG": ("µV" if SCALE_EMG_TO_uV else "raw_counts_or_mV"),
                "EEG": "µV",
                "ET": "normalized [0..1] for gaze; mm for pupils/dist; deg/deg/s/m/s² for IMU/head"
            },
            "notes": "Lags are shifts applied BEFORE interpolation to align with EEG reference. Positive means source lags EEG; we subtract lag when aligning."
        }
        out_json = os.path.join(out_dir, f"{stem}_synchronized_corrected.json")
        with open(out_json, "w", encoding="utf-8") as f:
            json.dump(log, f, indent=2)
        print(f"JSON lag log saved: {out_json}")

    return {
        'filename': filename,
        'rows_emg': len(emg), 'rows_eeg': len(eeg), 'rows_et': len(et),
        'grid_rows': len(t_grid),
        'fs_emg': fs_emg, 'fs_eeg': fs_eeg, 'fs_et': fs_et,
        'emg_eeg_lag_ms': float(emg_eeg_lag),
        'eeg_et_lag_ms': float(eeg_et_lag),
        'emg_et_lag_ms': float(emg_et_lag),
        'best_eeg_ch': best_eeg_ch,
        'aa_emg': len(aa_log["emg"]),
        'aa_eeg': len(aa_log["eeg"]),
        'aa_et': len(aa_log["et"])
    }

# -------------------------
# Batch
# -------------------------
def main():
    print("Starting synchronization (lite-union++ v3)")
    print(f"GRID_MODE={GRID_MODE} | FIXED_HZ={FIXED_HZ} | SPAN_MODE={SPAN_MODE}")
    print("Input:", INPUT_DIR); print("Output:", OUTPUT_DIR)

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    files = sorted(glob.glob(os.path.join(INPUT_DIR, "*.csv")))
    files = [f for f in files if '_synchronized' not in os.path.basename(f).lower()]
    print(f"Found {len(files)} files")

    rows=[]
    for i, fp in enumerate(files, 1):
        print(f"\n[{i}/{len(files)}]")
        try:
            rows.append(process_one(fp, OUTPUT_DIR))
        except Exception as e:
            print(f"❌ Error on {os.path.basename(fp)}: {e}")
            rows.append({'filename': os.path.basename(fp), 'status': f'ERROR: {e}'})
    if rows:
        pd.DataFrame(rows).to_csv(os.path.join(OUTPUT_DIR, "synchronization_summary.csv"), index=False)
        print("Summary saved.")

    print("\n" + "="*80); print("SYNCHRONIZATION COMPLETE"); print("="*80)

if __name__ == "__main__":
    main()
