### ML Setup

In [2]:
import sys
import typing
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

rng_seed = 90
torch.manual_seed(rng_seed)

if torch.cuda.is_available():
    device = torch.device("cuda:0")

print(device)

cuda:0


In [31]:
from pathlib import Path
import mne

eegmmidb_path = Path.cwd() / "eeg-motor-movementimagery-dataset-1.0.0" / "files"
example_edf_path = eegmmidb_path / "S001" / "S001R03.edf"

raw = mne.io.read_raw_edf(example_edf_path, preload=False)

print(raw)

# ---- 2. Print metadata ----
print("\nSampling rate:", raw.info["sfreq"])
print("Channels:", raw.info["ch_names"])
print(raw.annotations)
print("Number of channels:", raw.info["nchan"])
print("Duration (sec):", raw.n_times / raw.info["sfreq"])


events, event_id = mne.events_from_annotations(raw)
print(events)
# print(event_id)

tmin = 0
tmax = 4
epochs = mne.Epochs(
    raw,
    events,
    event_id=event_id,
    tmin=tmin,
    tmax=tmax,
    baseline=None,
    preload=True
)

print(epochs.get_data())

<RawEDF | S001R03.edf, 64 x 20000 (125.0 s), ~52 KiB, data not loaded>

Sampling rate: 160.0
Channels: ['Fc5.', 'Fc3.', 'Fc1.', 'Fcz.', 'Fc2.', 'Fc4.', 'Fc6.', 'C5..', 'C3..', 'C1..', 'Cz..', 'C2..', 'C4..', 'C6..', 'Cp5.', 'Cp3.', 'Cp1.', 'Cpz.', 'Cp2.', 'Cp4.', 'Cp6.', 'Fp1.', 'Fpz.', 'Fp2.', 'Af7.', 'Af3.', 'Afz.', 'Af4.', 'Af8.', 'F7..', 'F5..', 'F3..', 'F1..', 'Fz..', 'F2..', 'F4..', 'F6..', 'F8..', 'Ft7.', 'Ft8.', 'T7..', 'T8..', 'T9..', 'T10.', 'Tp7.', 'Tp8.', 'P7..', 'P5..', 'P3..', 'P1..', 'Pz..', 'P2..', 'P4..', 'P6..', 'P8..', 'Po7.', 'Po3.', 'Poz.', 'Po4.', 'Po8.', 'O1..', 'Oz..', 'O2..', 'Iz..']
<Annotations | 30 segments: T0 (15), T1 (8), T2 (7)>
Number of channels: 64
Duration (sec): 125.0
[[    0     0     1]
 [  672     0     3]
 [ 1328     0     1]
 [ 2000     0     2]
 [ 2656     0     1]
 [ 3328     0     2]
 [ 3984     0     1]
 [ 4656     0     3]
 [ 5312     0     1]
 [ 5984     0     3]
 [ 6640     0     1]
 [ 7312     0     2]
 [ 7968     0     1]
 [ 8640     0

In [89]:
from typing import List
import re
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader

train_subjects = [i for i in range(1, 100)] # 1-99
test_subjects = [i for i in range(100, 110)] # 100-109

mne.set_log_level("ERROR")

# def ema_normalize(x, alpha=0.001, eps=1e-6):
#     """
#     x shape: (trials, channels, time)
#     """
#     n_trials, n_channels, n_times = x.shape
#
#     mean = np.zeros((n_trials, n_channels))
#     var = np.ones((n_trials, n_channels))
#
#     out = np.zeros_like(x)
#
#     for t in range(n_times):
#         sample = x[:, :, t]
#
#         mean = (1 - alpha) * mean + alpha * sample
#         var = (1 - alpha) * var + alpha * (sample - mean) ** 2
#
#         out[:, :, t] = (sample - mean) / np.sqrt(var + eps)
#
#     return out
# def normalize_trial(x):
#     mean = x.mean(axis=1, keepdims=True)
#     std = x.std(axis=1, keepdims=True)
#     return (x - mean) / (std + 1e-6)

lrRun = {3, 4, 7, 8, 11, 12}
bothRun = {5, 6, 9, 10, 13, 14}

def extract_run_number(edf_path: str) -> int:
    name = Path(edf_path).name
    m = re.search(r"R(\d+)\.edf$", name, flags=re.IGNORECASE)
    if not m:
        raise ValueError(f"Cannot parse run number from filename: {name}")
    return int(m.group(1))

def preprocess_one_edf(edf_path, l_freq=8.0, h_freq=40.0,tmin=0.0, tmax=4.0,run_ica=True, random_state=1):

    edf_path = Path(edf_path)

    run = extract_run_number(edf_path)

    raw = mne.io.read_raw_edf(edf_path, preload=True)

    raw.pick_types(eeg=True)
    raw.filter(l_freq, h_freq)
    raw.set_eeg_reference("average", projection=False, verbose="error")

    print(f"channel names: {raw.ch_names}")

    # ICA
    if run_ica:
        ica = mne.preprocessing.ICA(n_components=0.99,method="fastica",random_state=random_state, max_iter="auto",)
        ica.fit(raw, verbose="error")
        frontal = [ch for ch in ["Fp1", "Fp2", "AFp1", "AFp2"] if ch in raw.ch_names]
        if len(frontal) >= 1:
            eog_proxy = raw.copy().pick_channels(frontal).get_data().mean(axis=0)
            sources = ica.get_sources(raw).get_data()
            corr = np.array([np.corrcoef(sources[i], eog_proxy)[0, 1] for i in range(sources.shape[0])])
            bad_ics = np.where(np.abs(corr) > 0.3)[0].tolist()
            ica.exclude = bad_ics
            raw = ica.apply(raw, verbose="error")
        else:
            pass

    # Events / epochs
    events, event_id = mne.events_from_annotations(raw)

    labels = {"T0", "T1", "T2"}
    if not labels.issubset(event_id.keys()):
        raise RuntimeError(f"[{edf_path.name}] Missing {labels - set(event_id.keys())}. Found: {list(event_id.keys())}")
    if run in lrRun:
        semantic = {"T0": "rest","T1": "left_fist","T2": "right_fist"}
    elif run in bothRun:
        semantic = {"T0": "rest","T1": "both_fists","T2": "both_feet"}
    else:
        raise ValueError(f"{edf_path}: run R{run:02d} not in task runs (3â€“14).")

    sem_event_id = {semantic[k]: event_id[k] for k in labels}

    epochs = mne.Epochs(
        raw, events, event_id=sem_event_id,
        tmin=tmin, tmax=tmax,
        baseline=None,
        preload=True,
        reject_by_annotation=True,
        verbose="error"
    )

    return epochs

class EEGMMIDBDataset(Dataset):
    def __init__(self, path: Path, subjects: List[int]):
        self.samples = []
        for subject in subjects:
            sub_folder = path / f"S{subject:03d}"

            for fname in os.listdir(sub_folder):
                if fname.endswith(".edf"):
                    X, y = self.extract_trials_from_edf(sub_folder / fname)
                    if X is None:
                        continue
                    for i in range(len(X)):
                        # X_norm = normalize_trial(X[i])
                        self.samples.append((X[i], y[i]))

        print(f"Total samples: {self.__len__()}")

    def __len__(self):
        return len(self.samples)

    def extract_trials_from_edf(self, edf_path):
        epochs = preprocess_one_edf(edf_path, )

        X = epochs.get_data()
        y = epochs.events[:, -1] - 1
        return X, y

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

loader_train = DataLoader(
    EEGMMIDBDataset(eegmmidb_path, train_subjects),
    batch_size=64,
    shuffle=True,
    num_workers=0
)
loader_test = DataLoader(
    EEGMMIDBDataset(eegmmidb_path, test_subjects),
    batch_size=64,
    shuffle=True,
    num_workers=0
)


(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
(38, 64, 513)
Total samples: 35043
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
(24, 64, 513)
Total samples: 3222


In [90]:
import torch
import torch.nn as nn


class EEGNet(nn.Module):

    def __init__(self, nb_classes: int, in_channels: int):
        super().__init__()

        F1 = 8
        D = 2
        F2 = F1 * D
        fs = 160 # Sampling frequency
        kernel_length = fs // 2

        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, (1, kernel_length), padding=(0, kernel_length // 2), bias=False),
            nn.BatchNorm2d(F1),

            nn.Conv2d(F1, F1 * D, (in_channels, 1), groups=F1, bias=False),
            nn.BatchNorm2d(F1 * D),

            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(0.25)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(F1 * D, F1 * D, (1, 16), padding=(0, 8), groups=F1 * D, bias=False),
            nn.Conv2d(F1 * D, F2, 1, bias=False),

            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(0.25)
        )

        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Linear(F2, nb_classes)

    def forward(self, x):
        x = x.unsqueeze(1)

        x = self.block1(x)
        x = self.block2(x)

        x = self.adaptive_pool(x)

        x = x.flatten(start_dim=1)
        return self.classifier(x)


In [91]:
# Training
dtype = torch.float32

def train_model(model, train_loader, optimizer, epochs=10):
    model = model.to(device)
    for e in range(epochs):
        for t, (x, y) in enumerate(train_loader):
            model.train()

            x = x.to(device, dtype=dtype)
            y = y.to(device, dtype=torch.long)

            # if t == 0:
            #     print("\n--- Batch Stats ---")
            #     print("shape:", x.shape)
            #     print("mean:", x.mean().item())
            #     print("std:", x.std().item())
            #     print("min:", x.min().item())
            #     print("max:", x.max().item())
            #     print("labels:", torch.unique(y))

            scores = model(x)

            if torch.isnan(scores).any():
                print("NaNs detected in model output!")
                return

            loss = F.cross_entropy(scores, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {e+1}/{epochs}, Loss {loss.item():.4f}")

def check_accuracy(model, loader):
    num_correct = 0
    num_samples = 0
    model.eval()
    with torch.no_grad():
        for t, (x, y) in enumerate(loader):
            x = x.to(device)
            y = y.to(device)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)

        acc = float(num_correct) / num_samples
        print('Got %d / %d correct of val set (%.2f)' % (num_correct, num_samples, 100 * acc))

In [None]:
model = EEGNet(nb_classes=3, in_channels=64).to(device)

check_accuracy(model, loader_test)

optimizer = torch.optim.Adamax(model.parameters(), lr=5e-4, weight_decay=1e-7)

train_model(model, loader_train, optimizer=optimizer, epochs=10)

check_accuracy(model, loader_test)

torch.save(model.state_dict(), 'model-epoch10.pt')

Got 801 / 3222 correct of val set (24.86)
Epoch 1/10, Loss 1.0580
