### **ICA Label + EEGKit**

In [None]:
# icalabel
from pathlib import Path
import numpy as np
from scipy.stats import median_abs_deviation
import mne
from mne_icalabel import label_components

# Configuration
MEGKIT_FILE = Path("/input/data/path/meegkit_cleaned_eeg.fif")
OUTPUT_FILE = MEGKIT_FILE.parent / "file_icalabel_cleaned_eeg.fif"

ICALABEL_THRESHOLDS = {
    'eye blink': 0.70,
    'heart beat': 0.80,
    'muscle artifact': 0.70,
    'line noise': 0.80,
    'channel noise': 0.80
}

def log(msg):
    print(msg)

# ============================================================================
# BAD CHANNEL DETECTION FUNCTION (from ica_xtra.py)
# ============================================================================
def detect_bad_channels(raw: mne.io.Raw, mad_threshold: float = 10.0, min_amplitude_uv: float = 0.1):
    raw_eeg = raw.copy().pick("eeg")  # ← Modern .pick() syntax
    data_uv = raw_eeg.get_data() * 1e6
    amplitude = np.ptp(data_uv, axis=1)
    variance = np.var(data_uv, axis=1)

    flat_mask = amplitude < min_amplitude_uv
    flat_chs = [ch for ch, is_flat in zip(raw_eeg.ch_names, flat_mask) if is_flat]

    noisy_mask = np.zeros(len(amplitude), dtype=bool)
    for feat in (variance, amplitude):
        mad = median_abs_deviation(feat, scale="normal", nan_policy="omit")
        if not np.isnan(mad) and mad > 1e-12:
            z = (feat - np.nanmedian(feat)) / mad
            noisy_mask |= z > mad_threshold
    noisy_chs = [ch for ch, is_noisy in zip(raw_eeg.ch_names, noisy_mask) if is_noisy]

    return sorted(set(flat_chs + noisy_chs))

# ============================================================================
# MAIN
# ============================================================================
log(f"Loading Meegkit-cleaned file: {MEGKIT_FILE}")
raw = mne.io.read_raw_fif(MEGKIT_FILE, preload=True)

# 1. Detect bad channels
bads = detect_bad_channels(raw, mad_threshold=10, min_amplitude_uv=0.1)
raw.info['bads'] = bads
if bads:
    log(f"Detected bad channels: {len(bads)}")
else:
    log("No bad channels detected.")

# 2. Apply CAR (excludes bads automatically)
raw = raw.set_eeg_reference('average', verbose=False)
log("✅ Re-applied average reference (excluding bad channels).")

# 3. Verify filtering
assert raw.info['highpass'] <= 1.0, "Data must be high-pass filtered ≤1 Hz"
log("✅ Filtering verified.")

# 4. Fit ICA (excludes bads automatically)
log("Fitting ICA...")
ica = mne.preprocessing.ICA(
    n_components=0.99,
    method='picard',
    fit_params=dict(ortho=False, extended=True),
    random_state=99,
    max_iter='auto'
)
ica.fit(raw)
log(f"ICA fitted with {ica.n_components_} components.")

# 5. Run ICLabel
log("Running ICLabel...")
raw_eeg = raw.copy().pick("eeg")  # ← Modern .pick() syntax
labels_dict = label_components(raw_eeg, ica, method="iclabel")

excluded = []
for i, (label, prob_vec) in enumerate(zip(labels_dict["labels"], labels_dict["y_pred_proba"])):
    lbl = label.lower().strip()
    if lbl in ICALABEL_THRESHOLDS and np.max(prob_vec) > ICALABEL_THRESHOLDS[lbl]:
        excluded.append(i)

ica.exclude = sorted(set(excluded))
log(f"ICLabel excluded components: {ica.exclude}")

# 6. Apply ICA
cleaned = ica.apply(raw)

# 7. Interpolate bad channels to restore full sensor set
log("Interpolating bad channels...")
cleaned.info['bads'] = bads
cleaned.interpolate_bads(reset_bads=True)
log(f"Interpolated {len(bads)} bad channels.")

# 8. Save final cleaned data
cleaned.save(OUTPUT_FILE, overwrite=True)
log(f"✅ Saved final cleaned data to: {OUTPUT_FILE}")