In [7]:
import os
import numpy as np
import mne
from scipy import stats
import scipy.io
import h5py

mne.set_log_level('error')

from utils.load import Load
from config.default import cfg

import torch
import torch.nn as nn

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from torch.utils.data import Dataset, DataLoader


from models.eegnet import EEGNet
from torchsummary import summary

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
loader = Load(cfg)

In [3]:
# Change this one-by-one to process all subjects in the dataset
# Yeah, Im lazy, shut up!
subject_id = 0

In [9]:
class UHD_Dataset(Dataset):
    def __init__(self, subject_id, split, device="cpu", config="default", verbose=False):
        self.split = split
        self.device = device
        raw_runs = loader.load_subject(subject_id = subject_id)

        


        preprocessed_runs = raw_runs

        for run in preprocessed_runs:
            run = run.resample(200)
            run = run.notch_filter(60)
            run = run.filter(8, 25)
            run = run.set_eeg_reference('average', projection=False)
            run = run.drop_channels(cfg['not_ROI_channels'])  

        def create_epochs(raw):
                events, event_ids = mne.events_from_annotations(raw)
                return mne.Epochs(
                    raw,
                    events=events,
                    event_id=event_ids,
                    tmin=-2,
                    tmax=7,
                    baseline=(-2,0),
                    preload=True,
                )

        epochs = [create_epochs(run) for run in preprocessed_runs]
        epochs = mne.concatenate_epochs(epochs)
        epochs = epochs.crop(0, 7)

        X = epochs.get_data()
        orig_shape = X.shape
        X = X.reshape(X.shape[0], -1)
        scaler = StandardScaler()
        X = scaler.fit_transform(X)
        X = X.reshape(orig_shape)



        self.X = X
        self.y = epochs.events[:, -1]-2

        train_X, test_X, train_y, test_y = train_test_split(self.X, self.y, test_size=0.2, random_state=42)
        if self.split == 'train':
            self.X = train_X
            self.y = train_y
        elif self.split == 'test':
            self.X = test_X
            self.y = test_y

        self.X = torch.from_numpy(self.X).float()
        self.y = torch.from_numpy(self.y).long()

        self.X = self.X.to(self.device)
        self.y = self.y.to(self.device)

        self.time_steps = self.X.shape[-1]
        self.channels = self.X.shape[-2]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [10]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

In [11]:
subject = 0
batch_size = 5

train_runs = [0,1,2,3]
test_runs = [4]

train_dataset = UHD_Dataset(subject, 'train', device = device)
test_dataset = UHD_Dataset(subject, 'test', device = device)

train_dataloader = DataLoader(train_dataset,  batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, drop_last=True)

In [12]:
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

for features, label in train_dataloader:
    print(features.shape)
    print(label)
    break
    

Train dataset: 200 samples
Test dataset: 50 samples
torch.Size([5, 158, 1401])
tensor([1, 1, 1, 3, 2])


In [13]:
channels = train_dataset.channels
samples = train_dataset.time_steps
model = EEGNet(channels = channels, samples= samples, num_classes = 5)
model.to(device)
summary(model, input_size=(5, 10, *next(iter(train_dataloader))[0][0].shape));

Layer (type:depth-idx)                   Param #
├─Conv2d: 1-1                            512
├─BatchNorm2d: 1-2                       16
├─Conv2d: 1-3                            2,528
├─BatchNorm2d: 1-4                       32
├─ELU: 1-5                               --
├─AvgPool2d: 1-6                         --
├─Dropout: 1-7                           --
├─Conv2d: 1-8                            256
├─Conv2d: 1-9                            128
├─BatchNorm2d: 1-10                      16
├─AvgPool2d: 1-11                        --
├─Dropout: 1-12                          --
├─Flatten: 1-13                          --
├─Linear: 1-14                           1,725
Total params: 5,213
Trainable params: 5,213
Non-trainable params: 0


In [14]:
# Test forward pass
model(next(iter(train_dataloader))[0]);

In [15]:
from torch import no_grad
from sklearn.metrics import accuracy_score

def accuracy(model, dataloader):
    all_labels = []
    all_predictions = []

    with no_grad():
        for features, labels in dataloader:
            outputs = model(features)
            _, predicted = torch.max(outputs.data, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    return accuracy_score(all_labels, all_predictions) * 100

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Training loop
for epoch in range(100):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(batch_features)
    
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 5 == 4:
        train_accuracy = accuracy(model, train_dataloader)
        test_accuracy = accuracy(model, test_dataloader)
        print(f"Epoch {epoch + 1}/{100}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(model, train_dataloader):.2f}%')
print(f'Final test accuracy: {accuracy(model, test_dataloader):.2f}%')

Epoch 5/100, Loss: 64.22528767585754, Train accuracy: 23.00%, Test accuracy: 14.00%
Epoch 10/100, Loss: 63.42068552970886, Train accuracy: 36.50%, Test accuracy: 10.00%
Epoch 15/100, Loss: 61.631113052368164, Train accuracy: 40.50%, Test accuracy: 14.00%
Epoch 20/100, Loss: 57.01328504085541, Train accuracy: 53.00%, Test accuracy: 14.00%
Epoch 25/100, Loss: 49.401441395282745, Train accuracy: 59.00%, Test accuracy: 18.00%
Epoch 30/100, Loss: 42.51008361577988, Train accuracy: 66.50%, Test accuracy: 14.00%
Epoch 35/100, Loss: 35.80991643667221, Train accuracy: 74.00%, Test accuracy: 22.00%
Epoch 40/100, Loss: 30.603190273046494, Train accuracy: 81.00%, Test accuracy: 18.00%
Epoch 45/100, Loss: 25.391809657216072, Train accuracy: 83.50%, Test accuracy: 14.00%
Epoch 50/100, Loss: 20.543932557106018, Train accuracy: 87.00%, Test accuracy: 18.00%
Epoch 55/100, Loss: 17.0146426782012, Train accuracy: 91.50%, Test accuracy: 18.00%
Epoch 60/100, Loss: 14.17531968653202, Train accuracy: 96.00%,