VERSION 1 - IT WAS THE GITHUB ORIGINAL. IT HAS OPTIONAL ICA
Version 1: The "Simple Padder"
This version's philosophy is "get to 22 channels first (by zero-padding), then do everything else."

Load .edf file.

Pick EEG-only channels.

Clean channel names (the simple version).

Pad with Zeros: Calls pick_order_and_pad with pad_missing=True. This adds all-zero channels immediately to get to 22 CORE_CHS.

Set Montage: Assigns 10-20 locations to the 22 channels (including the new fake-zero ones).

Common Average Reference (CAR): Applies set_eeg_reference.

Notch Filter: 60 Hz.

Band-pass Filter: 0.5 Hz to 100 Hz.

(Optional) ICA: Has a placeholder for ICA, but it's not automated (no ICLabel, exclude is commented out).

Resample: 250 Hz.

(Optional) Crop Controls: Trims the first 10 seconds of "control" files.

Epoch: Chops data into 2-second windows.

Artifact Rejection (Simple): Rejects epochs based on a single percentile (e.g., 95th) of amplitude.

Z-score: Normalizes each channel within each epoch.

Assign Labels.

preprocess core

In [None]:
from pathlib import Path
import re, pickle, numpy as np, mne
from mne.preprocessing import ICA
#python src\preprocess_batch.py --input_dir data_raw\DATA --output_dir data_pp --psd_dir figures\psd --max_patients 10 --pad-missing 
# Fixed 10–20 core channel layout (target topology for graphs/ML)
CORE_CHS = ["Fp1","Fp2","F7","F3","Fz","F4","F8",
            "T1","T3","C3","Cz","C4","T4","T2",
            "T5","P3","Pz","P4","T6","O1","Oz","O2"]

def clean_channel_names(raw: mne.io.BaseRaw):
    """
    Standardize raw channel names (remove 'EEG ', '-LE', '-REF', trim whitespace).
    Operates in-place on `raw`.
    """
    mapping = {orig: re.sub(r'^(?:EEG\s*)', '', orig).replace('-LE','').replace('-REF','').strip()
               for orig in raw.ch_names}
    raw.rename_channels(mapping)

def pick_order_and_pad(raw: mne.io.BaseRaw, pad_missing: bool = True):
    """
    Reorder channels to match CORE_CHS; optionally zero-pad missing ones.
    Returns:
      - ordered_list: list of channel names after reordering (CORE_CHS if padded)
      - present_mask: bool[22] True where original channel existed, False if padded
    Raises when no CORE_CHS are present and pad_missing=False.
    """
    present = [ch for ch in CORE_CHS if ch in raw.ch_names]  # in CORE_CHS order
    if not present and not pad_missing:
        raise RuntimeError("No recognizable CORE_CHS channels found in this recording.")

    if pad_missing:
        missing = [ch for ch in CORE_CHS if ch not in raw.ch_names]
        if missing:
            data = np.zeros((len(missing), raw.n_times))
            info = mne.create_info(missing, sfreq=raw.info["sfreq"], ch_types="eeg")
            raw.add_channels([mne.io.RawArray(data, info)], force_update_info=True)
        raw.reorder_channels(CORE_CHS)
        present_mask = np.array([ch in present for ch in CORE_CHS], dtype=bool)
        return CORE_CHS, present_mask

    # keep only the present subset, ordered
    raw.pick_channels(present)
    raw.reorder_channels(present)
    present_mask = np.array([ch in present for ch in CORE_CHS], dtype=bool)
    return present, present_mask

def set_montage_for_corechs(raw: mne.io.BaseRaw):
    """
    Assign 10–20 electrode positions for channels present in `raw`.
    Uses standard_1020 where available; approximates T1/T2 (FT9/FT10-like coords).
    """
    std = mne.channels.make_standard_montage('standard_1020')
    pos = std.get_positions()['ch_pos'].copy()
    ch_pos = {}
    # Add coordinates for channels that exist in this recording
    for ch in raw.ch_names:
        if ch in pos:
            ch_pos[ch] = pos[ch]
    # mastoids (approximate)
    # Add approximate mastoid positions for T1/T2 if present
    # FT9 : [-0.0840759 0.0145673 -0.050429 ] FT10 : [ 0.0841131 0.0143647 -0.050538 ]
    if 'T1' in raw.ch_names:
        ch_pos.setdefault('T1', (-0.0840759, 0.0145673, -0.050429))
    if 'T2' in raw.ch_names:
        ch_pos.setdefault('T2', (0.0841131, 0.0143647, -0.050538))
    #pos.update({'T1': [-0.040,-0.090,0.120], 'T2': [0.040,-0.090,0.120]})
    montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
    raw.set_montage(montage, match_case=False)

def infer_label_from_path(p: Path) -> int:
    """
    Infer binary label from path:
      returns 1 if path contains '/00_epilepsy/', else 0 (control).
    """
    s = str(p).replace('\\','/').lower()
    return 1 if "/00_epilepsy/" in s else 0

def preprocess_single(
    edf_path: Path,
    notch: float = 60.0,
    band: tuple = (0.5, 100.0),
    resample_hz: float = 250.0,
    epoch_len: float = 2.0,
    epoch_overlap: float = 0.0,
    reject_percentile: float = 95.0,
    crop_first10_if_control: bool = True,
    ica_components=None,
    return_psd: bool = False,
    pad_missing: bool = True               # << default: pad with zeros for fixed topology
):
    """
    End-to-end preprocessing for a single EDF:
      - Load EDF, keep EEG, clean names
      - Enforce CORE_CHS order; optionally pad missing channels
      - Set montage (incl. T1/T2 approximations)
      - Common average reference, notch, band-pass
      - Optional ICA (fitting scaffold present)
      - Resample; optionally crop first 10s for controls
      - Epoch, percentile-based amplitude rejection
      - Per-epoch, per-channel z-scoring across time
      - Create per-epoch labels from path
      - (Optional) PSD before/after
    Returns dict with processed Raw, Epochs, labels, thresholds, channel info, and masks.
    """
    edf_path = Path(edf_path)
    # Load raw data
    raw_before = mne.io.read_raw_edf(str(edf_path), preload=True, verbose="ERROR")
    raw_before.pick_types(eeg=True)
    clean_channel_names(raw_before)
    raw = raw_before.copy()

    # --- Enforce your channel order; pad missing if requested ---
    if pad_missing and not any(ch in raw_before.ch_names for ch in CORE_CHS):
        raise RuntimeError("No CORE_CHS present in this recording; aborting to avoid all-zero data.")
    present, present_mask = pick_order_and_pad(raw, pad_missing=pad_missing)
    pick_order_and_pad(raw_before, pad_missing=pad_missing)  # mirror for PSD-before alignment

    # --- Montage consistent with CORE_CHS names ---
    set_montage_for_corechs(raw)

    # Common average reference and filtering
    raw.set_eeg_reference('average', projection=False)
    raw.notch_filter(freqs=[notch])
    raw.filter(l_freq=band[0], h_freq=band[1], fir_design='firwin', filter_length='auto')

    # Optional ICA artifact removal
    if ica_components:
        ica = ICA(n_components=min(20, len(raw.ch_names)), method='fastica', random_state=42)
        ica.fit(raw)
        #ica.exclude = sorted(set(ica_components))
        #ica.apply(raw)

    # Resample & optional crop
    raw.resample(resample_hz, npad="auto")
    label = infer_label_from_path(edf_path)
    if crop_first10_if_control and label == 0:
        raw.crop(tmin=10.0)

    # Epoching + amplitude artifact rejection
    epochs = mne.make_fixed_length_epochs(raw, duration=epoch_len, overlap=epoch_overlap, preload=True)
    X = epochs.get_data()
    max_ptp_uv = np.ptp(X, axis=2).max(axis=1) * 1e6
    thr_uv = float(np.percentile(max_ptp_uv, reject_percentile))
    epochs_clean = epochs.copy().drop_bad(reject=dict(eeg=thr_uv * 1e-6))

    # Z-score
    #Xc = epochs_clean.get_data()
    #m, s = Xc.mean(), Xc.std() if Xc.std() != 0 else 1.0
    #epochs_clean._data = (Xc - m) / s

    # --- Per-epoch, per-channel z-score across time ---
    Xc = epochs_clean.get_data()
    m = Xc.mean(axis=2, keepdims=True)
    s = Xc.std(axis=2, keepdims=True)
    s[s == 0] = 1.0
    Xz = (Xc - m) / s
    epochs_clean = mne.EpochsArray(
        Xz, epochs_clean.info, events=epochs_clean.events,
        tmin=epochs_clean.tmin, event_id=epochs_clean.event_id, on_missing='ignore'
    )

    # --- Labels per epoch (subject-level from path) ---
    y = np.full(len(epochs_clean), label, dtype=int)


    # Collect results
    out = {
        "raw_after": raw,
        "epochs": epochs_clean,
        "labels": y,
        "threshold_uv": thr_uv,
        "present_channels": present,
        "present_mask": present_mask,  # True for originally present CORE_CHS

    }
    if return_psd:
        # channels that actually existed in the EDF
        real_chs = [ch for ch, m in zip(CORE_CHS, present_mask) if m]

        # BEFORE: compute PSD on real channels only (avoid padded zeros)
        rb = raw_before.copy()
        if real_chs:                      # only pick if we have any real channels
            rb.pick_channels(real_chs)
        out["psd_before"] = rb.compute_psd(fmax=band[1], average='mean')

        # AFTER: (recommended) also drop padded channels for a clean plot
        ra = raw.copy()
        if real_chs:
            ra.pick_channels(real_chs)
        out["psd_after"] = ra.compute_psd(fmax=band[1], average='mean')
        #out["psd_before"] = raw_before.compute_psd(fmax=band[1], average='mean')
        #out["psd_after"]  = raw.compute_psd(fmax=band[1],  average='mean')
    return out

preprocess single

In [None]:
"""
Usage:
  python src/preprocess_single.py --edf path/to/file.edf --out path/to/output_dir --psd_dir path/to/psd_figures
  python src/preprocess_single.py --edf data_raw/DATA/...t001.edf --out data_pp --psd_dir figures/psd
"""
import sys
from pathlib import Path

# Let this script import sibling modules from src/
sys.path.append(str(Path(__file__).resolve().parent))

import argparse
import pickle
import numpy as np
import matplotlib.pyplot as plt
import json


from preprocess_core import preprocess_single

def main(edf:str, out:str, psd_dir:str, pad_missing: bool):
    """
    Preprocess a single raw EEG EDF file and save preprocessed outputs.

    Saves:
      - {pid}_epochs.npy           : epochs array (n_epochs, n_ch, n_times)
      - {pid}_labels.npy           : per-epoch labels
      - {pid}_raw.npy              : raw AFTER preprocessing (channels x times)
      - {pid}_info.pkl             : MNE Raw.info for the post-processed raw
      - {pid}_present_mask.npy     : bool[22] mask (True if real channel, False if padded)
      - {pid}_present_channels.json: ordered channel names after pick/pad
      - {pid}_PSD_before.png/.png  : PSD figs (if available)
    """
    edf = Path(edf); out = Path(out); psd_dir = Path(psd_dir)
    out.mkdir(parents=True, exist_ok=True); psd_dir.mkdir(parents=True, exist_ok=True)
    
    # Run preprocessing pipeline on single EDF file
    #res = preprocess_single(edf, return_psd=True, pad_missing=args.pad_missing)  # use defaults from core
    res = preprocess_single(edf, return_psd=True, pad_missing=pad_missing)  # use defaults from core
    pid = edf.stem

    # Save preprocessed numpy arrays and metadata
    np.save(out / f"{pid}_epochs.npy", res["epochs"].get_data())
    np.save(out / f"{pid}_labels.npy", res["labels"])
    np.save(out / f"{pid}_raw.npy",    res["raw_after"].get_data())
    np.save(out / f"{pid}_present_mask.npy", res["present_mask"])
    with open(out / f"{pid}_info.pkl", "wb") as f:
        pickle.dump(res["raw_after"].info, f)
    with open(out / f"{pid}_present_channels.json", "w", encoding="utf-8") as f:
        json.dump(res["present_channels"], f, ensure_ascii=False, indent=2)
    # Save PSD plots before and after preprocessing
    for tag, psd in [("before", res.get("psd_before")), ("after", res.get("psd_after"))]:
        if psd is None: continue
        fig = psd.plot(show=False); fig.suptitle(f"PSD {tag.upper()}"); fig.savefig(psd_dir / f"{pid}_PSD_{tag}.png", dpi=150, bbox_inches="tight"); plt.close(fig)


    print(f"Saved epochs={len(res['epochs'])}, thr={res['threshold_uv']:.1f} µV")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Preprocess a single EEG EDF file for epilepsy detection")
    parser.add_argument("--edf", required=True, help="Path to raw EDF file")
    parser.add_argument("--out", required=True, help="Output directory for arrays and metadata")
    parser.add_argument("--psd_dir", required=True, help="Output directory for PSD figures")
    parser.add_argument("--pad_missing", action="store_true", help="Zero-pad missing CORE_CHS")
    args = parser.parse_args()
    #main(args.edf, args.out, args.psd_dir)
    main(args.edf, args.out, args.psd_dir, args.pad_missing)

preprocess batch

In [None]:
import sys
from pathlib import Path
import argparse
import traceback
import json
from tqdm import tqdm

sys.path.append(str(Path(__file__).resolve().parent))

from preprocess_core import preprocess_single


def main(input_dir: str, output_dir: str, psd_dir: str, max_patients: int = None, pad_missing: bool = False):
    """
    Batch preprocess EEG EDF files from a root directory recursively.

    Parameters:
    -----------
    input_dir : str
        Root folder containing EEG EDF files and subfolders.
    output_dir : str
        Folder to save processed numpy arrays and metadata, preserving input folder structure.
    psd_dir : str
        Folder to save PSD plot images, preserving input folder structure.
    max_patients : int or None
        If set, limits total unique patients processed by their ID extracted from filename.

    This function:
    -------------
    - Recursively finds *.edf files under input_dir.
    - Extracts patient IDs from filenames.
    - Processes EDF files incrementally until max_patients is reached.
    - Calls the core preprocess_single() for each file.
    - Saves numpy arrays and metadata in relative output folder.
    - Saves PSD quality control images in relative PSD folder.
    - Uses tqdm progress bar for feedback.
    - Handles exceptions per file to continue batch.
    """
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    psd_path = Path(psd_dir)
    
    output_path.mkdir(parents=True, exist_ok=True)
    psd_path.mkdir(parents=True, exist_ok=True)
    
    edf_files = list(input_path.rglob("*.edf"))
    print(f"Found {len(edf_files)} EDF files in {input_path}")
    
    processed_patients = set()
    
    with tqdm(total=len(edf_files), desc="Processing EDF files") as pbar:
        for edf_file in edf_files:
            pid_full = edf_file.stem
            # Extract patient id from file name - adjust as needed:
            patient_id = pid_full.split('_')[0]
            
            # 1. Determine the expected output path for the epoch file
            relative_path = edf_file.parent.relative_to(input_path)
            output_subdir = output_path / relative_path
            expected_output_file = output_subdir / f"{pid_full}_epochs.npy"

            # 2. Check if the output file already exists
            if expected_output_file.exists():
                # Check for patient limit logic, if this is a NEW patient ID we hit.
                if patient_id not in processed_patients:
                     processed_patients.add(patient_id)

                # Skip processing the file
                tqdm.write(f"Skipping {pid_full}: Output already exists.")
                pbar.update(1)
                continue # Skip the rest of the loop for this file


            if patient_id not in processed_patients:
                if max_patients is not None and len(processed_patients) >= max_patients:
                    print(f"Reached max patient limit: {max_patients}. Stopping.")
                    break
                processed_patients.add(patient_id)
            
            try:
                print(f"Processing {edf_file} (patient {patient_id})...")
                res = preprocess_single(edf_file, return_psd=True, pad_missing=pad_missing)

                X = res["epochs"].get_data()  # (E, C, T)
                if not np.isfinite(X).all():
                    raise ValueError(f"{edf_file}: non-finite values in epochs")

                
                # Preserve folder hierarchy in output and psd dirs
                #relative_path = edf_file.parent.relative_to(input_path)
                
                #output_subdir = output_path / relative_path
                psd_subdir = psd_path / relative_path
                
                output_subdir.mkdir(parents=True, exist_ok=True)
                psd_subdir.mkdir(parents=True, exist_ok=True)

                #np.save(output_subdir / f"{pid_full}_epochs.npy", res["epochs"].get_data())
                #np.save(output_subdir / f"{pid_full}_labels.npy", res["labels"])
                #np.save(output_subdir / f"{pid_full}_raw.npy", res["raw_after"].get_data())
                #np.save(expected_output_file, res["epochs"].get_data()) # Using expected_output_file here
                np.save(expected_output_file, X)
                np.save(output_subdir / f"{pid_full}_labels.npy", res["labels"])
                np.save(output_subdir / f"{pid_full}_raw.npy", res["raw_after"].get_data())
                np.save(output_subdir / f"{pid_full}_present_mask.npy", res["present_mask"])
                with open(output_subdir / f"{pid_full}_info.pkl", "wb") as f:
                    pickle.dump(res["raw_after"].info, f)
                with open(output_subdir / f"{pid_full}_present_channels.json", "w", encoding="utf-8") as f:
                    json.dump(res["present_channels"], f, ensure_ascii=False, indent=2)

                for tag, psd in [("before", res.get("psd_before")), ("after", res.get("psd_after"))]:
                    if psd is None:
                        continue
                    fig = psd.plot(show=False)
                    fig.suptitle(f"PSD {tag.upper()} {pid_full}")
                    fig.savefig(psd_subdir / f"{pid_full}_PSD_{tag}.png", dpi=150, bbox_inches="tight")
                    plt.close(fig)
                
                print(f"Finished {pid_full}: saved {len(res['epochs'])} epochs, threshold={res['threshold_uv']:.1f} µV")
            
            except Exception as e:
                print(f"Error processing {edf_file}: {e}")
                traceback.print_exc()
            
            pbar.update(1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Batch preprocess EEG EDF files for epilepsy dataset")
    parser.add_argument("--input_dir", required=True, help="Root EEG EDF folder")
    parser.add_argument("--output_dir", required=True, help="Folder to save preprocessed arrays and metadata")
    parser.add_argument("--psd_dir", required=True, help="Folder to save PSD plots")
    parser.add_argument("--max_patients", type=int, default=None, help="Limit unique patients processed")
    # Simple positive flag; default False
    parser.add_argument("--pad-missing", dest="pad_missing", action="store_true",
                        help="Enable zero-padding of missing channels (fixed topology).")
    parser.set_defaults(pad_missing=False)

    args = parser.parse_args()
    
    import numpy as np
    import pickle
    import matplotlib.pyplot as plt
    
    main(args.input_dir, args.output_dir, args.psd_dir, max_patients=args.max_patients, pad_missing=args.pad_missing)

Version 2: The "Clean-Then-Pad"
This version's philosophy is "clean the real data first (with ICA), then pad with zeros to get 22 channels."

Load .edf file.

Pick EEG-only channels.

Clean channel names (the robust, case-correcting version).

Pick Present Channels: Calls pick_order_and_pad with pad_missing=False. This throws away non-CORE_CHS but does not pad yet.

Notch Filter: 60 Hz.

Band-pass Filter: 1.0 Hz to 100 Hz (Note: 1.0 Hz, not 0.5, which is better for ICA).

Set Montage: Assigns 10-20 locations to only the channels that are present.

Common Average Reference (CAR).

Automated ICA: Runs a full mne_icalabel pipeline to automatically find and remove eye, muscle, and other artifacts. This is a major step.

Pad with Zeros: Now it calls pick_order_and_pad with pad_missing=True to fill in the missing spots with all-zero channels to get the final 22.

Resample: 250 Hz.

(Optional) Crop Controls: Trims the first 10 seconds of "control" files.

Epoch: Chops data into 2-second windows.

Artifact Rejection (Robust): Rejects epochs using an adaptive percentile plus a hard-coded "sanity cap" (e.g., 500 µV).

Z-score: Normalizes each channel within each epoch.

Assign Labels.

preprocess core

In [None]:
from pathlib import Path
import re, pickle, numpy as np, mne
from mne.preprocessing import ICA
from mne_icalabel import label_components
import matplotlib.pyplot as plt
#python src\preprocess_batch.py --input_dir data_raw\DATA --output_dir data_pp --psd_dir figures\psd --max_patients 10 --pad-missing 
# Fixed 10–20 core channel layout (target topology for graphs/ML)
CORE_CHS = ["Fp1","Fp2","F7","F3","Fz","F4","F8",
            "T1","T3","C3","Cz","C4","T4","T2",
            "T5","P3","Pz","P4","T6","O1","Oz","O2"]

def clean_channel_names(raw: mne.io.BaseRaw):
    """
    Standardize raw channel names (remove 'EEG ', '-LE', '-REF', trim whitespace,
    and standardize case to match CORE_CHS). Operates in-place on `raw`.
    """
    # 1. General cleaning
    mapping = {orig: re.sub(r'^(?:EEG\s*)', '', orig).replace('-LE','').replace('-REF','').strip()
               for orig in raw.ch_names}
    raw.rename_channels(mapping)
    
    # 2. Fix capitalization (e.g., "FP1" -> "Fp1")
    # Build a map of {lowercase_name: correct_case_name}
    core_ch_map = {ch.lower(): ch for ch in CORE_CHS}
    
    case_mapping = {}
    for ch in raw.ch_names:
        if ch.lower() in core_ch_map:
            case_mapping[ch] = core_ch_map[ch.lower()]
            
    raw.rename_channels(case_mapping)

def pick_order_and_pad(raw: mne.io.BaseRaw, pad_missing: bool = True):
    """
    Reorder channels to match CORE_CHS; optionally zero-pad missing ones.
    Returns:
      - ordered_list: list of channel names after reordering (CORE_CHS if padded)
      - present_mask: bool[22] True where original channel existed, False if padded
    Raises when no CORE_CHS are present and pad_missing=False.
    """
    present = [ch for ch in CORE_CHS if ch in raw.ch_names]  # in CORE_CHS order
    if not present and not pad_missing:
        raise RuntimeError("No recognizable CORE_CHS channels found in this recording.")

    if pad_missing:
        missing = [ch for ch in CORE_CHS if ch not in raw.ch_names]
        if missing:
            data = np.zeros((len(missing), raw.n_times))
            info = mne.create_info(missing, sfreq=raw.info["sfreq"], ch_types="eeg")
            raw.add_channels([mne.io.RawArray(data, info)], force_update_info=True)
        raw.reorder_channels(CORE_CHS)
        present_mask = np.array([ch in present for ch in CORE_CHS], dtype=bool)
        return CORE_CHS, present_mask

    # keep only the present subset, ordered
    raw.pick_channels(present)
    raw.reorder_channels(present)
    present_mask = np.array([ch in present for ch in CORE_CHS], dtype=bool)
    return present, present_mask

def set_montage_for_corechs(raw: mne.io.BaseRaw):
    """
    Assign 10–20 electrode positions for channels present in `raw`.
    Uses standard_1020 where available; approximates T1/T2 (FT9/FT10-like coords).
    """
    std = mne.channels.make_standard_montage('standard_1020')
    pos = std.get_positions()['ch_pos'].copy()
    ch_pos = {}
    # Add coordinates for channels that exist in this recording
    for ch in raw.ch_names:
        if ch in pos:
            ch_pos[ch] = pos[ch]
    # mastoids (approximate)
    # Add approximate mastoid positions for T1/T2 if present
    # FT9 : [-0.0840759 0.0145673 -0.050429 ] FT10 : [ 0.0841131 0.0143647 -0.050538 ]
    if 'T1' in raw.ch_names:
        ch_pos.setdefault('T1', (-0.0840759, 0.0145673, -0.050429))
    if 'T2' in raw.ch_names:
        ch_pos.setdefault('T2', (0.0841131, 0.0143647, -0.050538))
    #pos.update({'T1': [-0.040,-0.090,0.120], 'T2': [0.040,-0.090,0.120]})
    montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
    raw.set_montage(montage, match_case=False)

def infer_label_from_path(p: Path) -> int:
    """
    Infer binary label from path:
      returns 1 if path contains '/00_epilepsy/', else 0 (control).
    """
    s = str(p).replace('\\','/').lower()
    return 1 if "/00_epilepsy/" in s else 0

def preprocess_single(
    edf_path: Path,
    notch: float = 60.0,
    band: tuple = (1.0, 100.0),  # Tuned for ICLabel
    resample_hz: float = 250.0,
    epoch_len: float = 2.0,
    epoch_overlap: float = 0.0,
    reject_percentile: float = 95.0,
    crop_first10_if_control: bool = True,
    ica_components: bool = False,
    return_psd: bool = False,
    pad_missing: bool = True,
    ica_dir: str = None
):
    """
    End-to-end preprocessing for a single EDF (v13 - Final Order):
      - Load, clean names (with case-correction)
      - Pick *only* present CORE_CHS channels (no padding)
      - Filter (1-100 Hz)
      - Set Montage
      - Common Average Reference
      - Automated ICA artifact removal (ICLabel + Infomax)
      - Pad missing channels (to get 22)
      - Resample, crop, epoch, reject (w/ Sanity Cap), z-score
    """
    edf_path = Path(edf_path)
    pid = edf_path.stem
    
    # Load raw data
    raw_before = mne.io.read_raw_edf(str(edf_path), preload=True, verbose="ERROR")
    raw_before.pick_types(eeg=True, exclude=['EOG', 'ECG', 'EMG', 'MISC', 'STIM']) # Try to exclude junk
    clean_channel_names(raw_before) # Fixes case
    raw = raw_before.copy()

    # --- 1. PICK (NO PADDING) ---
    # We select *only* the channels we care about, throwing away 'EKG', 'DC', etc.
    # We also run this on raw_before for the PSD plot.
    pick_order_and_pad(raw_before, pad_missing=False) 
    present, present_mask_interim = pick_order_and_pad(raw, pad_missing=False)
    
    if len(present) == 0:
        raise RuntimeError("No CORE_CHS channels were found in this file.")

    # --- 2. Filtering ---
    raw.notch_filter(freqs=[notch])
    raw.filter(l_freq=band[0], h_freq=band[1], fir_design='firwin', filter_length='auto')
    
    # --- 3. Set Montage ---
    # This will now work, as 'raw' only contains known, case-corrected EEG channels
    set_montage_for_corechs(raw)

    # --- 4. Re-reference (CAR) ---
    raw.set_eeg_reference('average', projection=False)

    # --- 5. ICA Artifact Removal ---
    if ica_components:
        try:
            n_ica_comp = min(20, len(raw.ch_names) - 1)
            
            if n_ica_comp < 2:
                print(f"   ! Skipping ICA: Not enough channels ({len(raw.ch_names)})")
            else:
                ica = ICA(
                    n_components=n_ica_comp, 
                    method='infomax', 
                    fit_params=dict(extended=True),
                    random_state=42
                )
                ica.fit(raw)

                ic_labels = label_components(raw, ica, method="iclabel")
                labels = ic_labels["labels"]
                ica.exclude = [i for i, label in enumerate(labels) if label not in ["brain", "other"]]

                if ica.exclude:
                    print(f"   ICA: Found {len(ica.exclude)} artifact components. Plotting...")
                    if ica_dir:
                        ica_plot_dir = Path(ica_dir)
                        ica_plot_dir.mkdir(parents=True, exist_ok=True)
                        for comp_idx in ica.exclude:
                            fig = ica.plot_properties(raw, picks=comp_idx, show=False)[0]
                            fig.savefig(ica_plot_dir / f"{pid}_ica_comp_{comp_idx:02d}_REMOVED.png", dpi=100)
                            plt.close(fig)
                            
                    print(f"   ICA: Removing {len(ica.exclude)} components.")
                    ica.apply(raw)
                else:
                    print("   ICA: No artifacts found to remove.")
                
        except Exception as e:
            print(f"   ! ICA failed for {edf_path.stem}: {e}")
            
    # --- 6. PAD (NOW a_components=False,
    # Now we pad with zeros to get our fixed 22-channel graph
    # We also get the *final* present_mask
    present_final, present_mask = pick_order_and_pad(raw, pad_missing=True)

    # --- 7. Resample & optional crop ---
    raw.resample(resample_hz, npad="auto")
    label = infer_label_from_path(edf_path)
    if crop_first10_if_control and label == 0:
        raw.crop(tmin=10.0)

    # --- 8. Epoching + amplitude artifact rejection (Sanity Cap) ---
    epochs = mne.make_fixed_length_epochs(raw, duration=epoch_len, overlap=epoch_overlap, preload=True)
    X = epochs.get_data() # This gets data in VOLTS
    if X.shape[0] == 0:
        print(f"   ! No epochs created for {edf_path.stem}. Skipping file.")
        return None
    
    max_ptp_uv = np.ptp(X, axis=2).max(axis=1) * 1e6
    
    adaptive_thr_uv = float(np.percentile(max_ptp_uv, reject_percentile))
    sanity_thr_uv = 500.0  # Hard cap
    final_thr_uv = min(adaptive_thr_uv, sanity_thr_uv)
    
    print(f"   Adaptive threshold (95th percentile): {adaptive_thr_uv:.1f} uV")
    print(f"   Sanity-capped threshold: {final_thr_uv:.1f} uV (Using this one)")
    
    epochs_clean = epochs.copy().drop_bad(reject=dict(eeg=final_thr_uv * 1e-6))

    if len(epochs_clean) == 0:
        print(f"   ! All epochs rejected for {edf_path.stem} (Threshold={final_thr_uv:.1f} uV). Skipping file.")
        return None
    
    # --- 9. Per-epoch, per-channel z-score across time ---
    Xc = epochs_clean.get_data()
    m = Xc.mean(axis=2, keepdims=True)
    s = Xc.std(axis=2, keepdims=True)
    s[s == 0] = 1.0
    Xz = (Xc - m) / s
    epochs_clean = mne.EpochsArray(
        Xz, epochs_clean.info, events=epochs_clean.events,
        tmin=epochs_clean.tmin, event_id=epochs_clean.event_id, on_missing='ignore'
    )

    # --- 10. Labels per epoch (subject-level from path) ---
    y = np.full(len(epochs_clean), label, dtype=int)

    # --- Collect results ---
    out = {
        "raw_after": raw,
        "epochs": epochs_clean,
        "labels": y,
        "threshold_uv": final_thr_uv,
        "present_channels": present_final, # The final list of 22
        "present_mask": present_mask,   # The final 22-long mask
    }

    if return_psd:
        # We need to pad raw_before too so its PSD matches
        pick_order_and_pad(raw_before, pad_missing=True)
        out["psd_before"] = raw_before.compute_psd(fmax=band[1], average='mean')

        ra = raw.copy()
        if present:
            ra.pick_channels(present) # Plot only the *real* channels
        out["psd_after"] = ra.compute_psd(fmax=band[1], average='mean')
        
    return out

preprocess single

"""
Usage:
  python src/preprocess_single.py --edf path/to/file.edf --out path/to/output_dir --psd_dir path/to/psd_figures
  python src/preprocess_single.py --edf data_raw/DATA/...t001.edf --out data_pp --psd_dir figures/psd
"""
import sys
from pathlib import Path

# Let this script import sibling modules from src/
sys.path.append(str(Path(__file__).resolve().parent))

import argparse
import pickle
import numpy as np
import matplotlib.pyplot as plt
import json


from preprocess_core import preprocess_single

def main(edf:str, out:str, psd_dir:str, pad_missing: bool, ica: bool, ica_dir:str):
    """
    Preprocess a single raw EEG EDF file and save preprocessed outputs.

    Saves:
      - {pid}_epochs.npy           : epochs array (n_epochs, n_ch, n_times)
      - {pid}_labels.npy           : per-epoch labels
      - {pid}_raw.npy              : raw AFTER preprocessing (channels x times)
      - {pid}_info.pkl             : MNE Raw.info for the post-processed raw
      - {pid}_present_mask.npy     : bool[22] mask (True if real channel, False if padded)
      - {pid}_present_channels.json: ordered channel names after pick/pad
      - {pid}_PSD_before.png/.png  : PSD figs (if available)
    """
    edf = Path(edf); out = Path(out); psd_dir = Path(psd_dir)
    out.mkdir(parents=True, exist_ok=True); psd_dir.mkdir(parents=True, exist_ok=True)
    
    # Run preprocessing pipeline on single EDF file
    #res = preprocess_single(edf, return_psd=True, pad_missing=args.pad_missing)  # use defaults from core
    res = preprocess_single(edf, return_psd=True, pad_missing=pad_missing, ica_components=ica, ica_dir=ica_dir)  # use defaults from core
    pid = edf.stem

    # Save preprocessed numpy arrays and metadata
    np.save(out / f"{pid}_epochs.npy", res["epochs"].get_data())
    np.save(out / f"{pid}_labels.npy", res["labels"])
    np.save(out / f"{pid}_raw.npy",    res["raw_after"].get_data())
    np.save(out / f"{pid}_present_mask.npy", res["present_mask"])
    with open(out / f"{pid}_info.pkl", "wb") as f:
        pickle.dump(res["raw_after"].info, f)
    with open(out / f"{pid}_present_channels.json", "w", encoding="utf-8") as f:
        json.dump(res["present_channels"], f, ensure_ascii=False, indent=2)
    # Save PSD plots before and after preprocessing
    for tag, psd in [("before", res.get("psd_before")), ("after", res.get("psd_after"))]:
        if psd is None: continue
        fig = psd.plot(show=False); fig.suptitle(f"PSD {tag.upper()}"); fig.savefig(psd_dir / f"{pid}_PSD_{tag}.png", dpi=150, bbox_inches="tight"); plt.close(fig)


    print(f"Saved epochs={len(res['epochs'])}, thr={res['threshold_uv']:.1f} µV")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Preprocess a single EEG EDF file for epilepsy detection")
    parser.add_argument("--edf", required=True, help="Path to raw EDF file")
    parser.add_argument("--out", required=True, help="Output directory for arrays and metadata")
    parser.add_argument("--psd_dir", required=True, help="Output directory for PSD figures")
    parser.add_argument("--pad_missing", action="store_true", help="Zero-pad missing CORE_CHS")
    # --- ADD THIS LINE ---
    parser.add_argument("--ica", action="store_true", help="Run automated ICA artifact removal")
    # --- END OF ADDED LINE ---
    # --- ADD THIS LINE ---
    parser.add_argument("--ica_dir", help="Output directory for ICA component figures")
    args = parser.parse_args()
    #main(args.edf, args.out, args.psd_dir)
    main(args.edf, args.out, args.psd_dir, args.pad_missing, args.ica, args.ica_dir)