In [20]:
import torch
import torch.nn as nn
import numpy as np
import itertools
import torch.optim as optim
from typing import Literal
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
import json

from pathlib import Path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

RESULTS_DIR = Path("./results_mlp_multiple_channels")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
OUT_PATH = Path("./data/data.npz")
CHANNELS = ["EEG.AF3", "EEG.T7", "EEG.Pz", "EEG.T8", "EEG.AF4"]
CHANNEL_TO_IDX = { c: idx for idx, c in enumerate(CHANNELS) }

Using device: cuda


In [21]:
def load_data():
    """
    Load the dataset from .npz file.
    
    Returns:
        X: (27, 2, 25, 5, n_features)
        y: (27, 2, 25)
    """
    if not OUT_PATH.exists():
        raise FileNotFoundError(f"File {OUT_PATH} doesn't exist. Run preprocess_data() first.")
    data = np.load(OUT_PATH)
    
    return data['X'], data['y']

In [22]:
X, y = load_data()
X.shape

(27, 2, 25, 5, 16)

# Model

In [23]:
class MLP(nn.Module):
    def __init__(self, in_dim: int, h1=32, h_pre=32, dropout=0.1):
        super().__init__()
        self.in_dim = in_dim
        self.feat = nn.Sequential(
            nn.Linear(in_dim, h1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(h1, h_pre),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.out = nn.Linear(h_pre, 1)

    def forward(self, x):
        x = x.reshape(x.size(0), -1)
        if x.size(1) != self.in_dim:
            raise ValueError(f"Expected {self.in_dim} features, got {x.size(1)}")
        x = self.feat(x)
        return self.out(x)          # (batch, 1)



# Cross subject LOSO

In [24]:
def cross_subject_loso(X, y, channels, model_args, epochs=30):
    n_subjects = X.shape[0]
    channel_indices = list(range(len(channels)))
    results = {}

    # Generate all non-empty combinations of channels
    all_combos = []
    for r in range(1, len(channels) + 1):
        all_combos.extend(list(itertools.combinations(channel_indices, r)))

    for combo in all_combos:
        combo_names = [channels[i] for i in combo]
        combo_key = ", ".join(combo_names)
        print(f"\n===== Analyzing Combo: {combo_key} =====")

        loso_accs = []

        for test_sub in range(n_subjects):
            train_subs = [i for i in range(n_subjects) if i != test_sub]

            # --- TRAIN: (n_train_subs, 2, 25, 5, 16) -> (n_train_subs*50, 5, 16) -> select combo -> (N, |combo|, 16)
            X_train_np = X[train_subs].reshape(-1, X.shape[-2], X.shape[-1])   # (-1, 5, 16)
            y_train_np = y[train_subs].reshape(-1)                              # (-1,)

            X_train = torch.FloatTensor(X_train_np[:, combo, :]).to(device)    # (N, |combo|, 16)
            y_train = torch.FloatTensor(y_train_np).to(device).unsqueeze(1)    # (N, 1)

            # --- TEST: (1, 2, 25, 5, 16) -> (50, 5, 16) -> select combo -> (50, |combo|, 16)
            X_test_np = X[test_sub].reshape(-1, X.shape[-2], X.shape[-1])      # (50, 5, 16)
            y_test_np = y[test_sub].reshape(-1)                                 # (50,)

            X_test = torch.FloatTensor(X_test_np[:, combo, :]).to(device)      # (50, |combo|, 16)
            y_test = y_test_np  # keep numpy for accuracy_score

            # --- Model args (DON'T mutate model_args)
            feat_dim = X.shape[-1]                # 16
            in_dim = len(combo) * feat_dim        # |combo|*16
            margs = dict(model_args)
            margs["in_dim"] = in_dim

            model = MLP(**margs).to(device)
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            criterion = nn.BCEWithLogitsLoss()

            # Training loop
            model.train()
            for epoch in range(epochs):
                optimizer.zero_grad()
                logits = model(X_train)               # expect (N, 1) if your MLP is corrected
                loss = criterion(logits, y_train)     # y_train is (N, 1)
                loss.backward()
                optimizer.step()

            # Evaluation
            model.eval()
            with torch.inference_mode():
                logits = model(X_test)  # (50, 1)
                preds = (logits > 0).float().cpu().numpy().reshape(-1)  # (50,)
                acc = accuracy_score(y_test.reshape(-1), preds)
                loso_accs.append(acc)

        mean_acc = float(np.mean(loso_accs))
        std_acc  = float(np.std(loso_accs))
        results[combo_key] = {"mean": mean_acc, "std": std_acc, "per_subject": loso_accs}
        print(f"Combo Result -> Mean: {mean_acc:.4f} | STD: {std_acc:.4f}")

    return results

# Within subjects LOOCV

In [25]:
def within_subject_loocv(X, y, channels, model_args, epochs=30):
    n_subjects = X.shape[0]
    channel_indices = list(range(len(channels)))
    results = {}
    print(X.shape, y.shape)

    all_combos = []
    for r in range(1, len(channels) + 1):
        all_combos.extend(list(itertools.combinations(channel_indices, r)))

    for combo in all_combos:
        combo_names = [channels[i] for i in combo]
        combo_key = ", ".join(combo_names)
        print(f"\n===== Analyzing Combo: {combo_key} =====")

        subject_accuracies = []
        per_subject = {}  # NEW: store each subject's mean LOOCV acc for this combo

        for sub in range(n_subjects):
            X_sub = X[sub]   # (2, 25, 5, 16)
            y_sub = y[sub]   # (2, 25)

            # (2*25, 5, 16) -> (50, 5, 16), and (2*25,) -> (50,)
            X_sub = X_sub.reshape(-1, X_sub.shape[-2], X_sub.shape[-1])
            y_sub = y_sub.reshape(-1)

            n_trials = X_sub.shape[0]
            feat_dim = X_sub.shape[-1]  # 16

            X_trials = torch.FloatTensor(X_sub[:, combo, :]).to(device)      # (50, |combo|, 16)
            y_trials = torch.FloatTensor(y_sub).to(device).unsqueeze(1)      # (50, 1)

            loocv_accs = []

            for test_idx in range(n_trials):
                train_idx = [i for i in range(n_trials) if i != test_idx]

                X_train, X_test = X_trials[train_idx], X_trials[test_idx:test_idx + 1]
                y_train, y_test = y_trials[train_idx], y_trials[test_idx:test_idx + 1]

                model_args["in_dim"] = feat_dim * len(combo)

                model = MLP(**model_args).to(device)
                optimizer = optim.Adam(model.parameters(), lr=1e-3)
                criterion = nn.BCEWithLogitsLoss()

                model.train()
                for epoch in range(epochs):
                    optimizer.zero_grad()
                    outputs = model(X_train)
                    loss = criterion(outputs, y_train)
                    loss.backward()
                    optimizer.step()

                model.eval()
                with torch.inference_mode():
                    logits = model(X_test)
                    pred = (logits > 0).float().cpu().numpy().reshape(-1)
                    yt = y_test.cpu().numpy().reshape(-1)
                    acc = accuracy_score(yt, pred)
                    loocv_accs.append(acc)

            sub_acc = float(np.mean(loocv_accs))
            subject_accuracies.append(sub_acc)
            per_subject[f"s{sub+1}"] = sub_acc  # NEW

        mean_acc = float(np.mean(subject_accuracies))
        std_acc = float(np.std(subject_accuracies))

        results[combo_key] = {
            "mean": mean_acc,
            "std": std_acc,
            "per_subject": per_subject
        }

        print(f"Combo Result -> Mean: {mean_acc:.4f} | STD: {std_acc:.4f}")

    return results


In [26]:
model_args = {
    "in_dim":16,
      "h1": 64,
      "h_pre": 32,
      "dropout": 0.1
}

# Within-subject LOOCV
within_results = within_subject_loocv(X, y, CHANNELS, model_args, epochs=20)
out_path_1 = RESULTS_DIR / f"within_subject_loocv.json"
with out_path_1.open("w", encoding="utf-8") as f:
        json.dump(within_results, f, indent=2, sort_keys=False)
# Cross-subject LOSO
cross_results = cross_subject_loso(X, y, CHANNELS, model_args, epochs=20)
out_path_2 = RESULTS_DIR / f"cross_subject_loso.json"
with out_path_2.open("w", encoding="utf-8") as f:
        json.dump(cross_results, f, indent=2, sort_keys=False)

(27, 2, 25, 5, 16) (27, 2, 25)

===== Analyzing Combo: EEG.AF3 =====
Combo Result -> Mean: 0.3726 | STD: 0.1361

===== Analyzing Combo: EEG.T7 =====
Combo Result -> Mean: 0.3296 | STD: 0.1558

===== Analyzing Combo: EEG.Pz =====
Combo Result -> Mean: 0.3222 | STD: 0.1729

===== Analyzing Combo: EEG.T8 =====
Combo Result -> Mean: 0.3400 | STD: 0.1744

===== Analyzing Combo: EEG.AF4 =====
Combo Result -> Mean: 0.3244 | STD: 0.1725

===== Analyzing Combo: EEG.AF3, EEG.T7 =====
Combo Result -> Mean: 0.4089 | STD: 0.1662

===== Analyzing Combo: EEG.AF3, EEG.Pz =====
Combo Result -> Mean: 0.4452 | STD: 0.1460

===== Analyzing Combo: EEG.AF3, EEG.T8 =====
Combo Result -> Mean: 0.4430 | STD: 0.1676

===== Analyzing Combo: EEG.AF3, EEG.AF4 =====
Combo Result -> Mean: 0.3800 | STD: 0.1567

===== Analyzing Combo: EEG.T7, EEG.Pz =====
Combo Result -> Mean: 0.3630 | STD: 0.1534

===== Analyzing Combo: EEG.T7, EEG.T8 =====
Combo Result -> Mean: 0.3637 | STD: 0.1435

===== Analyzing Combo: EEG.T7, EE