In [2]:
# run_refactored.py
from __future__ import annotations
from pathlib import Path
import numpy as np
import pandas as pd

from config import SUBJECT_META, SESSION_NAMES                   # labels, male, age  :contentReference[oaicite:21]{index=21}
from dataset import load_12_subject_array, concat_sessions_timewise_from_npy  # loaders  :contentReference[oaicite:22]{index=22}
from model import EEGPanelIFEMI

In [4]:
def main():
    arr = load_12_subject_array(".\data\clean_EC.npy")  # shape (12, 4, 32, 90000)  :contentReference[oaicite:23]{index=23}
    assert arr.shape[0] == len(SUBJECT_META), "Subject meta must match npy subjects."

    model = EEGPanelIFEMI(K_vmf=4, r_grid=[1,2,3], random_state=42)
    rows = []
    for s in range(arr.shape[0]):
        label, male1, age = SUBJECT_META[s]
        X, sessions = concat_sessions_timewise_from_npy(arr[s], SESSION_NAMES)  # (T,C), session tags  :contentReference[oaicite:24]{index=24}
        rep = model.fit(
            X=X,
            sessions=sessions,
            sex_male1=float(male1),       # or None to test EM sex-imputation
            age_years=float(age),         # or None to test EM age-imputation
            task_rest1=1.0,
            em_iters=2
        )
        rows.append({
            "subject": label,
            "rank": rep.best_r,
            "train_mse": rep.train_metrics["mse"],
            "train_r2": rep.train_metrics["r2"],
            "test_mse": rep.test_metrics["mse"],
            "test_r2": rep.test_metrics["r2"],
        })
        print(f"[Done] {label}: r={rep.best_r}, train={rep.train_metrics}, test={rep.test_metrics}")

    df = pd.DataFrame(rows).sort_values("subject")
    df.to_csv("summary_refactored.csv", index=False)
    print("\n=== Summary (refactored) ===")
    print(df)

if __name__ == "__main__":
    main()


[Done] AM: r=3, train={'mse': 2.5304067061511027e-09, 'r2': 0.7680654868729337}, test={'mse': 3.2411925913398293e-09, 'r2': 0.48772695661375853}
[Done] CL: r=3, train={'mse': 2.0597690101877677e-09, 'r2': 0.7726575318625819}, test={'mse': 2.019799314633816e-09, 'r2': 0.7455883708570217}
[Done] CQ: r=3, train={'mse': 9.318501805074211e-09, 'r2': 0.6552150013749715}, test={'mse': 7.588207961707672e-09, 'r2': 0.6583102596340841}
[Done] DB: r=3, train={'mse': 7.755242666065321e-09, 'r2': 0.9901510515257869}, test={'mse': 1.0630285472994263e-09, 'r2': 0.22970617540649974}
[Done] DC: r=3, train={'mse': 2.477168771430117e-09, 'r2': 0.6365434411203066}, test={'mse': 2.3526595965684635e-09, 'r2': 0.49807054205290857}
[Done] DL: r=3, train={'mse': 4.246590110190278e-09, 'r2': 0.6614735508904122}, test={'mse': 4.437745259259013e-09, 'r2': 0.6946420174283544}
[Done] ErL: r=3, train={'mse': 1.5962363681709411e-09, 'r2': 0.7607953217468155}, test={'mse': 1.1033169268119235e-07, 'r2': 0.1198437029022