I created this notebook in order to practice writing and running deep learning algorithms on my own EEG data. For information about the data used in this project, see the following page: https://lpljacob.github.io/word_priming/

In [1]:
import numpy as np
import scipy.io as sio
from scipy import stats
from os.path import dirname, join as pjoin
from skimage.measure import block_reduce
from sklearn.model_selection import StratifiedKFold

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

The cell below declares variables used in loading and processing the data.

The EEG files used in this project have been pre-processed using EEGLAB (details: https://www.biorxiv.org/content/10.1101/862516v1) and are saved as .mat files (one separate file per subject) containing 3D arrays. 

Arrays consist of these dimensions:
- electrodes (64);
- timepoints (700, in ms);
- trials (variable per subject, as trials with eye blinks, eye movements, and other artifacts were discarded).

Each subject also needs a separate vector with condition labels; vector length must be equal to the third dimension of corresponding EEG data array.

Based on domain knowledge and the SVM results reported in the paper linked above, I am selecting only the timepoints from 200-500 ms. (see https://lpljacob.github.io/word_priming/ for an explanation)

In [2]:
folder = 'C:\_Files\svm files'
n = 20 # number of subjects
electrodes = 64
starting_point = 200 # starting point (ms) used for analyses
points = 300 # how many points (ms) to use for analyses
timewind_size = 10 # if larger than 1, average across points (ms)

# preallocate space
eeg_data = np.empty(shape=(electrodes,int(points/timewind_size),0))
trial_info = np.empty(shape=(1,0))
sub_info = np.empty(shape=(1,0))

The cell below loads and processes the data according to the parameters declared above.

In [3]:
for i in range(1,n+1): # iterate through subjects

    # load and prepare EEG trial data (X)
    mat_file = sio.loadmat(pjoin(folder, 'subdata' + str(i) + '.mat'))
    mat_data = mat_file['datasave']

    # selects specific timepoints and discards the rest
    crop_data = mat_data[:,starting_point:starting_point+points,:] 
    
    # if timewind_size > 1, average across timepoints
    avg_data = block_reduce(crop_data, block_size=(1,timewind_size,1), 
                            func=np.mean, cval=np.mean(crop_data))
    
    # normalize trial
    normdata = np.zeros(shape=avg_data.shape)
                         
    for j in range(avg_data.shape[-1]):
        normdata[:,:,j] = (
            avg_data[:,:,j] - np.mean(avg_data[:,:,j].flatten())) / np.std(avg_data[:,:,j].flatten())

    # load and prepare trial info (y)
    mat_info = sio.loadmat(pjoin(folder, 'trialinfo' + str(i) + '.mat'))
    mat_infodata = mat_info['trialsave']

    # store everything
    eeg_data = np.append(eeg_data, normdata,axis=2)
    trial_info = np.append(trial_info, mat_infodata)
    sub_info = np.append(sub_info, np.repeat(i, avg_data.shape[-1]))

The cell below creates a new file with elements that contain both subject and condition information for each trial. This is used for stratified kfolds; we want conditions and subjects to be counterbalanced across folds.

The first part of the code collapses the 16 labels (which stipulate both experimental condition and subject choice) into 8 condition labels, regardless of subject choice. (again, see https://lpljacob.github.io/word_priming/ for information on the experimental paradigm)

Finally, we transpose the EEG data (to pytorch specifications, with batch_size first for personal preference) and convert it to the data type we will be using for the model.

In [4]:
# obtain condition information
condition_pairs = np.array([13, 14, 15, 16, 9, 10, 11, 12])
conditions = trial_info.copy()

for i in range (1,9):
    conditions[np.where(conditions == condition_pairs[i-1])] = i

# combine conditions and sub_info (for stratified kfolds)
details = np.chararray((conditions.shape[-1]))
details = np.core.defchararray.add(np.char.mod('%d', conditions), np.char.mod('s%d', sub_info))

# transpose data to batch_size, seq_len, input_size
examples = eeg_data.transpose(2,1,0)
examples = examples.astype('float32')

Our goal is to predict what choice (out of 2 options) the subject has made in each trail. The label file is generated below. 

Afterward, check if we can train on GPU.

In [5]:
# obtain labels
binlabels = np.zeros(trial_info.shape)
binlabels[trial_info<9] = 1 # 'same' choice
binlabels[trial_info>8] = 0 # 'different' choice
binlabels = binlabels.astype('float32') 

# check if we can train on gpu
train_on_gpu=torch.cuda.is_available()
if train_on_gpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

Below, we declare the model architecture. Given we have timeseries data, I have chosen a recurrent network. The code supports both GRU and LSTM, and allows the user to stipulate the number of hidden units and number of GRU/LSTM layers.

This is a binary classification task, so we have a fully-connected layer with output of 1 following our recurrent layer(s), and we apply a sigmoid function to this output.

In [6]:
class RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, n_layers, output_size=1, rnntype='gru'):
        super(RNN, self).__init__()
        
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.type = rnntype
        
        if rnntype=='gru':
            self.rec = nn.GRU(input_size, hidden_size, n_layers, batch_first=True)
        elif rnntype == 'lstm':
            self.rec = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True)
        
        self.fc = nn.Linear(hidden_size, output_size)
        self.sig = nn.Sigmoid()

    def forward(self, x, hidden):

        batch_size = x.size(0)
        
        out, hidden = self.rec(x, hidden)
        out = out.contiguous().view(-1, self.hidden_size)
        
        out = self.fc(out)
        sig_out = self.sig(out)
        
        sig_out = sig_out.view(batch_size, -1)
        sig_out = sig_out[:, -1] # get last batch of labels
        
        return sig_out, hidden
    
    def init_hidden(self, batch_size):

        weight = next(self.parameters()).data
        
        if self.type=='gru':
            hidden = weight.new(self.n_layers, batch_size, self.hidden_size).zero_().to(device)
        elif self.type=='lstm':
            hidden = (weight.new(self.n_layers, batch_size, self.hidden_size).zero_().to(device),
                  weight.new(self.n_layers, batch_size, self.hidden_size).zero_().to(device))
        
        return hidden

Below we declare an early stopping class (by [stefanonardo](https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d)). This will interrupt model training based on validation accuracy, preventing overfitting.

In [7]:
# credits: https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d

class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if np.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

Finally, we define the training function.

Initial exploration with the model showed it was very prone to overfitting, likely due to the small amount of trials per subject and condition. To remedy this, I've tried to keep the model relatively simple (with one recurrent layer and 16 hidden units).

I am also using L1 regularization to encourage a sparse solution, as I believe only a few electrodes (input features) will contain neural activity relevant to the classification task at hand.

In [10]:
folds = 10
skf = StratifiedKFold(n_splits=folds, shuffle=True)

parameters = {'batch_size': 64, 'L1_scale': 0.0005, 'hidden_size': 16, 'n_layers': 1}

def train_model(parameters):
    
    model_type = 'gru'
    n_layers = parameters['n_layers']
    
    epochs = 1000 # arbitrary large number; early stopping will interrupt training much sooner

    counter = 0
    f_counter = 0

    all_accs = []
    all_losses = [] 

    # kfold loop
    for train_index, test_index in skf.split(examples, details):
        f_counter += 1

        # create a new instance of the model
        net = RNN(input_size=examples.shape[-1], hidden_size=parameters['hidden_size'], 
                  n_layers=n_layers, rnntype=model_type)
        
        # determine loss function and optimizer
        criterion = nn.BCELoss()
        optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

        if(train_on_gpu):
            net.cuda()

        # obtain kfold data
        X_train, X_val = examples[train_index], examples[test_index]
        y_train, y_val = binlabels[train_index], binlabels[test_index]

        # convert to tensor
        train_data = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
        val_data = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))

        # create dataloaders
        train_loader = DataLoader(train_data, batch_size=parameters['batch_size'])
        val_loader = DataLoader(val_data, batch_size=y_val.shape[-1])

        total_correct = 0
        total_examples = 0

        fold_losses = []
        fold_accs = []

        es = EarlyStopping(patience=10)

        for e in range(epochs):

            # training
            for inputs, labels in train_loader:             
                counter += 1

                # initialize hidden state
                h = net.init_hidden(inputs.size(0))
                net.train()

                if(train_on_gpu):
                    inputs, labels = inputs.cuda(), labels.cuda()

                net.zero_grad()

                if model_type == 'gru':
                    h = h.data
                else:
                    h = tuple([e.data for e in h])

                output, h = net(inputs, h)

                # L1 regularization
                L1_reg = torch.tensor(0., requires_grad=True)
                for name, param in net.named_parameters():
                    if 'weight' in name:
                        L1_reg = L1_reg + torch.norm(param, 1)
                        
                loss = criterion(output.squeeze(), labels.float()) + parameters['L1_scale'] * L1_reg

                loss.backward()

                optimizer.step()

            # validation
            val_losses = []
            net.eval()
            for inputs, labels in val_loader:

                val_h = net.init_hidden(inputs.size(0))

                if(train_on_gpu):
                    inputs, labels = inputs.cuda(), labels.cuda()

                output, val_h = net(inputs, val_h)
                val_loss = criterion(output.squeeze(), labels.float())

                val_losses.append(val_loss.item())

                total_correct =+ torch.sum(output.round()==labels).cpu().numpy()
                total_examples =+ labels.size(0)

            net.train()

            fold_losses.append(np.mean(val_losses))
            fold_accs.append(total_correct/total_examples)

            print("Fold: {}/{}...".format(f_counter, folds),
                  "Epoch: {}/{}...".format(e+1, epochs),
                  "Loss: {:.6f}...".format(loss.item()),
                  "Val Loss: {:.6f}...".format(np.mean(val_losses)),
                  "Accuracy: {:.6f}".format(total_correct/total_examples))

            # check if we should stop training on this fold
            if es.step(np.mean(val_losses)):
                all_accs.append(np.max(fold_accs))
                all_losses.append(np.min(fold_losses))
                print("Early stopping now...",
                      "Min val loss: {:.6f}...".format(np.min(fold_losses)),
                      "Max acc: {:.6f}".format(np.max(fold_accs)))
                break
                
    return all_accs, all_losses

We train the model (which should only take a couple minutes if using a GPU), then look at the results.

In [11]:
all_accs, all_losses = train_model(parameters)

Fold: 1/10... Epoch: 1/1000... Loss: 0.751457... Val Loss: 0.692414... Accuracy: 0.530120
Fold: 1/10... Epoch: 2/1000... Loss: 0.686209... Val Loss: 0.688675... Accuracy: 0.530120
Fold: 1/10... Epoch: 3/1000... Loss: 0.658448... Val Loss: 0.686569... Accuracy: 0.543264
Fold: 1/10... Epoch: 4/1000... Loss: 0.643891... Val Loss: 0.685023... Accuracy: 0.541073
Fold: 1/10... Epoch: 5/1000... Loss: 0.633489... Val Loss: 0.683737... Accuracy: 0.554217
Fold: 1/10... Epoch: 6/1000... Loss: 0.625578... Val Loss: 0.682604... Accuracy: 0.562979
Fold: 1/10... Epoch: 7/1000... Loss: 0.619731... Val Loss: 0.681388... Accuracy: 0.565170
Fold: 1/10... Epoch: 8/1000... Loss: 0.615376... Val Loss: 0.680056... Accuracy: 0.564074
Fold: 1/10... Epoch: 9/1000... Loss: 0.613478... Val Loss: 0.678265... Accuracy: 0.572837
Fold: 1/10... Epoch: 10/1000... Loss: 0.609757... Val Loss: 0.676004... Accuracy: 0.570646
Fold: 1/10... Epoch: 11/1000... Loss: 0.603917... Val Loss: 0.673128... Accuracy: 0.584885
Fold: 1/

Fold: 2/10... Epoch: 20/1000... Loss: 0.498756... Val Loss: 0.646932... Accuracy: 0.620044
Fold: 2/10... Epoch: 21/1000... Loss: 0.494184... Val Loss: 0.645780... Accuracy: 0.623348
Fold: 2/10... Epoch: 22/1000... Loss: 0.491623... Val Loss: 0.644746... Accuracy: 0.629956
Fold: 2/10... Epoch: 23/1000... Loss: 0.490338... Val Loss: 0.643819... Accuracy: 0.629956
Fold: 2/10... Epoch: 24/1000... Loss: 0.490543... Val Loss: 0.643043... Accuracy: 0.628855
Fold: 2/10... Epoch: 25/1000... Loss: 0.491269... Val Loss: 0.642362... Accuracy: 0.629956
Fold: 2/10... Epoch: 26/1000... Loss: 0.492293... Val Loss: 0.641792... Accuracy: 0.631057
Fold: 2/10... Epoch: 27/1000... Loss: 0.493287... Val Loss: 0.641329... Accuracy: 0.638767
Fold: 2/10... Epoch: 28/1000... Loss: 0.493895... Val Loss: 0.640916... Accuracy: 0.635463
Fold: 2/10... Epoch: 29/1000... Loss: 0.494195... Val Loss: 0.640528... Accuracy: 0.639868
Fold: 2/10... Epoch: 30/1000... Loss: 0.494521... Val Loss: 0.640189... Accuracy: 0.640969

Fold: 4/10... Epoch: 17/1000... Loss: 0.542166... Val Loss: 0.653716... Accuracy: 0.602492
Fold: 4/10... Epoch: 18/1000... Loss: 0.536111... Val Loss: 0.651092... Accuracy: 0.601359
Fold: 4/10... Epoch: 19/1000... Loss: 0.530252... Val Loss: 0.648444... Accuracy: 0.616082
Fold: 4/10... Epoch: 20/1000... Loss: 0.524409... Val Loss: 0.646052... Accuracy: 0.625142
Fold: 4/10... Epoch: 21/1000... Loss: 0.518517... Val Loss: 0.644070... Accuracy: 0.633069
Fold: 4/10... Epoch: 22/1000... Loss: 0.512870... Val Loss: 0.642609... Accuracy: 0.630804
Fold: 4/10... Epoch: 23/1000... Loss: 0.507194... Val Loss: 0.641570... Accuracy: 0.635334
Fold: 4/10... Epoch: 24/1000... Loss: 0.501518... Val Loss: 0.640837... Accuracy: 0.638732
Fold: 4/10... Epoch: 25/1000... Loss: 0.495942... Val Loss: 0.640264... Accuracy: 0.640997
Fold: 4/10... Epoch: 26/1000... Loss: 0.490358... Val Loss: 0.639854... Accuracy: 0.634202
Fold: 4/10... Epoch: 27/1000... Loss: 0.484725... Val Loss: 0.639616... Accuracy: 0.638732

Fold: 6/10... Epoch: 20/1000... Loss: 0.514862... Val Loss: 0.660228... Accuracy: 0.603708
Fold: 6/10... Epoch: 21/1000... Loss: 0.510954... Val Loss: 0.660433... Accuracy: 0.602549
Fold: 6/10... Epoch: 22/1000... Loss: 0.507265... Val Loss: 0.660663... Accuracy: 0.603708
Fold: 6/10... Epoch: 23/1000... Loss: 0.503687... Val Loss: 0.660939... Accuracy: 0.601390
Fold: 6/10... Epoch: 24/1000... Loss: 0.499968... Val Loss: 0.661147... Accuracy: 0.602549
Fold: 6/10... Epoch: 25/1000... Loss: 0.496394... Val Loss: 0.661322... Accuracy: 0.603708
Fold: 6/10... Epoch: 26/1000... Loss: 0.493209... Val Loss: 0.661401... Accuracy: 0.604867
Fold: 6/10... Epoch: 27/1000... Loss: 0.490403... Val Loss: 0.661437... Accuracy: 0.600232
Fold: 6/10... Epoch: 28/1000... Loss: 0.487451... Val Loss: 0.661251... Accuracy: 0.600232
Early stopping now... Min val loss: 0.659856... Max acc: 0.614137
Fold: 7/10... Epoch: 1/1000... Loss: 0.830648... Val Loss: 0.691585... Accuracy: 0.540828
Fold: 7/10... Epoch: 2/10

Fold: 7/10... Epoch: 81/1000... Loss: 0.433052... Val Loss: 0.653409... Accuracy: 0.607101
Fold: 7/10... Epoch: 82/1000... Loss: 0.429250... Val Loss: 0.653223... Accuracy: 0.609467
Fold: 7/10... Epoch: 83/1000... Loss: 0.425407... Val Loss: 0.653050... Accuracy: 0.609467
Fold: 7/10... Epoch: 84/1000... Loss: 0.421672... Val Loss: 0.652858... Accuracy: 0.609467
Fold: 7/10... Epoch: 85/1000... Loss: 0.417924... Val Loss: 0.652654... Accuracy: 0.610651
Fold: 7/10... Epoch: 86/1000... Loss: 0.414400... Val Loss: 0.652466... Accuracy: 0.610651
Fold: 7/10... Epoch: 87/1000... Loss: 0.410476... Val Loss: 0.652231... Accuracy: 0.611834
Fold: 7/10... Epoch: 88/1000... Loss: 0.404597... Val Loss: 0.652050... Accuracy: 0.611834
Fold: 7/10... Epoch: 89/1000... Loss: 0.398548... Val Loss: 0.651540... Accuracy: 0.611834
Fold: 7/10... Epoch: 90/1000... Loss: 0.399144... Val Loss: 0.651709... Accuracy: 0.611834
Fold: 7/10... Epoch: 91/1000... Loss: 0.395789... Val Loss: 0.651516... Accuracy: 0.611834

Fold: 9/10... Epoch: 2/1000... Loss: 0.731501... Val Loss: 0.690465... Accuracy: 0.525403
Fold: 9/10... Epoch: 3/1000... Loss: 0.708383... Val Loss: 0.688333... Accuracy: 0.532838
Fold: 9/10... Epoch: 4/1000... Loss: 0.695967... Val Loss: 0.687338... Accuracy: 0.536555
Fold: 9/10... Epoch: 5/1000... Loss: 0.686817... Val Loss: 0.686607... Accuracy: 0.550186
Fold: 9/10... Epoch: 6/1000... Loss: 0.679047... Val Loss: 0.685783... Accuracy: 0.558860
Fold: 9/10... Epoch: 7/1000... Loss: 0.672224... Val Loss: 0.684473... Accuracy: 0.555143
Fold: 9/10... Epoch: 8/1000... Loss: 0.664995... Val Loss: 0.682536... Accuracy: 0.571252
Fold: 9/10... Epoch: 9/1000... Loss: 0.657685... Val Loss: 0.680077... Accuracy: 0.583643
Fold: 9/10... Epoch: 10/1000... Loss: 0.649585... Val Loss: 0.677145... Accuracy: 0.574969
Fold: 9/10... Epoch: 11/1000... Loss: 0.640037... Val Loss: 0.674079... Accuracy: 0.581165
Fold: 9/10... Epoch: 12/1000... Loss: 0.630390... Val Loss: 0.671178... Accuracy: 0.587361
Fold: 9

Fold: 10/10... Epoch: 11/1000... Loss: 0.624863... Val Loss: 0.662969... Accuracy: 0.596154
Fold: 10/10... Epoch: 12/1000... Loss: 0.592551... Val Loss: 0.658307... Accuracy: 0.602564
Fold: 10/10... Epoch: 13/1000... Loss: 0.563597... Val Loss: 0.654526... Accuracy: 0.602564
Fold: 10/10... Epoch: 14/1000... Loss: 0.542757... Val Loss: 0.651903... Accuracy: 0.606410
Fold: 10/10... Epoch: 15/1000... Loss: 0.528591... Val Loss: 0.650086... Accuracy: 0.610256
Fold: 10/10... Epoch: 16/1000... Loss: 0.518163... Val Loss: 0.648748... Accuracy: 0.617949
Fold: 10/10... Epoch: 17/1000... Loss: 0.510779... Val Loss: 0.647632... Accuracy: 0.617949
Fold: 10/10... Epoch: 18/1000... Loss: 0.505451... Val Loss: 0.646632... Accuracy: 0.614103
Fold: 10/10... Epoch: 19/1000... Loss: 0.501689... Val Loss: 0.645742... Accuracy: 0.615385
Fold: 10/10... Epoch: 20/1000... Loss: 0.499210... Val Loss: 0.644898... Accuracy: 0.612821
Fold: 10/10... Epoch: 21/1000... Loss: 0.496949... Val Loss: 0.644069... Accurac

In [12]:
print("Mean accuracy across folds: {:.6f}...".format(np.mean(all_accs)),
      "Mean val loss across folds: {:.6f}".format(np.mean(all_losses)))

print("Best accuracy across folds: {:.6f}...".format(np.max(all_accs)),
      "Lowest val loss across folds: {:.6f}".format(np.min(all_losses)))

print("Worst accuracy across folds: {:.6f}...".format(np.min(all_accs)),
      "Highest val loss across folds: {:.6f}".format(np.max(all_losses)))

Mean accuracy across folds: 0.639571... Mean val loss across folds: 0.638833
Best accuracy across folds: 0.668281... Lowest val loss across folds: 0.617787
Worst accuracy across folds: 0.614137... Highest val loss across folds: 0.659856


Validation accuracy should be around 64%, which is comparable to what I obtained with the SVM model (66%). Not too bad given the small number of trials. 