# Imports

In [1]:
import os
import glob
import random
import numpy as np

from sklearn.metrics import accuracy_score
from scipy.signal import butter, filtfilt
from scipy.io import loadmat
import matplotlib.pyplot as plt


import torch
import pandas as pd

# set seed
seed = 42

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x79c472ad7db0>

# Prepare Data

## Download dataset

In [2]:
## Download and extract DataSet
### Nakanishi et. al 2015
!curl -O https://sccn.ucsd.edu/download/cca_ssvep.zip
!unzip cca_ssvep.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  145M  100  145M    0     0  11.1M      0  0:00:13  0:00:13 --:--:-- 13.2M
Archive:  cca_ssvep.zip
   creating: cca_ssvep/
  inflating: cca_ssvep/s4.mat        
  inflating: cca_ssvep/s5.mat        
  inflating: cca_ssvep/s3.mat        
  inflating: cca_ssvep/s7.mat        
  inflating: cca_ssvep/chan_locs.pdf  
  inflating: cca_ssvep/readme.txt    
  inflating: cca_ssvep/s2.mat        
  inflating: cca_ssvep/s8.mat        
  inflating: cca_ssvep/s10.mat       
  inflating: cca_ssvep/s9.mat        
  inflating: cca_ssvep/s6.mat        
  inflating: cca_ssvep/s1.mat        


## Pre-processing

In [3]:
# Pre-Processing
def segment_eeg(folder, elecs=None, fs=256, duration=1., band=[5., 45.], order=4, onset=0.135):
  eeg_files = glob.glob(f"{data_folder}/*.mat")
  n_subejects = len(eeg_files)
  onset = 38 + int(onset*fs)
  end = int(duration*fs)
  X , Y = [], [] # empty data and labels
  for record in eeg_files:
    data = loadmat(record)
    # samples, channels, trials, targets
    eeg = data["eeg"].transpose((2, 1, 3, 0))
    # filter data
    eeg = filter_eeg(eeg, fs=fs, band=band, order=order)
    # segment data
    eeg = eeg[onset:onset+end, :, :, :]
    samples, channels, blocks, targets = eeg.shape
    y = np.tile(np.arange(1, targets + 1), (blocks, 1))
    y = y.reshape((1, blocks * targets), order='F')

    X.append(eeg.reshape((samples, channels, blocks * targets), order='F'))
    Y.append(y)

  X = np.array(X, dtype=np.float32, order="F")
  Y = np.array(Y, dtype=np.float32).squeeze()

  return X, Y

def filter_eeg(data, fs=256, band=[5., 45.], order=4, axis=0):
  B, A = butter(order, np.array(band) / (fs / 2), btype='bandpass')
  return filtfilt(B, A, data, axis=axis)

# Evaluation

## Segment data into epochs

In [4]:
data_folder = os.path.abspath('./cca_ssvep')
band = [8, 64]
order = 4
fs = 256
duration = 1.

X, Y = segment_eeg(data_folder, band=band, order=order, fs=fs, duration=duration)
print(f"X shape: {X.shape}") # subject x samples x channels x trials
print(f"Y shape: {Y.shape}")

X shape: (10, 256, 8, 180)
Y shape: (10, 180)


## Define Model

In [5]:
from torch import flatten
from torch import nn
import torch.nn.functional as F

class ChComb(nn.Module):
  def __init__(self, Chans=8, Samples=220, dropout=0.5):
    super().__init__()
    self.conv = nn.Conv1d(Chans // 2, Chans, 1, padding='same')
    self.ln   = nn.LayerNorm(Samples)
    self.act  = nn.GELU()
    self.do   = nn.Dropout(p=dropout)

  def forward(self, x):
    return self.do(self.act(self.ln(self.conv(x))))

class Encoder(nn.Module):
  def __init__(self, Chans=16, Samples=220, dropout=0.5):
    super().__init__()
    # CNN module
    self.channels = Chans
    self.ln1  = nn.LayerNorm(Samples)
    self.conv = nn.Conv1d(Chans, Chans, 31, padding='same')
    self.ln2  = nn.LayerNorm(Samples)
    self.act  = nn.GELU()
    self.do   = nn.Dropout(p=dropout)
    # MLP module
    self.ln3  = nn.LayerNorm(Samples)
    self.proj = nn.Linear(Chans, Samples)
    self.do2  = nn.Dropout(p=dropout)

  def forward(self, x):
    #
    shortcut1 = x
    x = self.conv(self.ln1(x))
    x = self.act(self.ln2(x))
    x = self.do(x) + shortcut1
    shortcut2 = x
    #
    x = self.ln3(x)
    output_channels = []
    for i in range(self.channels):
      c = self.proj(x[:,:,i])
      c = c.unsqueeze(1)
      output_channels.append(c)
    x = torch.cat(output_channels, 1)
    x = self.do(x) + shortcut2
    return x

class MlpHead(nn.Module):
  def __init__(self, Chans, Samples, n_classes, drop_rate=0.5):
    super().__init__()
    self.drop       = nn.Dropout(drop_rate)
    self.linear1    = nn.Linear(Chans * Samples, 6 * n_classes)
    self.norm       = nn.LayerNorm(6*n_classes)
    self.activation = nn.GELU()
    self.drop2      = nn.Dropout(drop_rate)
    self.linear2    = nn.Linear(6*n_classes, n_classes)

  def forward(self, x):
    x = flatten(x, 1)
    x = self.drop(x)
    x = self.linear1(x)
    x = self.norm(x)
    x = self.activation(x)
    x = self.drop2(x)
    x = self.linear2(x)
    return x

class SSVEPFormerTH(nn.Module):
  def __init__(self, Chans=8, n_classes=12, fs=256,
               band=[8, 64], resolution=0.25,
               drop_rate=0.25):
    super().__init__()
    self.name = "SSVEPFORMER"
    self.fs = fs
    self.resolution = resolution
    self.nfft  = round(fs / resolution)
    self.fft_start = int(round(band[0] / self.resolution))
    self.fft_end   = int(round(band[1] / self.resolution)) + 1
    samples = (self.fft_end - self.fft_start) * 2
    filters = 2*Chans

    self.channel_comb = ChComb(filters,  samples, drop_rate)
    self.encoder1     = Encoder(filters, samples, drop_rate)
    self.encoder2     = Encoder(filters, samples, drop_rate)
    self.head         = MlpHead(filters, samples, n_classes, drop_rate)

    self.init_weights()

  def init_weights(self):
    for module in self.modules():
        if hasattr(module, 'weight'):
          cls_name = module.__class__.__name__
          if not("BatchNorm" in cls_name or "LayerNorm" in cls_name):
            nn.init.normal_(module.weight, mean=0.0, std=0.01)
          else:
            nn.init.constant_(module.weight, 1)
          if hasattr(module, "bias"):
            if module.bias is not None:
              nn.init.constant_(module.bias, 0)

  def forward(self, x):
    x = self.transform(x)
    x = self.channel_comb(x)
    x = self.encoder1(x)
    x = self.encoder2(x)
    x = self.head(x)
    return x

  def transform(self, x):
    with torch.no_grad():
      samples = x.shape[-1]
      x = torch.fft.fft(x, n=self.nfft) / samples
      real = x.real[:,:, self.fft_start:self.fft_end]
      imag = x.imag[:,:, self.fft_start:self.fft_end]
      x = torch.cat((real, imag), axis=-1)
    return x


In [6]:
class FBSSVEPFormer(nn.Module):
  def __init__(self, fs=256, n_subbands=3, models=None):
    super().__init__()
    self.name = "FB-SSVEPFORMER"
    self.fs = fs
    self.subbands = [[8*i, 80] for i in range(1, n_subbands+1)]
    self.subnets = models
    self.conv = nn.Conv1d(n_subbands, 1, 1, padding='same')
    self.init_weights()

  def init_weights(self):
    nn.init.normal_(self.conv.weight, mean=0.0, std=0.01)
    nn.init.constant_(self.conv.bias, 0)

  def forward(self, x):
    out = []
    for i, band in enumerate(self.subbands):
      c = self.filter_band(x, band)
      c = self.subnets[i](c)
      c = c.unsqueeze(1)
      out.append(c)
    #
    x = torch.cat(out, 1)
    x = self.conv(x)
    return x.squeeze(1)

  def filter_band(self, x, band):
    # x: batch, channels, samples
    device = x.device
    with torch.no_grad():
      x = x.cpu().numpy()
      B, A = butter(4, np.array(band) / (self.fs / 2), btype='bandpass')
      x = filtfilt(B, A, x, axis=-1)
      x = x.copy()
    return torch.tensor(x, dtype=torch.float, device=device)



## Utils for training

In [7]:
def make_loader(x, y, batch_size=32, shuffle=True):
    """
    """
    tensor_set = torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.float32),
                                                torch.tensor(y, dtype=torch.long))
    loader = torch.utils.data.DataLoader(tensor_set,
                                         batch_size=batch_size,
                                         shuffle=shuffle)
    return loader

In [8]:
# Basic Pytorch training loop
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

def train(dataloader, model, loss_fn, optimizer, epochs):
    size = len(dataloader.dataset)
    model.to(device)
    model.train()
    for epoch in range(epochs):
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            loss.backward()
            optimizer.step()

            if batch % 100 == 0:
                loss, current = loss.item(), (batch + 1) * len(X)

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    return 100*correct

In [9]:
def concatenate_subjects(x, y, fold):
  X = np.concatenate([x[idx] for idx in fold], axis=-1)
  Y = np.concatenate([y[idx] for idx in fold], axis=-1)
  X  = X.transpose((2,1,0))
  return X, Y - 1

## Training and testing

In [11]:
def train_model(model, x_train, y_train, x_val, y_val, band):
    x  = filter_eeg(x_train, fs=256, band=band, order=4, axis=-1)
    xv = filter_eeg(x_val, fs=256, band=band, order=4, axis=-1)
    x = x.copy()
    xv = xv.copy()

    # Create data loaders.
    train_dataloader = make_loader(x, y_train, batch_size=BATCH_SIZE, shuffle=True)
    test_dataloader  = make_loader(xv, y_val, batch_size=BATCH_SIZE)

    model.init_weights()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LR,
                                momentum=0.9,
                                weight_decay=WD, nesterov=False)

    loss_fn = nn.CrossEntropyLoss()

    train(train_dataloader, model, loss_fn, optimizer, EPOCHS)

    acc = test(test_dataloader, model, loss_fn)
    print(f"Accuracy: {acc}")

def freeze_model(model):
  for param in model.parameters():
    param.requires_grad = False

In [12]:
# Training settings
BATCH_SIZE = 128
EPOCHS = 100
LR = 0.001
WD = 0.001
DROP_RATE = 0.5

channels  = 8
n_classes = 12
fft_resolution = 0.25
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
acc = np.zeros((10, 1))

for subject in range(0, 10):
    print(f"Subject: {subject + 1} Training...")

    folds = np.delete(np.arange(10), subject)
    train_index = folds
    test_index  = [subject]

    # create data split for each subject
    x_train, y_train = concatenate_subjects(X, Y, train_index)
    x_val, y_val     = concatenate_subjects(X, Y, test_index)

    # Create data loaders.
    train_dataloader = make_loader(x_train, y_train, batch_size=BATCH_SIZE, shuffle=True)
    test_dataloader  = make_loader(x_val, y_val, batch_size=BATCH_SIZE)

    # create model
    model = SSVEPFormerTH(Chans=channels,
                           n_classes=n_classes,
                           fs=fs,
                           band=band,
                           resolution=fft_resolution,
                           drop_rate=DROP_RATE)
    model.init_weights()


    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LR,
                                momentum=0.9,
                                weight_decay=WD, nesterov=False)

    loss_fn = nn.CrossEntropyLoss()

    train(train_dataloader, model, loss_fn, optimizer, EPOCHS)

    acc[subject] = test(test_dataloader, model, loss_fn)
    print(f"Subhject {subject+1} Accuracy: {acc[subject]}")
    print("-------------------------------")

print(f"Mean Accuracy: {np.mean(acc)} +- {np.std(acc)}")

Subject: 1 Training...
Subhject 1 Accuracy: [87.77777778]
-------------------------------
Subject: 2 Training...
Subhject 2 Accuracy: [97.77777778]
-------------------------------
Subject: 3 Training...
Subhject 3 Accuracy: [96.11111111]
-------------------------------
Subject: 4 Training...
Subhject 4 Accuracy: [95.]
-------------------------------
Subject: 5 Training...
Subhject 5 Accuracy: [56.11111111]
-------------------------------
Subject: 6 Training...
Subhject 6 Accuracy: [56.66666667]
-------------------------------
Subject: 7 Training...
Subhject 7 Accuracy: [94.44444444]
-------------------------------
Subject: 8 Training...
Subhject 8 Accuracy: [97.22222222]
-------------------------------
Subject: 9 Training...
Subhject 9 Accuracy: [66.11111111]
-------------------------------
Subject: 10 Training...
Subhject 10 Accuracy: [98.88888889]
-------------------------------
Mean Accuracy: 84.61111111111111 +- 16.788535918048506


In [18]:
acc = np.zeros((10, 1))
subbands = 3
bands = [ [8*i, 80] for i in range(1, subbands+1)]

for subject in range(0, 10):
    print(f"Subject: {subject + 1} Training...")

    folds = np.delete(np.arange(10), subject)
    train_index = folds
    test_index  = [subject]

    # create data split for each subject
    x_train, y_train = concatenate_subjects(X, Y, train_index)
    x_val, y_val     = concatenate_subjects(X, Y, test_index)

    # Create data loaders.
    train_dataloader = make_loader(x_train, y_train, batch_size=BATCH_SIZE, shuffle=True)
    test_dataloader  = make_loader(x_val, y_val, batch_size=BATCH_SIZE)

    models = [SSVEPFormerTH(Chans=channels,
                           n_classes=n_classes,
                           fs=fs,
                           band=b,
                           resolution=fft_resolution,
                           drop_rate=DROP_RATE) for b in bands]

    for mod, b in zip(models, bands):
        train_model(mod, x_train, y_train, x_val, y_val, b)
        freeze_model(mod)


    # create model
    model = FBSSVEPFormer(fs=256, n_subbands=3, models=models)
    model.init_weights()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LR,
                                momentum=0.9,
                                weight_decay=WD, nesterov=False)

    loss_fn = nn.CrossEntropyLoss()

    train(train_dataloader, model, loss_fn, optimizer, 20)

    acc[subject] = test(test_dataloader, model, loss_fn)
    print(f"Subhject {subject+1} Accuracy: {acc[subject]}")
    print("-------------------------------")

print(f"Mean Accuracy: {np.mean(acc)} +- {np.std(acc)}")

Subject: 1 Training...
Accuracy: 96.66666666666667
Accuracy: 44.44444444444444
Accuracy: 60.55555555555555
Subhject 1 Accuracy: [96.66666667]
-------------------------------
Subject: 2 Training...
Accuracy: 98.88888888888889
Accuracy: 85.55555555555556
Accuracy: 50.55555555555556
Subhject 2 Accuracy: [100.]
-------------------------------
Subject: 3 Training...
Accuracy: 94.44444444444444
Accuracy: 87.22222222222223
Accuracy: 92.77777777777779
Subhject 3 Accuracy: [97.77777778]
-------------------------------
Subject: 4 Training...
Accuracy: 46.666666666666664
Accuracy: 29.444444444444446
Accuracy: 20.555555555555554
Subhject 4 Accuracy: [49.44444444]
-------------------------------
Subject: 5 Training...
Accuracy: 88.33333333333333
Accuracy: 57.77777777777777
Accuracy: 40.55555555555556
Subhject 5 Accuracy: [83.88888889]
-------------------------------
Subject: 6 Training...
Accuracy: 94.44444444444444
Accuracy: 87.22222222222223
Accuracy: 70.55555555555556
Subhject 6 Accuracy: [98.33