### **Minimalist EEG Preprocessing**

In [None]:
# meegkit
import sys
import warnings
from pathlib import Path
import numpy as np
from scipy.stats import median_abs_deviation
import mne
from meegkit import asr, star, sns

# Suppress harmless NumPy warnings from STAR's internal adaptation
warnings.filterwarnings("ignore", message="Mean of empty slice")
warnings.filterwarnings("ignore", message="invalid value encountered in divide")

# ============================================================================
# CONFIGURATION
# ============================================================================
SUBJECT = "sub-008"
TRIAL_ID = "trial003"
CROP_BASE_DIR = Path("/input/data/path")  # Replace with actual input data path
MEGKIT_OUTPUT_DIR = Path("/output/data/path/cleaned")  # Replace with actual output data path
GPS_FILE = Path("/montage/coordinates/file/ghw280_from_egig.gpsc")

CHANNEL_RENAME_MAP = {**{str(i): f'E{i}' for i in range(1, 281)}, 'REF CZ': 'Cz'}

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def log(msg: str):
    print(msg)

def parse_gpsc(filepath: Path):
    channels = []
    with open(filepath, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 4:
                try:
                    name = parts[0]
                    x, y, z = map(float, parts[1:4])
                    channels.append((name, x, y, z))
                except ValueError:
                    continue
    return channels

def apply_channel_renaming(raw: mne.io.Raw) -> mne.io.Raw:
    existing_map = {old: new for old, new in CHANNEL_RENAME_MAP.items() if old in raw.ch_names}
    if existing_map:
        raw.rename_channels(existing_map)
        log(f"Renamed {len(existing_map)} channels.")
    return raw

def apply_montage(raw: mne.io.Raw, gpsc_file: Path) -> mne.io.Raw:
    channels = parse_gpsc(gpsc_file)
    if not channels:
        raise ValueError("No valid channels in .gpsc file")
    gpsc_array = np.array([ch[1:4] for ch in channels])
    mean_pos = gpsc_array.mean(axis=0)
    ch_pos = {
        ch[0]: np.array([ch[1] - mean_pos[0], ch[2] - mean_pos[1], ch[3] - mean_pos[2]]) / 1000.0
        for ch in channels
    }
    montage = mne.channels.make_dig_montage(
        ch_pos=ch_pos,
        nasion=ch_pos.get('FidNz'),
        lpa=ch_pos.get('FidT9'),
        rpa=ch_pos.get('FidT10'),
        coord_frame='head'
    )
    raw.set_montage(montage, on_missing='warn')
    log("Montage applied.")
    return raw

def detect_bad_channels(raw: mne.io.Raw, mad_threshold: float = 10.0, min_amplitude_uv: float = 0.1):
    raw_eeg = raw.copy().pick("eeg")
    data_uv = raw_eeg.get_data() * 1e6
    amplitude = np.ptp(data_uv, axis=1)
    variance = np.var(data_uv, axis=1)

    flat_mask = amplitude < min_amplitude_uv
    flat_chs = [ch for ch, is_flat in zip(raw_eeg.ch_names, flat_mask) if is_flat]

    noisy_mask = np.zeros(len(amplitude), dtype=bool)
    for feat in (variance, amplitude):
        mad = median_abs_deviation(feat, scale="normal", nan_policy="omit")
        if not np.isnan(mad) and mad > 1e-12:
            z = (feat - np.nanmedian(feat)) / mad
            noisy_mask |= z > mad_threshold
    noisy_chs = [ch for ch, is_noisy in zip(raw_eeg.ch_names, noisy_mask) if is_noisy]

    return sorted(set(flat_chs + noisy_chs))

def find_cleanest_segment(raw: mne.io.Raw, duration_sec: float = 30.0, step_sec: float = 2.0):
    sfreq = raw.info["sfreq"]
    duration_samp = int(duration_sec * sfreq)
    step_samp = int(step_sec * sfreq)
    total_samp = raw.n_times

    if total_samp < duration_samp:
        log(f"⚠️ Trial too short ({total_samp / sfreq:.1f}s). Using full trial.")
        return raw.get_data(), 0.0

    data_v = raw.get_data()
    n_windows = (total_samp - duration_samp) // step_samp + 1

    variances = []
    amplitudes = []

    for i in range(n_windows):
        start, end = i * step_samp, i * step_samp + duration_samp
        win = data_v[:, start:end]

        if np.ptp(win) < 1e-9:
            continue

        var_metric = np.median(np.var(win, axis=1))
        amp_metric = np.median(np.ptp(win, axis=1))

        if np.isfinite(var_metric) and np.isfinite(amp_metric):
            variances.append(var_metric)
            amplitudes.append(amp_metric)

    if not variances:
        log("⚠️ No valid windows found. Using first segment as calibration.")
        return data_v[:, :duration_samp], 0.0

    variances = np.array(variances)
    amplitudes = np.array(amplitudes)

    def mad_zscore(x):
        x = np.atleast_1d(x)
        x_clean = x[np.isfinite(x)]
        if len(x_clean) == 0:
            return np.zeros_like(x)
        med = np.median(x_clean)
        mad = median_abs_deviation(x_clean, scale='normal', nan_policy='omit')
        if not np.isfinite(mad) or mad == 0:
            return np.zeros_like(x)
        z = (x - med) / mad
        z[~np.isfinite(z)] = 0
        return z

    score = mad_zscore(variances) + mad_zscore(amplitudes)
    best_idx = np.argmin(score)
    best_start = best_idx * step_samp
    calib_data_v = data_v[:, best_start:best_start + duration_samp]
    start_time = best_start / sfreq
    log(f"✅ Cleanest segment at t={start_time:.1f}s (score={score[best_idx]:.2f})")
    return calib_data_v, start_time

# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    input_fif = CROP_BASE_DIR / SUBJECT / f"file_{SUBJECT}_{TRIAL_ID}_eeg.fif"
    output_dir = MEGKIT_OUTPUT_DIR / SUBJECT / TRIAL_ID
    output_dir.mkdir(parents=True, exist_ok=True)
    output_fif = output_dir / f"file_{SUBJECT}_{TRIAL_ID}_meegkit_cleaned_eeg.fif"

    log(f"Processing: {input_fif}")

    # 1. Load & prepare
    raw = mne.io.read_raw_fif(input_fif, preload=True)
    raw = apply_channel_renaming(raw)
    raw = apply_montage(raw, GPS_FILE)

    # 2. Filter
    raw = raw.filter(l_freq=1.0, h_freq=100.0, picks='eeg', n_jobs=-1, verbose=False)
    raw = raw.notch_filter(freqs=60, picks='eeg', method='spectrum_fit',
                           filter_length='auto', mt_bandwidth=1.0, p_value=0.05, verbose=False)

    # 3. Force real-valued data
    raw._data = np.real(raw._data).astype(np.float64)

    # 4. Drop Cz (reference channel, not part of 280 E-channels)
    if 'Cz' in raw.ch_names:
        raw.drop_channels(['Cz'])
        log("Dropped Cz (reference channel).")

    # 5. Detect and MARK bad channels (do NOT drop)
    bad_chs = detect_bad_channels(raw, mad_threshold=10.0, min_amplitude_uv=0.1)
    raw.info['bads'] = bad_chs
    log(f"Marked {len(bad_chs)} channels as bad.")

    # 6. Get indices of GOOD EEG channels (excludes bads automatically)
    good_idx = mne.pick_types(raw.info, eeg=True, exclude='bads')
    good_ch_names = [raw.ch_names[i] for i in good_idx]
    log(f"Cleaning {len(good_ch_names)} good EEG channels.")

    # 7. Extract good data for cleaning
    data_good = raw.get_data(picks=good_idx)  # (n_good, n_times)

    # 8. Create temporary Raw object for calibration selection
    info_good = mne.create_info(ch_names=good_ch_names, sfreq=raw.info['sfreq'], ch_types='eeg')
    raw_good = mne.io.RawArray(data_good, info_good, verbose=False)
    raw_good.set_montage(raw.get_montage(), on_missing='ignore')

    # 9. Apply CAR on good channels only
    raw_good = raw_good.set_eeg_reference('average', verbose=False)
    log("CAR applied on good E-channels only.")

    # 10. Auto-calibrate ASR on good channels
    calib_data, _ = find_cleanest_segment(raw_good)
    calib_data = np.real(calib_data).astype(np.float64)
    
    asr_model = asr.ASR(sfreq=raw_good.info['sfreq'], cutoff=3, estimator='oas')
    asr_model.fit(calib_data)

    # 11. Clean entire good data
    cleaned_good = asr_model.transform(data_good)
    cleaned_good = np.real(cleaned_good).astype(np.float64)

    # STAR
    cleaned_t_ch = cleaned_good.T
    denoised_t_ch, _, _ = star.star(cleaned_t_ch, thresh=1.5, verbose=True)
    cleaned_good = denoised_t_ch.T

    # SNS
    cleaned_t_ch = cleaned_good.T
    denoised_t_ch, _ = sns.sns(cleaned_t_ch, n_neighbors=8)
    cleaned_good = denoised_t_ch.T

    # 12. PUT CLEANED DATA BACK INTO FULL RAW OBJECT
    raw._data[good_idx, :] = cleaned_good

    # 13. INTERPOLATE BAD CHANNELS (final step)
    raw.interpolate_bads(reset_bads=True)
    log(f"Interpolated {len(bad_chs)} bad channels to restore full 280-channel set.")

    # 14. Save
    raw.save(output_fif, overwrite=True)
    log(f"✅ Saved cleaned 280-channel data to: {output_fif}")