In [None]:
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler
import itertools, json, csv

# ---------- PATHS ----------
RESULTS_DIR = Path("mlp_results")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# ---------- SHARED CONFIG ----------
model_args = {
    "in_dim": None,   # will be set dynamically (depends on combo)
    "h1": 64,
    "h_pre": 32,
    "dropout": 0.1,
}

train_args = {
    "seed": 42,
    "epochs": 200,
    "batch_size": 64,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "patience": 30,
    "val_size": 0.2,   # for LOSO
}

loocv_args = {
    "seed": 42,
    "epochs": 150,
    "batch_size": 32,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "patience": 30,
    "val_size": 0.2,   # inside-subject train/val split on remaining trials
}



# Model

In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim: int, h1=64, h_pre=32, dropout=0.1):
        super().__init__()
        self.feat = nn.Sequential(
            nn.Linear(in_dim, h1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(h1, h_pre),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.out = nn.Linear(h_pre, 1)  # logits

    def forward(self, x):
        x = self.feat(x)
        return self.out(x).squeeze(-1)  # logits



# Helper functions

In [None]:


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device():
    return "cuda" if torch.cuda.is_available() else "cpu"


def scale_train_test(X_train_raw, X_test_raw):
    scaler = RobustScaler()
    X_train = scaler.fit_transform(X_train_raw)
    X_test = scaler.transform(X_test_raw)
    return X_train, X_test, scaler


def stratified_train_val_split(X_train, y_train, val_size=0.1, seed=42):
    return train_test_split(
        X_train, y_train,
        test_size=val_size,
        random_state=seed,
        stratify=y_train
    )


@torch.no_grad()
def predict_proba(model: nn.Module, X: np.ndarray, device: str) -> np.ndarray:
    model.eval()
    xb = torch.tensor(X, dtype=torch.float32, device=device)
    logits = model(xb)
    probs = torch.sigmoid(logits)
    return probs.detach().cpu().numpy()


def accuracy_from_probs(probs: np.ndarray, y_true: np.ndarray, thr: float = 0.5) -> float:
    y_pred = (probs >= thr).astype(np.int64)
    return float((y_pred == y_true).mean())


@torch.no_grad()
def eval_on_loader_bce(model: nn.Module, loader: DataLoader, device: str):
    model.eval()
    loss_fn = nn.BCEWithLogitsLoss()
    losses, correct, total = [], 0, 0

    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)

        logits = model(xb)
        loss = loss_fn(logits, yb)
        losses.append(loss.item())

        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).float()
        correct += (preds == yb).sum().item()
        total += yb.numel()

    return float(np.mean(losses)), (correct / total if total else 0.0)


def train_one_split_bce(
    model: nn.Module,
    X_tr, y_tr, X_val, y_val,
    epochs=200, batch_size=64, lr=1e-3, weight_decay=1e-4,
    patience=25, device="cuda", log_every=50
):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.BCEWithLogitsLoss()

    train_ds = TensorDataset(
        torch.tensor(X_tr, dtype=torch.float32),
        torch.tensor(y_tr, dtype=torch.float32),
    )
    val_ds = TensorDataset(
        torch.tensor(X_val, dtype=torch.float32),
        torch.tensor(y_val, dtype=torch.float32),
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    best_val = float("inf")
    best_state = None
    bad = 0

    model.to(device)

    for epoch in range(epochs):
        model.train()
        train_losses = []

        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)

            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()
            train_losses.append(loss.item())

        mean_val, val_acc = eval_on_loader_bce(model, val_loader, device=device)

        if mean_val < best_val - 1e-6:
            best_val = mean_val
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1

        # if epoch % log_every == 0:
        #     mean_train = float(np.mean(train_losses)) if train_losses else float("inf")
        #     print(f"epoch {epoch:03d} | train_loss={mean_train:.4f} | val_loss={mean_val:.4f} | val_acc={val_acc*100:.2f}%")

        if bad >= patience:
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    return model

In [None]:
def extract_subject_combo(X, y, subject_idx: int, combo: tuple):
    """
    Returns X_sub_flat, y_sub for one subject and one combo.
    X_sub_flat shape: (50, len(combo)*n_feat)
    y_sub shape: (50,)
    """
    n_feat = X.shape[-1]
    n_total_trials = X.shape[1] * X.shape[2]  # 2*25=50

    X_sub = X[subject_idx, :, :, combo, :].reshape(n_total_trials, len(combo), n_feat)
    y_sub = y[subject_idx, :, :].reshape(n_total_trials).astype(np.int64)

    X_sub_flat = X_sub.reshape(n_total_trials, -1)
    return X_sub_flat, y_sub


def extract_loso_combo(X, y, test_subj: int, combo: tuple):
    """
    Returns raw train/test for LOSO for a given combo.
    X_train_raw shape: ((n_subjects-1)*50, len(combo)*n_feat)
    X_test_raw shape: (50, len(combo)*n_feat)
    """
    n_subjects = X.shape[0]
    train_subjects = [s for s in range(n_subjects) if s != test_subj]

    X_train_list, y_train_list = [], []
    for s in train_subjects:
        Xs, ys = extract_subject_combo(X, y, s, combo)
        X_train_list.append(Xs)
        y_train_list.append(ys)

    X_train_raw = np.concatenate(X_train_list, axis=0)
    y_train = np.concatenate(y_train_list, axis=0)

    X_test_raw, y_test = extract_subject_combo(X, y, test_subj, combo)
    return X_train_raw, y_train, X_test_raw, y_test


# Subject-dependent LOOCV

In [None]:
def run_subject_loocv(X, y, subject_idx: int, combo: tuple, channel_label: str, cfg_model, cfg_train):
    set_seed(cfg_train["seed"])
    device = get_device()

    X_sub_raw, y_sub = extract_subject_combo(X, y, subject_idx, combo)

    scaler = RobustScaler(quantile_range=(25, 75))
    X_sub = scaler.fit_transform(X_sub_raw)

    n_samples = X_sub.shape[0]  # 50
    fold_accs = []

    for holdout in tqdm(range(n_samples), desc=f"Training loocv"):
        X_test = X_sub[holdout:holdout+1]
        y_test = y_sub[holdout:holdout+1]

        X_train = np.delete(X_sub, holdout, axis=0)
        y_train = np.delete(y_sub, holdout, axis=0)

        X_tr, X_val, y_tr, y_val = train_test_split(
            X_train, y_train,
            test_size=cfg_train["val_size"],
            random_state=cfg_train["seed"],
            stratify=y_train
        )

        in_dim = X_tr.shape[1]
        margs = dict(cfg_model)
        margs["in_dim"] = in_dim

        model = MLP(**margs)
        model = train_one_split_bce(
            model, X_tr, y_tr, X_val, y_val,
            epochs=cfg_train["epochs"],
            batch_size=cfg_train["batch_size"],
            lr=cfg_train["lr"],
            weight_decay=cfg_train["weight_decay"],
            patience=cfg_train["patience"],
            device=device,
            log_every=999999
        )

        probs = predict_proba(model, X_test, device=device)
        fold_accs.append(accuracy_from_probs(probs, y_test))

    mean_acc = float(np.mean(fold_accs))
    return {
        "subject": f"S{subject_idx+1:02d}",
        "combo": channel_label,
        "acc": mean_acc
    }


# Cross subject LOSO

In [None]:
def run_cross_subject_loso(X, y, combo: tuple, channel_label: str, cfg_model, cfg_train):
    set_seed(cfg_train["seed"])
    device = get_device()

    n_subjects = X.shape[0]
    fold_accs = []

    for test_subj in  tqdm(range(n_subjects), desc=f"Training loso"):
        X_train_raw, y_train, X_test_raw, y_test = extract_loso_combo(X, y, test_subj, combo)

        X_train, X_test, _ = scale_train_test(X_train_raw, X_test_raw)
        X_tr, X_val, y_tr, y_val = stratified_train_val_split(
            X_train, y_train,
            val_size=cfg_train["val_size"],
            seed=cfg_train["seed"]
        )

        in_dim = X_tr.shape[1]
        margs = dict(cfg_model)
        margs["in_dim"] = in_dim

        model = MLP(**margs)
        model = train_one_split_bce(
            model, X_tr, y_tr, X_val, y_val,
            epochs=cfg_train["epochs"],
            batch_size=cfg_train["batch_size"],
            lr=cfg_train["lr"],
            weight_decay=cfg_train["weight_decay"],
            patience=cfg_train["patience"],
            device=device,
            log_every=50
        )

        probs = predict_proba(model, X_test, device=device)
        acc = accuracy_from_probs(probs, y_test)
        fold_accs.append(acc)

        print(f"{channel_label} | test S{test_subj+1:02d}: {acc*100:.2f}%")

    return {
        "combo": channel_label,
        "per_subject": [float(a) for a in fold_accs],
        "mean": float(np.mean(fold_accs)),
        "std": float(np.std(fold_accs)),
    }

In [None]:
def all_combos(n_channels: int, min_len=1, max_len=None):
    if max_len is None:
        max_len = n_channels
    out = []
    for r in range(min_len, max_len + 1):
        out.extend(list(itertools.combinations(range(n_channels), r)))
    return out


def combo_label(combo: tuple, channel_names: list):
    return "+".join(channel_names[i] for i in combo)

# Optuna Search

In [None]:
import optuna
from src.feature_extraction import load_data

def optuna_objective_factory(channel_idx: int, seed: int = 42, n_trials_cache=None):
    X, y = load_data()
    n_subjects = X.shape[0]
    n_feat = X.shape[-1]
    device = "cpu"  # keep optuna cheap; switch to cuda if you really want

    def objective(trial):
        set_seed(seed)

        h1 = trial.suggest_int("h1", 32, 256, step=32)
        h_pre = trial.suggest_int("h_pre", 8, 64, step=8)
        dropout = trial.suggest_float("dropout", 0.0, 0.3)
        lr = trial.suggest_float("lr", 1e-4, 3e-3, log=True)
        batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])

        weight_decay = 1e-4
        val_size = 0.1
        patience = 25
        epochs = 300

        fold_accs = []
        for test_subj in range(n_subjects):
            X_train_raw, y_train, X_test_raw, y_test = make_loso_channel_split(X, y, test_subj, channel_idx)
            X_train, X_test, _ = scale_train_test(X_train_raw, X_test_raw)
            X_tr, X_val, y_tr, y_val = stratified_train_val_split(X_train, y_train, val_size=val_size, seed=seed)

            model = MLP(in_dim=n_feat, h1=h1, h_pre=h_pre, dropout=dropout)
            model = train_bce_logits(
                model,
                X_tr, y_tr, X_val, y_val,
                epochs=epochs, batch_size=batch_size, lr=lr, weight_decay=weight_decay,
                patience=patience, device=device, log_every=999999
            )

            probs = predict_proba(model, X_test, device=device)
            acc = accuracy_from_probs(probs, y_test)
            fold_accs.append(acc)

            # pruning
            trial.report(float(np.mean(fold_accs)), step=test_subj)
            if trial.should_prune():
                raise optuna.TrialPruned()

        return float(np.mean(fold_accs))

    return objective


def run_optuna(channel_idx: int, n_trials: int = 15, seed: int = 42):
    import optuna
    pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=5)

    study = optuna.create_study(direction="maximize", pruner=pruner,
                                study_name=f"mlp_channel_{channel_idx}_search")

    study.optimize(optuna_objective_factory(channel_idx, seed=seed), n_trials=n_trials, n_jobs=1)

    best = {
        "channel": channel_idx,
        "best_value": float(study.best_trial.value),
        "best_params": dict(study.best_trial.params),
        "n_trials": n_trials,
        "seed": seed,
        "study_name": study.study_name,
    }
    return best


def run_optuna_all_channels(n_channels=5, n_trials=15, out_path=None):
    all_results = {}
    for ch in range(n_channels):
        print(f"\n=== OPTUNA SEARCH: CHANNEL {ch} ===")
        best = run_optuna(channel_idx=ch, n_trials=n_trials)
        all_results[f"channel_{ch}"] = best
        print(best)

    if out_path is None:
        out_path = RESULTS_DIR / "results_optuna_mlp_per_channel.json"

    with open(out_path, "w") as f:
        json.dump(all_results, f, indent=2)

    print(f"Saved Optuna results to: {out_path.resolve()}")
    return all_results


In [None]:
from src.feature_extraction import load_data

channel_names = ["C1", "C2", "C3", "C4", "C5"]
X, y = load_data()

# ---- A) SUBJECT-DEPENDENT LOOCV for ALL SINGLE CHANNELS ----
loocv_single_results = []
for combo in all_combos(X.shape[3], min_len=1):
    label = combo_label(combo, channel_names)
    for s in range(X.shape[0]):
        r = run_subject_loocv(X, y, s, combo, label, model_args, loocv_args)
        loocv_single_results.append(r)
        print(label, r["subject"], f'{r["acc"]*100:.2f}%')

with open(RESULTS_DIR / "loocv_single_channels.json", "w") as f:
    json.dump(loocv_single_results, f, indent=2)

# ---- B) CROSS-SUBJECT LOSO for ALL MULTI-CHANNEL COMBOS ----
loso_combo_results = {}
for combo in all_combos(X.shape[3], min_len=1):
    label = combo_label(combo, channel_names)
    r = run_cross_subject_loso(X, y, combo, label, model_args, train_args)
    loso_combo_results[label] = r
    print("=>", label, f'{r["mean"]*100:.2f}% Â± {r["std"]*100:.2f}%')

with open(RESULTS_DIR / "loso_multichannel_combos.json", "w") as f:
    json.dump(loso_combo_results, f, indent=2)