<a href="https://colab.research.google.com/github/emi-emi671/EEG-Anonymization/blob/main/MarkkuEEG_Pipeline_saving_output_features_on_drive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cp -r /content/drive/MyDrive/data /content/EEG_DATA


In [3]:
!ls /content/drive/MyDrive/data

Patient10_MTU207295UUS_t1_anonymized.mat
Patient10_MTU207297UUS_t1_anonymized_158849.mat
Patient10_MTU207297UUS_t1_anonymized_158865.mat
Patient10_MTU207297UUS_t1_anonymized_158873.mat
Patient10_MTU207297UUS_t1_anonymized_158877.mat
Patient10_MTU207297UUS_t1_anonymized_158884.mat
Patient10_MTU207297UUS_t1_anonymized_158927.mat
Patient10_MTU207297UUS_t1_anonymized_158936.mat
Patient10_MTU207297UUS_t1_anonymized_158940.mat
Patient10_MTU207297UUS_t1_anonymized_158952.mat
Patient10_MTU207297UUS_t1_anonymized.mat
Patient11_MTU207297UUS_t1_anonymized_158964.mat
Patient11_MTU207297UUS_t1_anonymized_158971.mat
Patient11_MTU207297UUS_t1_anonymized_158973.mat
Patient11_MTU207297UUS_t1_anonymized_158985.mat
Patient11_MTU207297UUS_t1_anonymized_158990.mat
Patient11_MTU207297UUS_t1_anonymized_159010.mat
Patient11_MTU207297UUS_t1_anonymized.mat
Patient12_MTU207297UUS_t1_anonymized_159014.mat
Patient12_MTU207297UUS_t1_anonymized_159018.mat
Patient12_MTU207297UUS_t1_anonymized.mat
Patient13_MTU207297U

In [9]:
!pip install numpy scipy h5py matplotlib pandas




In [4]:
import os
import re
import glob
import numpy as np
import h5py
import scipy.io as sio
from scipy.signal import butter, sosfiltfilt, iirnotch, filtfilt, welch
from typing import Dict, List, Tuple, Optional, Literal


# ============================================================
# 1) Patient ID extraction + file grouping
# ============================================================

_PATIENT_REGEX_DEFAULT = r"^(Patient\d+)"

def extract_patient_id(filename: str, pattern: str = _PATIENT_REGEX_DEFAULT) -> str:
    base = os.path.basename(filename)
    m = re.match(pattern, base)
    if not m:
        raise ValueError(
            f"Could not extract patient id from filename: {base}\n"
            f"Regex used: {pattern}\n"
            f"Update 'pattern' to match your naming scheme."
        )
    return m.group(1)

def group_files_by_patient(folder: str, mat_glob: str = "*.mat",
                           patient_regex: str = _PATIENT_REGEX_DEFAULT) -> Dict[str, List[str]]:
    files = sorted(glob.glob(os.path.join(folder, mat_glob)))
    if not files:
        raise FileNotFoundError(f"No files found in {folder} matching {mat_glob}")

    groups: Dict[str, List[str]] = {}
    for fp in files:
        pid = extract_patient_id(fp, pattern=patient_regex)
        groups.setdefault(pid, []).append(fp)

    for pid in groups:
        groups[pid] = sorted(groups[pid])

    return groups


# ============================================================
# 2) MATLAB/HDF5 loaders
# ============================================================

def _numeric_datasets(h5file):
    items = []
    def visit(name, obj):
        if isinstance(obj, h5py.Dataset) and np.issubdtype(obj.dtype, np.number):
            items.append((name, obj))
    h5file.visititems(visit)
    return items

def _read_scalar_ref(f: h5py.File, ref):
    try:
        v = np.asarray(f[ref])[()]
        v = np.asarray(v).squeeze()
        if v.size == 1 and np.isfinite(v).all():
            return float(v)
    except Exception:
        return None
    return None

def _detect_fs(f: h5py.File, n_ch: int, default_fs: float) -> float:
    if "#refs#/e/samplingRate" in f:
        ds = f["#refs#/e/samplingRate"]
        arr = np.asarray(ds[()]).squeeze()
        if arr.ndim == 1 and arr.size == n_ch and np.allclose(arr, arr[0]):
            val = float(arr[0])
            if 50.0 <= val <= 5000.0:
                return val

    if "#refs#/t6/dSamplingRate" in f:
        ds = f["#refs#/t6/dSamplingRate"]
        try:
            refs = np.asarray(ds[()]).squeeze().reshape(-1)
            vals = []
            for r in refs:
                v = _read_scalar_ref(f, r)
                if v is not None:
                    vals.append(v)
            vals = [v for v in vals if 50.0 <= v <= 5000.0]
            if vals:
                v0 = float(vals[0])
                if all(abs(v - v0) < 1e-6 for v in vals):
                    return v0
        except Exception:
            pass

    return float(default_fs)

def load_eeg_and_fs(mat_path: str, default_fs: float = 256.0, mat_key_fallback: str = "val"):
    try:
        with h5py.File(mat_path, "r") as f:
            best = None
            for path, ds in _numeric_datasets(f):
                if ds.ndim != 2:
                    continue
                r, c = ds.shape
                ch, smp = min(r, c), max(r, c)
                if not (4 <= ch <= 512 and smp / ch >= 10):
                    continue
                score = ds.size + (0.7 * ds.size if np.issubdtype(ds.dtype, np.floating) else 0)
                if best is None or score > best[0]:
                    best = (score, path)

            if best is None:
                raise RuntimeError(f"EEG dataset not found in {mat_path}")

            eeg_path = best[1]
            eeg = np.asarray(f[eeg_path][()])
            if eeg.shape[0] > eeg.shape[1]:
                eeg = eeg.T
            eeg = eeg.astype(np.float32, copy=False)

            fs = _detect_fs(f, n_ch=eeg.shape[0], default_fs=default_fs)

        return eeg, fs, eeg_path

    except OSError:
        d = sio.loadmat(mat_path)
        if mat_key_fallback not in d:
            raise RuntimeError(f"'{mat_key_fallback}' not found in {mat_path}. Keys: {list(d.keys())}")

        eeg = d[mat_key_fallback]
        if eeg.shape[0] > eeg.shape[1]:
            eeg = eeg.T
        eeg = eeg.astype(np.float32, copy=False)
        fs = float(default_fs)
        return eeg, fs, mat_key_fallback


# ============================================================
# 3) Preprocessing + Features
# ============================================================

def preprocess(eeg: np.ndarray, fs: float,
               band: Tuple[float, float] = (0.5, 70.0),
               notch: Optional[float] = 50.0,
               reref: bool = True) -> np.ndarray:
    x = eeg.astype(np.float64, copy=False)
    nyq = 0.5 * fs
    low, high = band
    if not (0 < low < high < nyq):
        raise ValueError(f"Bad band {band} for fs={fs} (Nyq={nyq}).")

    sos = butter(4, [low / nyq, high / nyq], btype="band", output="sos")
    x = sosfiltfilt(sos, x, axis=1)

    if notch is not None:
        if not (0 < notch < nyq):
            raise ValueError(f"Bad notch {notch} for fs={fs} (Nyq={nyq}).")
        b, a = iirnotch(notch / nyq, Q=30.0)
        x = filtfilt(b, a, x, axis=1)

    if reref:
        x = x - x.mean(axis=0, keepdims=True)

    return x.astype(np.float32, copy=False)

def extract_window_features(eeg_prep: np.ndarray, fs: float,
                            win_sec: float = 2.0, overlap: float = 0.5):
    bands = {
        "delta": (0.5, 4), "theta": (4, 8), "alpha": (8, 13),
        "beta": (13, 30), "gamma": (30, 45), "high_gamma": (45, 70)
    }
    n_ch, n_samp = eeg_prep.shape
    win = int(round(win_sec * fs))
    step = int(round(win * (1 - overlap)))
    if win <= 1 or step <= 0:
        raise ValueError("Bad window parameters; check win_sec/overlap/fs.")
    nper = min(256, win)

    def bandpower(f, p, lo, hi):
        m = (f >= lo) & (f < hi)
        return float(np.trapezoid(p[m], f[m])) if np.any(m) else 0.0

    def sef95(f, p):
        c = np.cumsum(np.maximum(p, 0.0))
        if c[-1] <= 0:
            return 0.0
        return float(f[np.searchsorted(c, 0.95 * c[-1])])

    feat_names = []
    for ch in range(n_ch):
        feat_names += [f"ch{ch}_ll", f"ch{ch}_rms", f"ch{ch}_var", f"ch{ch}_zcr"]
        feat_names += [f"ch{ch}_{b}_bp" for b in bands]
        feat_names += [f"ch{ch}_sef95"]

    X = []
    for s in range(0, n_samp - win + 1, step):
        seg = eeg_prep[:, s:s + win]
        row = []
        for ch in range(n_ch):
            x = seg[ch].astype(np.float64, copy=False)
            row += [
                float(np.sum(np.abs(np.diff(x)))) if x.size > 1 else 0.0,
                float(np.sqrt(np.mean(x * x))),
                float(np.var(x)),
                float(np.mean(x[:-1] * x[1:] < 0)) if x.size > 1 else 0.0,
            ]
            f, p = welch(x, fs=fs, nperseg=nper, noverlap=nper // 2)
            for lo, hi in bands.values():
                row.append(bandpower(f, p, lo, hi))
            row.append(sef95(f, p))
        X.append(row)

    return np.asarray(X, dtype=np.float32), feat_names


# ============================================================
# 4) Patient aggregation + optional log1p
# ============================================================

def log1p_bandpower_vector(v: np.ndarray, feature_names: List[str]) -> np.ndarray:
    v = v.copy()
    bp_idx = [i for i, n in enumerate(feature_names) if "_bp" in n]
    if bp_idx:
        v[bp_idx] = np.log1p(np.maximum(v[bp_idx], 0.0))
    return v

def aggregate_patient_windows(X_all_windows: np.ndarray, method: str = "mean") -> np.ndarray:
    if X_all_windows.ndim != 2 or X_all_windows.shape[0] == 0:
        raise ValueError("Need non-empty window matrix (n_windows, n_features).")

    if method == "mean":
        return X_all_windows.mean(axis=0)
    if method == "median":
        return np.median(X_all_windows, axis=0)
    if method == "mean_std":
        mu = X_all_windows.mean(axis=0)
        sd = X_all_windows.std(axis=0, ddof=0)
        return np.concatenate([mu, sd], axis=0)
    raise ValueError("method must be one of: 'mean', 'median', 'mean_std'")


# ============================================================
# 5) NEW: Normalizer across patients (zscore / robust / minmax)
# ============================================================

NormMethod = Literal["none", "zscore", "robust", "minmax"]

def fit_normalizer(X: np.ndarray, method: NormMethod):
    X = np.asarray(X, dtype=np.float64)

    if method == "none":
        return {"method": "none"}

    if method == "zscore":
        mean = X.mean(axis=0)
        std = X.std(axis=0, ddof=0)
        std = np.where(std > 0, std, 1.0)
        return {"method": "zscore", "mean": mean, "std": std}

    if method == "robust":
        med = np.median(X, axis=0)
        q1 = np.quantile(X, 0.25, axis=0)
        q3 = np.quantile(X, 0.75, axis=0)
        iqr = q3 - q1
        iqr = np.where(iqr > 0, iqr, 1.0)
        return {"method": "robust", "median": med, "iqr": iqr}

    if method == "minmax":
        mn = X.min(axis=0)
        mx = X.max(axis=0)
        rng = mx - mn
        rng = np.where(rng > 0, rng, 1.0)
        return {"method": "minmax", "min": mn, "range": rng}

    raise ValueError("Unknown method")

def transform_with_normalizer(X: np.ndarray, params: dict) -> np.ndarray:
    X = np.asarray(X, dtype=np.float64)
    m = params["method"]

    if m == "none":
        return X.astype(np.float32)

    if m == "zscore":
        return ((X - params["mean"]) / params["std"]).astype(np.float32)

    if m == "robust":
        return ((X - params["median"]) / params["iqr"]).astype(np.float32)

    if m == "minmax":
        return ((X - params["min"]) / params["range"]).astype(np.float32)

    raise ValueError("Bad normalizer params")


# ============================================================
# 6) Main: build patient vectors + (optional) normalize
# ============================================================

def build_patient_feature_vectors(
    folder: str,
    *,
    mat_glob: str = "*.mat",
    patient_regex: str = _PATIENT_REGEX_DEFAULT,
    default_fs: float = 256.0,
    mat_key_fallback: str = "val",
    band: Tuple[float, float] = (0.5, 60.0),
    notch: Optional[float] = 50.0,
    reref: bool = True,
    win_sec: float = 2.0,
    overlap: float = 0.5,
    agg_method: str = "mean",
    log1p_bp: bool = True,
    normalize: NormMethod = "none",   # <-- NEW
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], List[str], Dict[str, dict], dict]:
    """
    Returns:
      patient_vectors_raw : dict patient_id -> raw aggregated vector
      patient_vectors_norm: dict patient_id -> normalized vector (if normalize != 'none')
      feature_names       : list[str]
      metadata            : dict
      norm_params         : dict (normalizer parameters)
    """
    groups = group_files_by_patient(folder, mat_glob=mat_glob, patient_regex=patient_regex)

    patient_vectors_raw: Dict[str, np.ndarray] = {}
    metadata: Dict[str, dict] = {}
    feat_names_ref: Optional[List[str]] = None

    for pid, file_list in groups.items():
        windows_all = []
        per_file_info = []

        for fp in file_list:
            eeg, fs, eeg_path = load_eeg_and_fs(fp, default_fs=default_fs, mat_key_fallback=mat_key_fallback)
            eeg_prep = preprocess(eeg, fs, band=band, notch=notch, reref=reref)
            Xw, feat_names = extract_window_features(eeg_prep, fs, win_sec=win_sec, overlap=overlap)

            if feat_names_ref is None:
                feat_names_ref = feat_names
            elif feat_names != feat_names_ref:
                raise ValueError(f"Feature mismatch detected. File: {fp}")

            windows_all.append(Xw)
            per_file_info.append({
                "file": os.path.basename(fp),
                "path": fp,
                "fs": float(fs),
                "eeg_path": eeg_path,
                "n_windows": int(Xw.shape[0]),
                "n_channels": int(eeg.shape[0]),
                "n_samples": int(eeg.shape[1]),
            })

        if not windows_all or sum(w.shape[0] for w in windows_all) == 0:
          raise ValueError(f"No valid windows for patient {pid}")

        X_patient_windows = np.vstack(windows_all)
        v = aggregate_patient_windows(X_patient_windows, method=agg_method)

        # feature name expansion for mean_std
        feature_names = feat_names_ref
        if agg_method == "mean_std":
            feature_names = [f"{n}_mean" for n in feat_names_ref] + [f"{n}_std" for n in feat_names_ref]

        if log1p_bp:
            v = log1p_bandpower_vector(v, feature_names)

        patient_vectors_raw[pid] = v.astype(np.float32, copy=False)
        metadata[pid] = {
            "patient_id": pid,
            "n_files": len(file_list),
            "files": per_file_info,
            "total_windows": int(X_patient_windows.shape[0]),
            "agg_method": agg_method,
        }

    if feat_names_ref is None:
        raise RuntimeError("No features extracted.")

    final_feature_names = feat_names_ref
    if agg_method == "mean_std":
        final_feature_names = [f"{n}_mean" for n in feat_names_ref] + [f"{n}_std" for n in feat_names_ref]

    # ----- normalization across patients (optional) -----
    pids_sorted = sorted(patient_vectors_raw.keys())
    X_raw = np.stack([patient_vectors_raw[k] for k in pids_sorted])

    norm_params = fit_normalizer(X_raw, method=normalize)

    if normalize == "none":
        X_norm = X_raw.astype(np.float32)
    else:
        X_norm = transform_with_normalizer(X_raw, norm_params)

    patient_vectors_norm: Dict[str, np.ndarray] = {}
    for i, pid in enumerate(pids_sorted):
        patient_vectors_norm[pid] = X_norm[i].astype(np.float32, copy=False)

    return patient_vectors_raw, patient_vectors_norm, final_feature_names, metadata, norm_params


**Now test the EEG pipeline**

In [7]:
# ============================================================
# 5) Example usage (Colab)
# ============================================================
if __name__ == "__main__":
    folder = "/content/EEG_DATA/"  # change
    patient_vectors_raw, patient_vectors_norm, feature_names, meta, norm_params = build_patient_feature_vectors(

        folder,
        mat_glob="*.mat",
        patient_regex=r"^(Patient\d+)",  # adjust if needed
        default_fs=256.0,
        band=(0.5, 60.0),
        notch=40.0,
        win_sec=2.0,
        overlap=0.5,
        agg_method="mean",
        log1p_bp=True,
 #        normalize="zscore", when use multiple patients
    )

print("Patients:", len(patient_vectors_raw))
first_pid = sorted(patient_vectors_raw.keys())[0]
print("Example patient:", first_pid, "raw shape:", patient_vectors_raw[first_pid].shape)
print("Example patient:", first_pid, "norm shape:", patient_vectors_norm[first_pid].shape)
print("Feature count:", len(feature_names))
print("Files for", first_pid, "=", meta[first_pid]["n_files"], "Total windows:", meta[first_pid]["total_windows"])
print("Normalizer method:", norm_params.get("method"))

Patients: 9
Example patient: Patient1 raw shape: (363,)
Example patient: Patient1 norm shape: (363,)
Feature count: 363
Files for Patient1 = 7 Total windows: 12619
Normalizer method: none


In [10]:
!mkdir -p /content/drive/MyDrive/EEG_CACHE

In [12]:
import joblib
import os
CACHE_DIR = "/content/drive/MyDrive/EEG_CACHE"

CACHE_FILE = os.path.join(CACHE_DIR, "patient_features_v1.joblib")

joblib.dump(
    {
        "patient_vectors_raw": patient_vectors_raw,
        "patient_vectors_norm": patient_vectors_norm,
        "feature_names": feature_names,
        "meta": meta,
        "norm_params": norm_params,
    },
    CACHE_FILE,
    compress=3,  # balances speed & size
)

['/content/drive/MyDrive/EEG_CACHE/patient_features_v1.joblib']

In [13]:
print("Patients:", len(patient_vectors_raw))
first_pid = sorted(patient_vectors_raw.keys())[0]
print("Example patient:", first_pid, "raw shape:", patient_vectors_raw[first_pid].shape)
print("Example patient:", first_pid, "norm shape:", patient_vectors_norm[first_pid].shape)
print("Feature count:", len(feature_names))
print("Files for", first_pid, "=", meta[first_pid]["n_files"], "Total windows:", meta[first_pid]["total_windows"])
print("Normalizer method:", norm_params.get("method"))

Patients: 9
Example patient: Patient1 raw shape: (363,)
Example patient: Patient1 norm shape: (363,)
Feature count: 363
Files for Patient1 = 7 Total windows: 12619
Normalizer method: none
