In [45]:
import torch
import torch.nn as nn
import numpy as np
import itertools
import torch.optim as optim

from typing import Literal
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

DATA_PATH = '/kaggle/input/bio-project-data/data.npz'
CHANNELS = ["EEG.AF3", "EEG.T7", "EEG.Pz", "EEG.T8", "EEG.AF4"]
CHANNEL_TO_IDX = { c: idx for idx, c in enumerate(CHANNELS) }

Using device: cuda


In [46]:
def load_data(data_path, normalize: bool = True):
    data = np.load(data_path)
    X, y = data["X_raw"], data["y"]

    if X.ndim != 5 or y.ndim != 3:
        raise ValueError("X must be 5D and y 3D")

    n_sub, n_sess, n_chan, n_trials, n_samples = X.shape

    print(f"X original: {X.shape}")
    print(f"y original: {y.shape}")

    if normalize:
        X_mean = X.mean(axis=-1, keepdims=True)
        X_std = X.std(axis=-1, keepdims=True) + 1e-6
        X = (X - X_mean) / X_std
        
    X_transposed = X.transpose(0, 1, 3, 2, 4)  # (sub, sess, trial, chan, time)
    X_new = X_transposed.reshape(n_sub, n_sess * n_trials, n_chan, n_samples)
    y_new = y.reshape(n_sub, n_sess * n_trials)

    if X_new.shape[1] != y_new.shape[1]:
        raise RuntimeError(f"Mismatch: X_new {X_new.shape}, y_new {y_new.shape}")

    unique_labels = np.unique(y_new)
    if not set(unique_labels).issubset({0, 1}):
        raise RuntimeError(f"Non binary labels found: {unique_labels}")

    print(f"X final: {X_new.shape}")
    print(f"y final: {y_new.shape}")

    return X_new, y_new

In [47]:
X, y = load_data(DATA_PATH, True)

X original: (27, 2, 5, 25, 384)
y original: (27, 2, 25)
X final: (27, 50, 5, 384)
y final: (27, 50)


# Model

In [61]:
class EEGLieFeatureExtractor(nn.Module):
    def __init__(self, n_channels=5, seq_len=384, tcn_channels=64, kernel_size=5,
                 lstm_hidden=128, lstm_layers=2, dropout=0.3):
        super().__init__()
        
        # TCN: stack of 1D convs
        self.tcn = nn.Sequential(
            nn.Conv1d(n_channels, tcn_channels, kernel_size=kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.LayerNorm([tcn_channels, seq_len]),
            nn.Conv1d(tcn_channels, tcn_channels, kernel_size=kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.LayerNorm([tcn_channels, seq_len]),
            nn.MaxPool1d(2)  # Reduce sequence length
        )
        
        # BiLSTM for long-term temporal dynamics
        self.lstm = nn.LSTM(
            input_size=tcn_channels,
            hidden_size=lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True
        )
        
        # Head (feature vector output)
        self.classifier = nn.Linear(2 * lstm_hidden, 1)
    
    def forward(self, x):
        # TCN
        x = self.tcn(x)                     # (batch, tcn_channels, seq_len//2)
        x = x.transpose(1, 2)               # (batch, seq_len//2, tcn_channels) for LSTM
        
        # LSTM
        lstm_out, _ = self.lstm(x)          # (batch, seq_len//2, 2*lstm_hidden)
        
        return self.classifier(lstm_out[:, -1, :])  # (batch, 1)

# Cross subject LOSO

In [56]:
def cross_subject_loso(X, y, channels, model_args, epochs=30):
    n_subjects = X.shape[0]
    channel_indices = list(range(len(channels)))
    results = {}
    
    # Generate all non-empty combinations of channels
    all_combos = []
    for r in range(1, len(channels) + 1):
        all_combos.extend(list(itertools.combinations(channel_indices, r)))
    
    for combo in all_combos:
        combo_names = [channels[i] for i in combo]
        combo_key = ", ".join(combo_names)
        print(f"\n===== Analyzing Combo: {combo_key} =====")
        subject_accuracies = []
        
        for test_sub in range(n_subjects):
            train_subs = [i for i in range(n_subjects) if i != test_sub]
            
            # Prepare train and test tensors
            X_train = torch.FloatTensor(X[train_subs][:, :, combo, :]).reshape(-1, len(combo), 384).to(device)
            y_train = torch.FloatTensor(y[train_subs]).reshape(-1, 1).to(device)
            
            X_test = torch.FloatTensor(X[test_sub:test_sub+1][:, :, combo, :]).reshape(-1, len(combo), 384).to(device)
            y_test = y[test_sub].flatten()
            
            # Instantiate model
            model_args['n_channels'] = len(combo)
            model = EEGLieFeatureExtractor(**model_args).to(device)
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            criterion = nn.BCEWithLogitsLoss()  # use logits later if needed
            
            # Training loop
            model.train()
            for epoch in range(epochs):
                optimizer.zero_grad()
                outputs = model(X_train)
                loss = criterion(outputs, y_train)
                loss.backward()
                optimizer.step()
            
            # Evaluation
            model.eval()
            with torch.inference_mode():
                logits = model(X_test)
                preds = (logits > 0).float().cpu().numpy()
                acc = accuracy_score(y_test, preds)
                subject_accuracies.append(acc)
        
        # Compute mean and std accuracy
        mean_acc = np.mean(subject_accuracies)
        std_acc = np.std(subject_accuracies)
        results[combo_key] = {'mean': mean_acc, 'std': std_acc}
        print(f"Combo Result -> Mean: {mean_acc:.4f} | STD: {std_acc:.4f}")
    
    return results


# Within subjects LOOCV

In [54]:

def within_subject_loocv(X, y, channels, model_args, epochs=30):
    n_subjects = X.shape[0]
    channel_indices = list(range(len(channels)))
    results = {}
    
    # Generate all non-empty channel combinations
    all_combos = []
    for r in range(1, len(channels)+1):
        all_combos.extend(list(itertools.combinations(channel_indices, r)))
    
    for combo in all_combos:
        combo_names = [channels[i] for i in combo]
        combo_key = ", ".join(combo_names)
        print(f"\n===== Analyzing Combo: {combo_key} =====")
        subject_accuracies = []
        
        # Loop over subjects
        for sub in range(n_subjects):
            n_trials = X.shape[1]
            X_trials = torch.FloatTensor(X[sub][:, combo, :]).to(device)  # (n_trials, n_channels, seq_len)
            y_trials = torch.FloatTensor(y[sub].reshape(-1,1)).to(device)
            
            loocv_accs = []
            
            # LOOCV: leave one trial out
            for test_idx in range(n_trials):
                train_idx = [i for i in range(n_trials) if i != test_idx]
                
                X_train, X_test = X_trials[train_idx], X_trials[test_idx:test_idx+1]
                y_train, y_test = y_trials[train_idx], y_trials[test_idx:test_idx+1]
                
                # Instantiate model
                model_args['n_channels'] = len(combo)
                model = EEGLieFeatureExtractor(**model_args).to(device)
                optimizer = optim.Adam(model.parameters(), lr=1e-3)
                criterion = nn.BCEWithLogitsLoss()
                
                # Training
                model.train()
                for epoch in range(epochs):
                    optimizer.zero_grad()
                    outputs = model(X_train)
                    loss = criterion(outputs, y_train)
                    loss.backward()
                    optimizer.step()
                
                # Evaluation
                model.eval()
                with torch.inference_mode():
                    logits = model(X_test)
                    pred = (logits > 0).float().cpu().numpy()
                    acc = accuracy_score(y_test.cpu().numpy(), pred)
                    loocv_accs.append(acc)
            
            # Average LOOCV accuracy for the subject
            mean_sub_acc = np.mean(loocv_accs)
            subject_accuracies.append(mean_sub_acc)
        
        # Average across subjects
        mean_acc = np.mean(subject_accuracies)
        std_acc = np.std(subject_accuracies)
        results[combo_key] = {'mean': mean_acc, 'std': std_acc}
        print(f"Combo Result -> Mean: {mean_acc:.4f} | STD: {std_acc:.4f}")
    
    return results

In [63]:
model_args = {
    "seq_len": 384,
    "tcn_channels": 64,
    "kernel_size": 5,
    "lstm_hidden": 128,
    "lstm_layers": 2,
    "dropout": 0.3
}

# X shape: (n_subjects, n_trials, n_channels, seq_len)
# y shape: (n_subjects, n_trials)

# Within-subject LOOCV
within_results = within_subject_loocv(X, y, CHANNELS, model_args, epochs=20)

# Cross-subject LOSO
cross_results = cross_subject_loso(X, y, CHANNELS, model_args, epochs=20)


===== Analyzing Combo: EEG.AF3 =====


KeyboardInterrupt: 