# Gait FOG Detection — Fast Baseline (with robust `events.csv` support)

This notebook builds windowed features from your **gait** CSVs and trains a simple classifier for **FOG detection**.
It includes fast-mode controls and robust handling of your `events.csv` schema (e.g., `id`, `init`, `completion`).

In [21]:
# --- Environment & deps ---
import importlib, sys, subprocess

def ensure(pkg, pip_name=None):
    name = pip_name or pkg
    try:
        importlib.import_module(pkg)
        print(f"{pkg} OK")
    except ImportError:
        print(f"Installing {name} ...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", name])
        importlib.invalidate_caches()
        importlib.import_module(pkg)
        print(f"{pkg} installed")

ensure("numpy")
ensure("pandas")
ensure("scipy")
ensure("sklearn", "scikit-learn")
ensure("tqdm")


numpy OK
pandas OK
scipy OK
sklearn OK
tqdm OK


In [22]:
# --- Imports ---
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# --- Locate GAIT directory automatically (edit POSSIBLE_ROOTS if needed) ---
POSSIBLE_ROOTS = [
    Path(r"C:\Users\muham\_Projects\PD New"),
    Path(r"C:\Users\muham\_Projects\PD"),
    Path.cwd(),
]

def find_gait_dir():
    for base in POSSIBLE_ROOTS:
        for cand in [base / "data" / "gait", base / "gait"]:
            if cand.exists():
                return cand
    return None

GAIT_DIR = find_gait_dir()
print("GAIT_DIR ->", GAIT_DIR)
assert GAIT_DIR is not None, "Couldn't locate gait folder. Set GAIT_DIR manually to your '.../data/gait' path."

# --- Fast mode and signal config ---
FS = 100          # default Hz if no time column
WIN_S = 2.0       # seconds
HOP_S = 0.5       # seconds

# ===== Speed knobs =====
FAST_MODE = True               # quick run
DOWN_SAMPLE_TO = 50            # Hz target if a time column exists (try 25 for faster)
MAX_RECORDINGS = 120           # limit # of training recordings processed (raise later)
MAX_WINDOWS_PER_REC = 500      # cap windows per recording
USE_FREQ_FEATS = False         # time-domain stats only (faster). Set True later for accuracy.
PRINT_EVERY = 10               # status every N recordings
NUMERIC_LIMIT = 128            # limit number of channels to keep feature size manageable


GAIT_DIR -> C:\Users\muham\_Projects\PD New\data\gait


## Manifest (list of recordings)
If you already created a `gait_manifest.csv`, we'll load it. Otherwise, we build a simple manifest by scanning `train/*/*.csv`.

In [23]:
MANIFESTS_DIR = GAIT_DIR.parent / "manifests"
MANIFESTS_DIR.mkdir(parents=True, exist_ok=True)
MANIFEST_PATH = MANIFESTS_DIR / "gait_manifest.csv"

def build_simple_gait_manifest(gait_dir: Path) -> pd.DataFrame:
    rows = []
    # scan train split only for supervised training
    train_dir = gait_dir / "train"
    if not train_dir.exists():
        # fallback: scan all CSVs under gait_dir
        scan_dirs = [gait_dir]
    else:
        scan_dirs = [train_dir]
    for root in scan_dirs:
        for p in root.rglob("*.csv"):
            # skip metadata-ish files
            if p.name.lower() in {"subjects.csv","tasks.csv","events.csv","daily_metadata.csv",
                                  "tdcsfog_metadata.csv","defog_metadata.csv","sample_submission.csv"}:
                continue
            split = "train" if "train" in str(p).replace("\\","/").lower() else "unknown"
            source = Path(p).parent.name  # e.g., defog/tdcsfog/notype
            rows.append({
                "path": str(p),
                "recording_id": p.stem,
                "split": split,
                "source": source,
            })
    df = pd.DataFrame(rows).drop_duplicates(subset=["path"]).reset_index(drop=True)
    # No subject_id available? default to recording_id (grouping will still prevent leakage per recording)
    if "subject_id" not in df.columns:
        df["subject_id"] = df["recording_id"]
    return df

if MANIFEST_PATH.exists():
    gman = pd.read_csv(MANIFEST_PATH)
    print(f"[manifest] Loaded -> {MANIFEST_PATH} (rows={len(gman)})")
else:
    gman = build_simple_gait_manifest(GAIT_DIR)
    gman.to_csv(MANIFEST_PATH, index=False)
    print(f"[manifest] Built -> {MANIFEST_PATH} (rows={len(gman)})")

# Restrict to train split and records that actually exist
gman = gman[gman["split"]=="train"].copy()
gman = gman[gman["path"].apply(lambda s: Path(str(s)).is_file())].reset_index(drop=True)
print("[manifest] train rows:", len(gman))
gman.head(3)


[manifest] Loaded -> C:\Users\muham\_Projects\PD New\data\manifests\gait_manifest.csv (rows=970)
[manifest] train rows: 970


Unnamed: 0,path,recording_id,split,source,subject_id
0,C:\Users\muham\_Projects\PD New\data\gait\trai...,02ea782681,train,defog,02ea782681
1,C:\Users\muham\_Projects\PD New\data\gait\trai...,06414383cf,train,defog,06414383cf
2,C:\Users\muham\_Projects\PD New\data\gait\trai...,092b4c1819,train,defog,092b4c1819


## Events (`events.csv`) normalization
Your file uses columns like `id`, `init`, `completion`. We'll map them to a unified schema:
- `Id` — recording identifier (matched to CSV file stem)
- `start_time` — event start (seconds or samples)
- `end_time` — event end (seconds or samples)

In [24]:
# --- Load & normalize events.csv (supports columns: id/init/completion) ---
import pandas as pd

# Try common locations
candidates = [
    GAIT_DIR / "events.csv",
    GAIT_DIR.parent / "gait" / "events.csv",
    GAIT_DIR.parent.parent / "gait" / "events.csv",
]
events = None
for p in candidates:
    if p.is_file():
        events = pd.read_csv(p)
        print("[events] Using ->", p)
        break
if events is None:
    # Last-resort: search upward
    found = list(GAIT_DIR.parent.rglob("events.csv"))
    if found:
        events = pd.read_csv(found[0])
        print("[events] Using (found) ->", found[0])

if events is None:
    raise FileNotFoundError("events.csv not found. Place it under the gait folder.")

ev = events.copy()
ev.columns = [c.strip().lower() for c in ev.columns]

# Map different schemas to unified names
rename_map = {}
if "id" in ev.columns:          rename_map["id"] = "Id"
if "recording_id" in ev.columns:rename_map["recording_id"] = "Id"
if "series_id" in ev.columns:   rename_map["series_id"] = "Id"

if "init" in ev.columns:        rename_map["init"] = "start_time"
if "start_time" in ev.columns:  rename_map["start_time"] = "start_time"
if "start" in ev.columns:       rename_map["start"] = "start_time"
if "begin" in ev.columns:       rename_map["begin"] = "start_time"

if "completion" in ev.columns:  rename_map["completion"] = "end_time"
if "end_time" in ev.columns:    rename_map["end_time"] = "end_time"
if "end" in ev.columns:         rename_map["end"] = "end_time"
if "stop" in ev.columns:        rename_map["stop"] = "end_time"

ev = ev.rename(columns=rename_map)

required = {"Id","start_time","end_time"}
missing = required - set(ev.columns)
if missing:
    raise RuntimeError(f"events.csv missing columns {missing}. Found columns: {list(ev.columns)}")

# Coerce numeric and clean
ev["start_time"] = pd.to_numeric(ev["start_time"], errors="coerce")
ev["end_time"]   = pd.to_numeric(ev["end_time"], errors="coerce")
ev = ev.dropna(subset=["start_time","end_time"]).reset_index(drop=True)
ev = ev[ev["end_time"] > ev["start_time"]].reset_index(drop=True)

# If looks like milliseconds, convert to seconds
if float(ev[["start_time","end_time"]].max().max()) > 1e4:
    ev[["start_time","end_time"]] = ev[["start_time","end_time"]] / 1000.0

# Normalize Id to be comparable to CSV stems
ev["Id"] = ev["Id"].astype(str).str.replace(".csv","", regex=False).str.strip()
events = ev

# Ensure manifest has recording_id as stem
gman["recording_id"] = gman["recording_id"].astype(str).str.replace(".csv","", regex=False)

# Quick sanity
from pathlib import Path as _P
stems = set(_P(str(s)).stem for s in gman["path"].astype(str))
unmatched = set(events["Id"].unique()) - stems
print(f"[events] rows={len(events)} | unique Ids={events['Id'].nunique()} | unmatched IDs={len(unmatched)}")
events.head(3)


[events] Using -> C:\Users\muham\_Projects\PD New\data\gait\events.csv
[events] rows=3544 | unique Ids=535 | unmatched IDs=0


Unnamed: 0,Id,start_time,end_time,type,kinetic
0,003f117e14,8.61312,14.7731,Turn,1.0
1,009ee11563,11.3847,41.1847,Turn,1.0
2,009ee11563,54.6647,58.7847,Turn,1.0


In [25]:
import numpy as np
import pandas as pd

def safe_fs(fs, default=None):
    """Return a valid sampling rate (>=1 Hz)."""
    if default is None:
        default = float(globals().get("FS", 100.0))
    try:
        fs = float(fs)
    except Exception:
        fs = np.nan
    if not np.isfinite(fs) or fs <= 0:
        return float(default)
    return max(1.0, fs)

def finalize_fs(fs):
    """Clamp fs after you infer it from timestamps."""
    return safe_fs(fs)

def window_signal(n_samples, fs, win_s, hop_s, cap=None):
    """Window emitter with guards so hop is never zero."""
    fs = safe_fs(fs)
    win = max(1, int(round(win_s * fs)))
    hop = max(1, int(round(hop_s * fs)))
    count = 0
    for start in range(0, max(1, n_samples - win + 1), hop):
        if cap is not None and count >= cap:
            break
        count += 1
        yield start, start + win
    # If no window got emitted but we have enough samples, emit one
    if count == 0 and n_samples >= win:
        yield 0, win


## Fast feature extraction
- Optional **downsampling** to `DOWN_SAMPLE_TO`
- **Time-domain** features only (mean/std/RMS) by default
- Caps windows per recording for speed

In [26]:
from scipy.signal import welch
import numpy as np
import pandas as pd

def window_signal(n_samples, fs, win_s, hop_s, cap=None):
    win = int(round(win_s*fs))
    hop = int(round(hop_s*fs))
    count = 0
    for start in range(0, max(1, n_samples - win + 1), hop):
        if cap is not None and count >= cap:
            break
        count += 1
        yield start, start+win

def downsample_if_needed(df, target_fs, time_col):
    """Round timestamps to 1/target_fs bins and average numeric columns."""
    if time_col is None or target_fs is None:
        return df, None
    tt = pd.to_numeric(df[time_col], errors="coerce").to_numpy()
    if np.isnan(tt).all() or tt.size < 2:
        return df, None
    dt = np.nanmedian(np.diff(tt))
    if not (dt > 0):
        return df, None
    fs_now = 1.0 / dt
    if fs_now <= target_fs + 1e-6:
        return df, fs_now
    bins = np.floor(tt * target_fs) / target_fs
    g = df.groupby(bins)
    df_ds = g.mean(numeric_only=True).reset_index(drop=True)
    return df_ds, target_fs

def extract_numeric_matrix(df):
    num_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)]
    if not num_cols:
        return None, []
    num_cols = num_cols[:NUMERIC_LIMIT]
    Xs = [df[c].to_numpy(dtype=float, copy=False) for c in num_cols]
    Xmat = np.vstack(Xs)  # (C, T)
    return Xmat, num_cols

def summarize_time(x):
    mean = x.mean(axis=1)
    std  = x.std(axis=1)
    rms  = np.sqrt((x**2).mean(axis=1))
    return np.concatenate([mean, std, rms], axis=0)

def summarize_with_freq(x, fs):
    feats = [x.mean(axis=1), x.std(axis=1), np.sqrt((x**2).mean(axis=1))]
    f, Pxx = welch(x, fs=fs, axis=1, nperseg=min(128, x.shape[1]))
    def band_power(lo, hi):
        idx = (f >= lo) & (f < hi)
        return Pxx[:, idx].sum(axis=1)
    for lo, hi in [(0.1,0.5),(0.5,3),(3,8),(8,20)]:
        feats.append(band_power(lo, hi))
    return np.concatenate(feats, axis=0)


In [27]:
import numpy as np
import pandas as pd

def safe_fs(fs, default=None):
    """Clamp sampling rate to a valid value (>=1 Hz)."""
    if default is None:
        default = float(globals().get("FS", 100.0))
    try:
        fs = float(fs)
    except Exception:
        fs = np.nan
    if not np.isfinite(fs) or fs <= 0:
        return float(default)
    return max(1.0, fs)

def window_signal(n_samples, fs, win_s, hop_s, cap=None):
    """Emit [start, end) indices with guarded hop so it never becomes 0."""
    fs = safe_fs(fs)
    win = max(1, int(round(win_s * fs)))
    hop = max(1, int(round(hop_s * fs)))
    count = 0
    for start in range(0, max(1, n_samples - win + 1), hop):
        if cap is not None and count >= cap:
            break
        count += 1
        yield start, start + win
    # If nothing emitted but long enough for one window, emit one
    if count == 0 and n_samples >= win:
        yield 0, win


In [28]:
FEAT_DIM = None          # will lock after first window
PAD_VALUE = np.nan       # or 0.0


## Build dataset (X, y, groups)
Labels are assigned 1 if the **center** of a window lies within any `[start_time, end_time]` of the same recording in `events.csv`.
If your events times are in **samples**, we auto-convert them to seconds using the estimated sampling rate.

In [29]:
rows, labels, groups = [], [], []

gtrain = gman.copy()  # already restricted to train
rec_paths = gtrain["path"].astype(str).tolist()
if FAST_MODE and MAX_RECORDINGS is not None:
    rec_paths = rec_paths[:int(MAX_RECORDINGS)]

# cache events by Id for quick lookup
ev_by_id = {k: v.reset_index(drop=True) for k, v in events.groupby("Id")}

added_total = 0
for i, path_str in enumerate(tqdm(rec_paths, desc="Recordings")):
    p = Path(path_str)
    try:
        df = pd.read_csv(p)
    except Exception as e:
        print("Skip", p.name, "->", e, flush=True)
        continue

    # time column if present
    time_col = next((c for c in ["time","Time","t","timestamp","Timestamp"] if c in df.columns), None)

    # estimate fs
    fs = FS
    if time_col is not None:
        tt = pd.to_numeric(df[time_col], errors="coerce").to_numpy()
        if tt.size > 1:
            dt = np.nanmedian(np.diff(tt))
            if dt and dt > 0:
                fs = float(round(1.0/dt))

    # optional downsample
    if FAST_MODE and time_col is not None and DOWN_SAMPLE_TO is not None:
        df, fs2 = downsample_if_needed(df, DOWN_SAMPLE_TO, time_col)
        if fs2 is not None:
            fs = fs2

    Xmat, cols = extract_numeric_matrix(df)
    if Xmat is None:  # no numeric data
        continue

    # windowing with cap
    cap = MAX_WINDOWS_PER_REC if FAST_MODE else None
    T = Xmat.shape[1]
    rec_dur_s = T / float(fs) if fs else 0.0

    rec_id = p.stem
    ev_rec = ev_by_id.get(rec_id, None)

    # if events likely in samples, convert to seconds
    if ev_rec is not None and len(ev_rec):
        et_max = float(ev_rec["end_time"].max())
        if rec_dur_s > 0 and (et_max > rec_dur_s * 1.5) and (et_max <= T * 1.5):
            evc = ev_rec.copy()
            evc["start_time"] = evc["start_time"] / float(fs)
            evc["end_time"]   = evc["end_time"]   / float(fs)
            ev_rec = evc

    n_added = 0
    for s, e in window_signal(T, fs, WIN_S, HOP_S, cap=cap):
        xw = Xmat[:, s:e]
        if xw.shape[1] < int(WIN_S*fs):
            continue

        feats = summarize_time(xw) if (FAST_MODE and not USE_FREQ_FEATS) else summarize_with_freq(xw, fs)
        feats = np.asarray(feats, dtype=float).ravel()
        
        if FEAT_DIM is None:
            FEAT_DIM = feats.size
        
        if feats.size != FEAT_DIM:
            if feats.size < FEAT_DIM:
                feats = np.pad(feats, (0, FEAT_DIM - feats.size), constant_values=PAD_VALUE)
            else:
                feats = feats[:FEAT_DIM]

        # label via event center
        center_t = (s + (e - s)/2) / float(fs)
        yv = 0
        if ev_rec is not None and len(ev_rec):
            st = ev_rec["start_time"].to_numpy()
            en = ev_rec["end_time"].to_numpy()
            yv = int(((st <= center_t) & (center_t <= en)).any())

        rows.append(feats)
        labels.append(yv)
        # group by subject (fallback to recording_id)
        subj = gtrain.loc[gtrain["path"]==path_str, "subject_id"].astype(str).values
        groups.append(subj[0] if subj.size else rec_id)
        n_added += 1

    if (i+1) % max(1, PRINT_EVERY) == 0:
        print(f"[{i+1}/{len(rec_paths)}] {p.name}: {n_added} windows | fs={fs}Hz", flush=True)

X = np.asarray(rows, dtype=float)
y = np.asarray(labels, dtype=int)
groups = np.asarray(groups)

print("Feature matrix:", X.shape)
print("Label counts:", pd.Series(y).value_counts().to_dict())


Recordings:   8%|▊         | 9/120 [00:02<00:36,  3.07it/s]

[10/120] 15508c7f41.csv: 500 windows | fs=1.0Hz


Recordings:  16%|█▌        | 19/120 [00:06<00:33,  2.97it/s]

[20/120] 32d03020a9.csv: 500 windows | fs=1.0Hz


Recordings:  24%|██▍       | 29/120 [00:09<00:33,  2.74it/s]

[30/120] 4f613ccf88.csv: 500 windows | fs=1.0Hz


Recordings:  32%|███▎      | 39/120 [00:12<00:25,  3.12it/s]

[40/120] 6a20935af5.csv: 500 windows | fs=1.0Hz


Recordings:  41%|████      | 49/120 [00:16<00:24,  2.85it/s]

[50/120] 850748a138.csv: 500 windows | fs=1.0Hz


Recordings:  49%|████▉     | 59/120 [00:19<00:19,  3.17it/s]

[60/120] a2f1a8ab76.csv: 500 windows | fs=1.0Hz


Recordings:  57%|█████▊    | 69/120 [00:23<00:17,  2.93it/s]

[70/120] be9d33541d.csv: 500 windows | fs=1.0Hz


Recordings:  66%|██████▌   | 79/120 [00:26<00:14,  2.86it/s]

[80/120] e1f92471b9.csv: 500 windows | fs=1.0Hz


Recordings:  74%|███████▍  | 89/120 [00:29<00:10,  2.92it/s]

[90/120] f9efef91fb.csv: 500 windows | fs=1.0Hz


Recordings:  82%|████████▎ | 99/120 [00:33<00:08,  2.43it/s]

[100/120] 2cc3c30645.csv: 500 windows | fs=1.0Hz


Recordings:  91%|█████████ | 109/120 [00:37<00:04,  2.31it/s]

[110/120] 60f28aa837.csv: 500 windows | fs=1.0Hz


Recordings:  99%|█████████▉| 119/120 [00:41<00:00,  2.57it/s]

[120/120] 89e9ed32d1.csv: 500 windows | fs=1.0Hz


Recordings: 100%|██████████| 120/120 [00:42<00:00,  2.85it/s]

Feature matrix: (60000, 21)
Label counts: {0: 59044, 1: 956}





## Train & evaluate (StratifiedGroupKFold)
We use a simple **LogisticRegression** with standardization. Metrics: **ROC AUC** and **accuracy**.

In [30]:
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.model_selection import StratifiedGroupKFold
import numpy as np

# Replace ±inf with NaN so the imputer can handle them
X = np.asarray(X, dtype=float)
X[~np.isfinite(X)] = np.nan

# Sanity: require at least 2 classes
if len(np.unique(y)) < 2:
    raise RuntimeError("All windows share the same label. Check events mapping or units.")

cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

model = make_pipeline(
    SimpleImputer(strategy="median"),                    # ← handles NaN
    StandardScaler(with_mean=True, with_std=True),
    LogisticRegression(max_iter=1000, class_weight="balanced", solver="lbfgs")
)

aucs, accs = [], []
for k, (tr, va) in enumerate(cv.split(X, y, groups=groups), 1):
    model.fit(X[tr], y[tr])
    prob = model.predict_proba(X[va])[:, 1]
    pred = (prob >= 0.5).astype(int)
    try:
        auc = roc_auc_score(y[va], prob)
    except Exception:
        auc = np.nan
    acc = accuracy_score(y[va], pred)
    aucs.append(auc); accs.append(acc)
    print(f"Fold {k}: AUC={auc:.3f} ACC={acc:.3f} (n={len(va)})")

print("\nMean AUC={:.3f} ± {:.3f} | Mean ACC={:.3f} ± {:.3f}".format(
    np.nanmean(aucs), np.nanstd(aucs), np.mean(accs), np.std(accs)))


Fold 1: AUC=0.745 ACC=0.772 (n=12000)
Fold 2: AUC=0.832 ACC=0.616 (n=12000)
Fold 3: AUC=0.333 ACC=0.649 (n=12000)
Fold 4: AUC=0.777 ACC=0.692 (n=12000)
Fold 5: AUC=0.586 ACC=0.793 (n=12000)

Mean AUC=0.655 ± 0.180 | Mean ACC=0.704 ± 0.069


In [31]:
from sklearn.metrics import average_precision_score, precision_recall_curve, f1_score

def eval_fold(y_true, p):
    ap = average_precision_score(y_true, p)
    prs, rcs, th = precision_recall_curve(y_true, p)
    f1s = 2*prs[:-1]*rcs[:-1]/(prs[:-1]+rcs[:-1]+1e-12)
    i = f1s.argmax()
    return ap, f1s[i], th[i]

ap_list, f1_list = [], []
for k,(tr,va) in enumerate(cv.split(X, y, groups=groups),1):
    model.fit(X[tr], y[tr])
    prob = model.predict_proba(X[va])[:,1]
    ap, f1b, thb = eval_fold(y[va], prob)
    ap_list.append(ap); f1_list.append(f1b)
    print(f"Fold {k}: PR-AUC={ap:.3f}  Best-F1={f1b:.3f} @ thr={thb:.3f}")
print("Mean PR-AUC={:.3f}  Mean Best-F1={:.3f}".format(np.mean(ap_list), np.mean(f1_list)))


Fold 1: PR-AUC=0.032  Best-F1=0.100 @ thr=0.802
Fold 2: PR-AUC=0.204  Best-F1=0.351 @ thr=0.783
Fold 3: PR-AUC=0.009  Best-F1=0.029 @ thr=0.090
Fold 4: PR-AUC=0.025  Best-F1=0.058 @ thr=0.600
Fold 5: PR-AUC=0.018  Best-F1=0.052 @ thr=0.459
Mean PR-AUC=0.058  Mean Best-F1=0.118


In [32]:
from sklearn.linear_model import LogisticRegressionCV
model = make_pipeline(
    SimpleImputer(strategy="median"),
    StandardScaler(),
    LogisticRegressionCV(
        Cs=np.logspace(-3,3,13),
        cv=3,
        scoring="roc_auc",
        class_weight="balanced",
        max_iter=3000,
        n_jobs=-1
    )
)


In [33]:
from sklearn.ensemble import HistGradientBoostingClassifier
model = make_pipeline(
    SimpleImputer(strategy="median"),
    HistGradientBoostingClassifier(
        max_depth=None,
        learning_rate=0.05,
        max_iter=300,
        random_state=42
    )
)
# When fitting, pass sample_weight to balance classes:
#   from sklearn.utils.class_weight import compute_sample_weight
#   w = compute_sample_weight("balanced", y[tr])
#   model.fit(X[tr], y[tr], histgradientboostingclassifier__sample_weight=w)


In [34]:
print(events.head())
print("Example rec:", rec_id, "dur_s=", rec_dur_s, "fs=", fs)


           Id  start_time  end_time  type  kinetic
0  003f117e14     8.61312   14.7731  Turn      1.0
1  009ee11563    11.38470   41.1847  Turn      1.0
2  009ee11563    54.66470   58.7847  Turn      1.0
3  011322847a    28.09660   30.2966  Turn      1.0
4  01d0fe7266    30.31840   31.8784  Turn      1.0
Example rec: 89e9ed32d1 dur_s= 164756.0 fs= 1.0


In [36]:
# ==== Cross-validated evaluation: window / recording / subject ====
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import (
    roc_auc_score, accuracy_score, average_precision_score,
    precision_recall_curve
)
from sklearn.model_selection import StratifiedGroupKFold
import warnings

# ---------- Safety & prep ----------
X = np.asarray(X, dtype=float)
y = np.asarray(y, dtype=int)
groups = np.asarray(groups)

# If rec_ids wasn't created during feature-building, fall back (not ideal but prevents crash)
try:
    rec_ids = np.asarray(rec_ids)
    if rec_ids.shape[0] != X.shape[0]:
        raise ValueError("rec_ids length mismatch.")
except Exception:
    warnings.warn("rec_ids not found or mismatched; falling back to groups for rec-level aggregation.")
    rec_ids = groups.copy()

# Replace ±inf with NaN, to be imputed
X[~np.isfinite(X)] = np.nan

# Sanity: require at least 2 classes overall
if np.unique(y).size < 2:
    raise RuntimeError("All windows share the same label. Check events mapping or units.")

# ---------- CV + model ----------
cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

# Stronger baseline than plain LR: LR with inner CV on C, balanced classes
model = make_pipeline(
    SimpleImputer(strategy="median"),
    StandardScaler(),
    LogisticRegressionCV(
        Cs=np.logspace(-3, 3, 13),
        cv=3,
        scoring="roc_auc",
        class_weight="balanced",
        max_iter=3000,
        n_jobs=-1,
        solver="lbfgs"
    )
)

# ---------- Helpers ----------
def pr_summary(y_true, p):
    ap = average_precision_score(y_true, p)
    pr, rc, th = precision_recall_curve(y_true, p)
    f1 = (2 * pr[:-1] * rc[:-1]) / (pr[:-1] + rc[:-1] + 1e-12)
    best_idx = int(f1.argmax())
    return ap, float(f1[best_idx]), float(th[best_idx])

# ---------- Loop ----------
aucs_win, accs_win = [], []
pr_aucs_win, best_f1_win = [], []
aucs_rec, aucs_subj = [], []

for k, (tr, va) in enumerate(cv.split(X, y, groups=groups), 1):
    Xtr, ytr = X[tr], y[tr]
    Xva, yva = X[va], y[va]

    # Fit
    model.fit(Xtr, ytr)

    # Predict probabilities
    p = model.predict_proba(Xva)[:, 1]
    yhat = (p >= 0.5).astype(int)

    # Window-level metrics
    try:
        auc_w = roc_auc_score(yva, p)
    except Exception:
        auc_w = np.nan
    acc_w = accuracy_score(yva, yhat)
    ap_w, f1b_w, thr_w = pr_summary(yva, p)

    aucs_win.append(auc_w); accs_win.append(acc_w)
    pr_aucs_win.append(ap_w); best_f1_win.append(f1b_w)

    # Recording-level AUC (mean prob per recording, label = any-positive window)
    dfp = pd.DataFrame({"rec": rec_ids[va], "y": yva, "p": p})
    agg_rec = dfp.groupby("rec", as_index=False).agg(y=("y", "max"), p=("p", "mean"))
    if agg_rec["y"].nunique() > 1:
        auc_r = roc_auc_score(agg_rec["y"], agg_rec["p"])
    else:
        auc_r = np.nan
    aucs_rec.append(auc_r)

    # Subject-level AUC (mean prob per subject, label = any-positive window)
    dfs = pd.DataFrame({"subj": groups[va], "y": yva, "p": p})
    agg_subj = dfs.groupby("subj", as_index=False).agg(y=("y", "max"), p=("p", "mean"))
    if agg_subj["y"].nunique() > 1:
        auc_s = roc_auc_score(agg_subj["y"], agg_subj["p"])
    else:
        auc_s = np.nan
    aucs_subj.append(auc_s)

    pos_rate = float(yva.mean())
    print(
        f"Fold {k}: "
        f"win AUC={auc_w:.3f} ACC={acc_w:.3f} PR-AUC={ap_w:.3f} Best-F1={f1b_w:.3f} "
        f"| rec AUC={auc_r if np.isfinite(auc_r) else float('nan'):.3f} "
        f"| subj AUC={auc_s if np.isfinite(auc_s) else float('nan'):.3f} "
        f"(n_win={len(va)}, n_rec={len(agg_rec)}, n_subj={len(agg_subj)}, pos_rate={pos_rate:.4f})"
    )

# ---------- Summary ----------
print("\nMEANS → "
      f"Window:  AUC={np.nanmean(aucs_win):.3f}±{np.nanstd(aucs_win):.3f}  "
      f"ACC={np.mean(accs_win):.3f}±{np.std(accs_win):.3f}  "
      f"PR-AUC={np.mean(pr_aucs_win):.3f}  Best-F1={np.mean(best_f1_win):.3f}\n"
      f"Recording AUC: {np.nanmean(aucs_rec):.3f}±{np.nanstd(aucs_rec):.3f}\n"
      f"Subject   AUC: {np.nanmean(aucs_subj):.3f}±{np.nanstd(aucs_subj):.3f}")




Fold 1: win AUC=0.746 ACC=0.772 PR-AUC=0.032 Best-F1=0.099 | rec AUC=0.659 | subj AUC=0.659 (n_win=12000, n_rec=24, n_subj=24, pos_rate=0.0102)
Fold 2: win AUC=0.828 ACC=0.621 PR-AUC=0.230 Best-F1=0.289 | rec AUC=0.552 | subj AUC=0.552 (n_win=12000, n_rec=24, n_subj=24, pos_rate=0.0307)
Fold 3: win AUC=0.363 ACC=0.652 PR-AUC=0.009 Best-F1=0.029 | rec AUC=0.385 | subj AUC=0.385 (n_win=12000, n_rec=24, n_subj=24, pos_rate=0.0132)
Fold 4: win AUC=0.776 ACC=0.692 PR-AUC=0.025 Best-F1=0.058 | rec AUC=0.454 | subj AUC=0.454 (n_win=12000, n_rec=24, n_subj=24, pos_rate=0.0111)
Fold 5: win AUC=0.585 ACC=0.793 PR-AUC=0.018 Best-F1=0.052 | rec AUC=0.696 | subj AUC=0.696 (n_win=12000, n_rec=24, n_subj=24, pos_rate=0.0144)

MEANS → Window:  AUC=0.660±0.169  ACC=0.706±0.067  PR-AUC=0.063  Best-F1=0.106
Recording AUC: 0.549±0.118
Subject   AUC: 0.549±0.118
