# About this notebook

This notebook showcase how we can load an EEG dataset, and perform preprocessing. We train the EEGNet model with leave one subject paradigm and k-fold cross-validation.

In [1]:
%cd ../

D:\workspace\github\bcikit


In [2]:
import numpy as np
import torch
from torch import nn, optim

In [3]:
config = {
  "seed": 12,
  "segment_config": {
    "window_len": 1,
    "shift_len": 250,
    "sample_rate": 250,
  },
  "bandpass_config": {
      "sample_rate": 250,
      "lowcut": 7,
      "highcut": 70,
      "order": 6
  },
  "subject_ids": {
    "low": 1,
    "high": 10
  },
  "root": "../data/hsssvep",
  "selected_channels": ['PZ', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'O1', 'Oz', 'O2', 'PO7', 'PO8'],
  "batchsize": 64,
  "num_classes": 40,
}

# Data

## Create preprocessing function

In [4]:
from bcikit.transforms.channels import pick_channels
from bcikit.transforms.segment_time import segment_data_time_domain
from bcikit.transforms.bandpass import butter_bandpass_filter


def preprocessing(data, targets, channel_names, sample_rate, selected_channels, segment_config, bandpass_config, verbose, **kwargs):
    print()
    print("preprocessing data shape", data.shape) # (subject,session,trial,channel,time)

    # filter channels
    data = pick_channels(
        data=data, 
        channel_names=channel_names, 
        selected_channels=selected_channels,
        verbose=False
    )
    print("after pick_channels", data.shape)

    # segment signal and select the first segment
    data = segment_data_time_domain(
        data=data,
        window_len=segment_config['window_len'],
        shift_len=segment_config['shift_len'],
        sample_rate=segment_config['sample_rate'],
        add_segment_axis=True,
    )
    data = data[:, :, :, :, 0, :] # select the first segment, and remove the rest
    print("after segment_data_time_domain", data.shape)

    # bandpass filter
    data = butter_bandpass_filter(
        signal=data, 
        lowcut=bandpass_config["lowcut"], 
        highcut=bandpass_config["highcut"], 
        sample_rate=bandpass_config["sample_rate"], 
        order=bandpass_config["order"]
    )
    print("after butter_bandpass_filter", data.shape)

    # since we are doing leave one subject out, we don't care about `session`, we only want data in this format (subject, trial, channel, time).
    data = data.reshape((data.shape[0], data.shape[1]*data.shape[2], data.shape[3], data.shape[4]))
    targets = targets.reshape((targets.shape[0], targets.shape[1]*targets.shape[2]))

    return data, targets

## Load data

In [5]:
from bcikit.datasets.ssvep import HSSSVEP
from bcikit.datasets import EEGDataloader
from bcikit.datasets.data_selection_methods import leave_one_subject_out


subject_ids = list(np.arange(config['subject_ids']['low'], config['subject_ids']['high']+1, dtype=int))
print("Load subject IDs", subject_ids)

data = EEGDataloader(
    dataset=HSSSVEP, 
    root=config["root"], 
    subject_ids=subject_ids,
    preprocessing_fn=preprocessing, # customize your own preprocessing
    data_selection_fn=leave_one_subject_out, # customize your data selection function or use common ones from `data_selection_methods`
    verbose=True,
    selected_channels=config["selected_channels"],
    segment_config=config["segment_config"],
    bandpass_config=config["bandpass_config"],
)

train_loader, val_loader, test_loader = data.get_dataloaders(test_subject_id=1, batchsize=64)

print()
print("train_loader:", train_loader.dataset.data.shape, train_loader.dataset.targets.shape)
print("val_loader:", val_loader.dataset.data.shape, val_loader.dataset.targets.shape)
print("test_loader:", test_loader.dataset.data.shape, test_loader.dataset.targets.shape)

Load subject IDs [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Load subject: 1
Load subject: 2
Load subject: 3
Load subject: 4
Load subject: 5
Load subject: 6
Load subject: 7
Load subject: 8
Load subject: 9
Load subject: 10

preprocessing data shape (10, 1, 240, 64, 1000)
after pick_channels (10, 1, 240, 11, 1000)
after segment_data_time_domain (10, 1, 240, 11, 250)
after butter_bandpass_filter (10, 1, 240, 11, 250)

train_loader: (1440, 11, 250) (1440,)
val_loader: (720, 11, 250) (720,)
test_loader: (240, 11, 250) (240,)


# Model

In [6]:
num_epochs = 10
learning_rate = 0.001
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
from bcikit.models import CompactEEGNet
from bcikit.models.utils import count_params


model = CompactEEGNet(
    num_channel=len(config['selected_channels']),
    num_classes=config['num_classes'],
    signal_length=config['segment_config']['window_len'] * config['bandpass_config']['sample_rate'],
).to(device)

x = torch.ones((16, len(config['selected_channels']), config['segment_config']['window_len']*config['bandpass_config']['sample_rate'])).to(device)
y = model(x)
print("Input shape:", x.shape)
print("Output shape:", y.shape)
print('Model size:', count_params(model))

Input shape: torch.Size([16, 11, 250])
Output shape: torch.Size([16, 40])
Model size: 63496


## Training

In [8]:
params_to_update = []
for name, param in model.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)

optimizer = optim.Adam(params_to_update, lr=learning_rate, weight_decay=0.05)

criterion = nn.CrossEntropyLoss()

In [9]:
for epoch in range(num_epochs):
    epoch_loss = 0.0
    model.train()
    
    for X, Y in train_loader:
        inputs = X.to(device)
        labels = Y.long().to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        epoch_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print('Epoch {} loss: {:.5f}'.format(epoch, epoch_loss))

Epoch 0 loss: 85.87423
Epoch 1 loss: 81.52420
Epoch 2 loss: 72.76737
Epoch 3 loss: 62.58075
Epoch 4 loss: 54.42659
Epoch 5 loss: 47.62129
Epoch 6 loss: 45.38964
Epoch 7 loss: 45.18582
Epoch 8 loss: 41.74949
Epoch 9 loss: 44.05182
