In [1]:
# Load data
import numpy as np
from sklearn.model_selection import train_test_split
from Inner_Speech_Dataset.Python_Processing.Data_extractions import  Extract_data_from_subject

# Load all data for subject
def load_subject(subject_nr):
    datatype = "EEG"
    root_dir = "../dataset"

    data, description = Extract_data_from_subject(root_dir, subject_nr, datatype)
    return data, description

# Extract labels from the description
def extract_labels(desc):
    return desc[:,1]


# Test when extracting only the action interval
def extract_action_interval(data):
    return data[:,:,254:890]

def split_data(data, labels):
    trainv_data, test_data, trainv_labels, test_labels = train_test_split(data, labels, test_size = 0.2, random_state = 1)
    train_data, val_data, train_labels, val_labels = train_test_split(trainv_data, trainv_labels, test_size = 0.25, random_state = 1)

    return train_data, val_data, test_data, train_labels, val_labels, test_labels




# Load, extract and split data
data, description = load_subject(1)
labels = extract_labels(description)
#new_data = extract_action_interval(data)
#print(new_data.shape)


train_data, val_data, test_data, train_labels, val_labels, test_labels = split_data(data, labels)

# EEGNet
Implementation of EEGNetfrom (Lawhern et. al.) for the use on "Thinking out loud" dataset.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import os
import copy

class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.f1 = 16
        self.d = 4
        self.f2 = self.d*self.f1

        # Block 1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels = self.f1, kernel_size = (1, 128), padding = 0, bias = False) # 128 = half sampling rate, reduces to 2Hz, 16 is probably F1, the number of filters (output channels)
        self.batchnorm1 = nn.BatchNorm2d(self.f1, False)

        self.conv2 = nn.Conv2d(in_channels = self.f1, out_channels = self.f1*self.d, kernel_size = (128, 1), padding = 0, bias = False) # row, column] 128 = number of sensors for the depth wise convolution
        self.batchnorm2 = nn.BatchNorm2d(self.f1*self.d, False) # Not sure this is the right number of features.
        self.elu = nn.ELU()
        self.avgPool1 = nn.AvgPool2d((1,4)) # Reduces sampling rate to 64hz

        # Block 2

        # Separable convolution
        self.depthwise = nn.Conv2d(in_channels = self.f1*self.d, out_channels = self.f1*self.d, kernel_size = (1, 32),
                        bias=False, groups = self.f1*self.d, padding = (0, 16//2)) # Captures 500ms of data at sampling rate 64Hz (32 Hz is 500ms of that)
        self.pointwise = nn.Conv2d(in_channels = self.f1*self.d, out_channels = self.f2, kernel_size = 1) # This results in f2 channels of the output here.
        
        self.batchnorm3 = nn.BatchNorm2d(self.f2, False)

        self.avgPool2 = nn.AvgPool2d((1,8))
        
        # Fully connected classifier
        self.classifier = nn.Linear(1920,4, bias = False)
        self.softmax = nn.Softmax()


    def forward(self, x, dropout = 0.5): # Dropout 0.5 dropout for within-subject and 0.25 for cross-subject
        # Block 1
        res = self.conv1(x)
        res = self.batchnorm1(res)
        res = self.conv2(res)
        res = torch.renorm(res, p=2, dim=0, maxnorm=1)
        res = self.batchnorm2(res)
        #res = F.dropout(0.25)
        res = self.elu(res)
        res = self.avgPool1(res)
        res = F.dropout(res, dropout)
        # Block 2
        res = self.depthwise(res)
        res = self.pointwise(res)
        res = self.batchnorm3(res)
        res = self.elu(res)
        res = self.avgPool2(res)
        res = F.dropout(res, dropout)
        # Classifier
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res

    def train_model(self, data, labels, epochs, batch_size, loss_func, optimizer):
        print("Epoch\t train loss\t validation loss\t train acc\t validation acc")
        best_model = copy.deepcopy(self.state_dict())
        best_loss = 500 # bad value for dis

        for epoch in range(epochs):        
            epoch_loss = 0
            self.train() # Set model to train mode
            
            for i in range(len(data)//batch_size): # BATCH SIZE MUST BE EVEN DIVIDER OF DATA LEN, otherwise we miss stuff here
                start = i*batch_size
                end = (i+1)*batch_size

                train_inputs = torch.FloatTensor(data[start:end])
                train_labels = torch.LongTensor(labels[start:end])
                train_outputs = self(train_inputs)

                loss = loss_func(train_outputs, train_labels)
                
                #print("LOSS: ", loss)
                epoch_loss += loss
                loss.backward()
                optimizer.step()
            
            # Check for new best model, this should be on val_data instead
            if epoch_loss < best_loss:
                best_model = copy.deepcopy(self.state_dict())

            print("EPOCH LOSS:", epoch_loss)

        self.load_state_dict(best_model) # Set model to best performing one.
                



# TODO: Run on cuda.
# TODO: Add validation data and labels to train_model method params
# TODO: Implement early stopping.
# TODO: Graphing of train and val loss.

## Testing code
Code for testing the EEGNet implementation

In [3]:
train_data = torch.tensor(train_data)
train_data = torch.unsqueeze(train_data,1).float()
train_labels = torch.tensor(train_labels).long()

val_data = torch.tensor(val_data)
val_data = torch.unsqueeze(val_data,1).float()
val_labels = torch.tensor(val_labels).long()

test_data = torch.tensor(test_data)
test_data = torch.unsqueeze(test_data,1).float()
test_labels = torch.tensor(test_labels).long()


# Uncomment this if running on only 1 trial
#train_data = torch.unsqueeze(torch.unsqueeze(d,0),0).float() # Unsqueeze adds a wrapper dimension of 1s


In [4]:
# Dumb checking of initial weights
def accuracy_check(data, labels):
    r = network.forward(data)
    p = torch.max(r,1)[1]
    correct = 0.0
    for i in range(len(p)):
        if p[i] == labels[i]:
            correct += 1
    print("ACCURACY:", correct/len(p))


network = EEGNet().float()
accuracy_check(train_data, train_labels)

print("####### TRAINING #######")
op = optim.Adam(params = network.parameters(), lr = 0.0001)
lossf = nn.NLLLoss()
network.train_model(data = train_data, labels = train_labels, epochs = 20, batch_size = 10, loss_func = lossf, optimizer = op)

accuracy_check(train_data, train_labels)

  res = self.softmax(res)


ACCURACY: 0.21666666666666667
####### TRAINING #######
Epoch	 train loss	 validation loss	 train acc	 validation acc
EPOCH LOSS: tensor(-7.4676, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-9.7885, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-11.0765, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-13.5384, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-15.7294, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-17.4631, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-19.6669, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-20.9731, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-21.6067, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-22.8588, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-21.9961, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-22.3128, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-22.3893, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-23.0253, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-23.1827, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-23.9734, grad_fn=<AddBackward0>)
EPOCH LOSS: tensor(-23.4140, grad_fn=

In [5]:
accuracy_check(train_data, train_labels)
accuracy_check(val_data, val_labels)
accuracy_check(test_data, test_labels)

  res = self.softmax(res)


ACCURACY: 0.7866666666666666
ACCURACY: 0.22
ACCURACY: 0.28
