# Load data
This code loads the data for 1 subject with the help of methods defined in data_handler.ipynb.
Also prints distribution of the split just to make sure it's stratified.

In [1]:
import ipynb
import torch
from ipynb.fs.full.data_handler import load_subject, load_subject_non_downsampled, extract_labels, extract_action_interval, split_data, split_info, to_device, get_innerspeech
from ipynb.fs.full.train_model import train_model, accuracy_check

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data, description = load_subject_non_downsampled(1)
data, description = get_innerspeech(data, description)
labels = extract_labels(description)
#data_interval = extract_action_interval(data, hz = 1024)

train_data, val_data, test_data, train_labels, val_labels, test_labels = split_data(data, labels)

train_data, val_data, test_data, train_labels, val_labels, test_labels = to_device(train_data, val_data, test_data, train_labels, val_labels, test_labels, device)
split_info(train_data, val_data, test_data, train_labels, val_labels, test_labels)

[39, 39, 40, 40] 
 [6, 6, 5, 5] 
 [5, 5, 5, 5]
24.68354430379747 %  24.68354430379747 %  25.31645569620253 %  25.31645569620253 % 
27.27272727272727 %  27.27272727272727 %  22.727272727272727 %  22.727272727272727 % 
25.0 %  25.0 %  25.0 %  25.0 % 


# EEGNet
Implementation of EEGNetfrom (Lawhern et. al.) for the use on "Thinking out loud" dataset.

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim


class EEGNet(nn.Module):
    def __init__(self, hz, interval = "full", dropout = 0.5):
        super(EEGNet, self).__init__()
        self.f1 = 16
        self.d = 4
        self.f2 = self.d*self.f1

        # Block 1
        if hz == 254:
            self.conv1 = nn.Conv2d(in_channels=1, out_channels = self.f1, kernel_size = (1, 128), padding = 0, bias = False) # 128 = half sampling rate, reduces to 2Hz, 16 is probably F1, the number of filters (output channels)
        elif hz == 1024:
            self.conv1 = nn.Conv2d(in_channels=1, out_channels = self.f1, kernel_size = (1, 512), padding = 0, bias = False) # 128 = half sampling rate, reduces to 2Hz, 16 is probably F1, the number of filters (output channels)
        self.batchnorm1 = nn.BatchNorm2d(self.f1, False)

        self.conv2 = nn.Conv2d(in_channels = self.f1, out_channels = self.f1*self.d, kernel_size = (128, 1), padding = 0, bias = False) # row, column] 128 = number of sensors for the depth wise convolution
        self.batchnorm2 = nn.BatchNorm2d(self.f1*self.d, False) # Not sure this is the right number of features.
        # Elu
        self.elu = nn.ELU()
        # Dropout
        self.dropout = nn.Dropout(dropout)

        self.avgPool1 = nn.AvgPool2d((1,4)) # Reduces sampling rate to 64hz

        # Block 2

        # Separable convolution
        if hz == 254:
            self.depthwise = nn.Conv2d(in_channels = self.f1*self.d, out_channels = self.f1*self.d, kernel_size = (1, 32),
                        bias=False, groups = self.f1*self.d, padding = (0, 16//2)) # Captures 500ms of data at sampling rate 64Hz (32 Hz is 500ms of that)
        elif hz == 1024:
            self.depthwise = nn.Conv2d(in_channels = self.f1*self.d, out_channels = self.f1*self.d, kernel_size = (1, 128),
                        bias=False, groups = self.f1*self.d, padding = (0, 16//2)) # Captures 500ms of data at sampling rate 64Hz (32 Hz is 500ms of that)
        self.pointwise = nn.Conv2d(in_channels = self.f1*self.d, out_channels = self.f2, kernel_size = 1) # This results in f2 channels of the output here.
        
        self.batchnorm3 = nn.BatchNorm2d(self.f2, False)

        self.avgPool2 = nn.AvgPool2d((1,8))
        
        # Fully connected classifier
        if interval == "action":
            if hz == 254:
                self.classifier = nn.Linear(896,4, bias = False)
            elif hz == 1024:
                self.classifier = nn.Linear(3200,4, bias = False)
        elif interval == "full":
            if hz == 254:
                self.classifier = nn.Linear(1920,4, bias = False)
            elif hz == 1024:
                self.classifier = nn.Linear(7296, 4, bias = False)
        self.softmax = nn.LogSoftmax(dim = 1)


    def forward(self, x): # Dropout 0.5 dropout for within-subject and 0.25 for cross-subject
        # Block 1
        res = self.conv1(x)
        res = self.batchnorm1(res)
        res = self.conv2(res)
        res = torch.renorm(res, p=2, dim=0, maxnorm=1)
        res = self.batchnorm2(res)
        #res = F.dropout(0.25)
        res = self.elu(res)
        res = self.avgPool1(res)
        res = self.dropout(res)
        # Block 2
        res = self.depthwise(res)
        res = self.pointwise(res)
        res = self.batchnorm3(res)
        res = self.elu(res)
        res = self.avgPool2(res)
        res = self.dropout(res)
        # Classifier
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res


## Testing code
Code for testing the EEGNet implementation

In [5]:
network = EEGNet(hz = 1024, interval = "full").float().to(device)
accuracy_check(network, train_data, train_labels)

print("####### TRAINING #######")
op = optim.Adam(params = network.parameters(), lr = 0.0001)
lossf = nn.NLLLoss()
train_model(network, train_data = train_data, train_labels = train_labels, val_data = val_data, val_labels = val_labels, epochs = 50, batch_size = 10, loss_func = lossf, optimizer = op)

accuracy_check(network, train_data, train_labels)

ACCURACY: 0.2468354430379747
####### TRAINING #######
Epoch	 train loss	 validation loss
0 	  1.4979248046875 	 1.3854141235351562
1 	  1.203192647298177 	 1.3860247135162354
2 	  1.023139444986979 	 1.387166976928711
3 	  0.9223775863647461 	 1.389386534690857
4 	  0.8008235295613607 	 1.3947343826293945
5 	  0.7182073593139648 	 1.406744360923767
6 	  0.611872673034668 	 1.4423842430114746
7 	  0.538532002766927 	 1.4882687330245972
8 	  0.4843051592508952 	 1.5046709775924683
9 	  0.4180446624755859 	 1.506285548210144
10 	  0.378174877166748 	 1.5722187757492065
11 	  0.31573301951090493 	 1.5447083711624146
12 	  0.25503455797831215 	 1.6191210746765137
13 	  0.20807323455810547 	 1.6635719537734985
14 	  0.19754304885864257 	 1.6640446186065674
15 	  0.16564200719197592 	 1.683122992515564
16 	  0.13731668790181478 	 1.704674482345581
17 	  0.11952701409657797 	 1.7069644927978516
18 	  0.11170883178710937 	 1.7412528991699219
19 	  0.0937974770863851 	 1.771518349647522
20 	  0.

In [6]:
accuracy_check(network, train_data, train_labels)
accuracy_check(network, val_data, val_labels)
accuracy_check(network, test_data, test_labels)

ACCURACY: 0.9620253164556962
ACCURACY: 0.3181818181818182
ACCURACY: 0.3
