In [None]:
# EEG Motor Imagery Classification Pipeline

# --- INSTALL DEPENDENCIES ---
# Run in Colab or locally to install required libraries
!pip install mne torch scikit-learn matplotlib --quiet

# --- IMPORT LIBRARIES ---
import os
import urllib.request
import numpy as np
import mne
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_score
from mne.decoding import CSP
import matplotlib.pyplot as plt

# --- PARAMETERS ---
subject_ids = list(range(1, 12))  # Subjects 1 to 11
recording_id = 8  # Motor imagery task: left and right hand (S###R08.edf)
tmin, tmax = 0, 2  # Time window in seconds for epochs

# --- DATA LOADING FUNCTION ---
def load_and_preprocess(subject_id):
    filename = f"S{subject_id:03d}R{recording_id:02d}.edf"
    url = f"https://physionet.org/static/published-projects/eegmmidb/1.0.0/S{subject_id:03d}/{filename}"

    if not os.path.exists(filename):
        urllib.request.urlretrieve(url, filename)

    raw = mne.io.read_raw_edf(filename, preload=True, stim_channel='auto', verbose=False)
    raw.rename_channels(lambda x: x.strip('.'))
    raw.set_montage('standard_1020', on_missing='ignore')
    raw.filter(1., 40., fir_design='firwin', verbose=False)

    events, event_id = mne.events_from_annotations(raw, verbose=False)
    relevant_event_ids = {k: v for k, v in event_id.items() if v in [2, 3]}
    if not relevant_event_ids:
        return None, None

    epochs = mne.Epochs(raw, events, event_id=relevant_event_ids, tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False)
    epochs.pick_types(eeg=True)  # Keep only EEG channels
    labels = epochs.events[:, 2]
    return epochs, labels

all_epochs = []
all_labels = []

for sid in subject_ids:
    epochs, labels = load_and_preprocess(sid)
    if epochs is not None:
        all_epochs.append(epochs)
        all_labels.append(labels)

if not all_epochs:
    raise RuntimeError("No valid data loaded.")

import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    epochs = mne.concatenate_epochs(all_epochs)

y = np.concatenate(all_labels)
y_bin = (y == 3).astype(int)
X = epochs.get_data().astype(np.float64)

# CSP + LDA pipeline
csp = CSP(n_components=4, reg='ledoit_wolf', log=True, norm_trace=False)
lda = LinearDiscriminantAnalysis()
clf = Pipeline([('CSP', csp), ('LDA', lda)])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(clf, X, y_bin, cv=cv, n_jobs=1)
print(f"CSP + LDA Accuracy: {np.mean(scores)*100:.2f}% (+/- {np.std(scores)*100:.2f}%)")

# Prepare data for RNN
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y_bin, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
train_dl = DataLoader(train_ds, batch_size=10, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=10)

class EEG_RNN(nn.Module):
    def __init__(self, input_size=64, hidden_size=64, num_layers=2, num_classes=2):
        super().__init__()
        self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B, C, T) -> (B, T, C)
        out, _ = self.rnn(x)
        return self.fc(out[:, -1, :])

model = EEG_RNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(20):
    model.train()
    total_loss = 0
    for xb, yb in train_dl:
        pred = model(xb)
        loss = criterion(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in val_dl:
            pred = model(xb)
            correct += (pred.argmax(1) == yb).sum().item()
            total += yb.size(0)
    acc = 100 * correct / total
    print(f"Epoch {epoch+1}, Loss: {total_loss:.2f}, Val Acc: {acc:.2f}%")