# Load data
This code loads the data for 1 subject with the help of methods defined in data_handler.ipynb.
Also prints distribution of the split just to make sure it's stratified.

In [1]:
import ipynb
import torch
from ipynb.fs.full.data_handler import load_subject, load_subject_non_downsampled, extract_labels, extract_action_interval, split_data, split_info, to_device, get_innerspeech
from ipynb.fs.full.train_model import train_model, accuracy_check

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data, description = load_subject_non_downsampled(1)
data, description = get_innerspeech(data, description)
labels = extract_labels(description)
data_interval = extract_action_interval(data, hz = 1024)

train_data, val_data, test_data, train_labels, val_labels, test_labels = split_data(data_interval, labels)

train_data, val_data, test_data, train_labels, val_labels, test_labels = to_device(train_data, val_data, test_data, train_labels, val_labels, test_labels, device)
split_info(train_data, val_data, test_data, train_labels, val_labels, test_labels)

1024
yo
[39, 39, 40, 40] 
 [6, 6, 5, 5] 
 [5, 5, 5, 5]
24.68354430379747 %  24.68354430379747 %  25.31645569620253 %  25.31645569620253 % 
27.27272727272727 %  27.27272727272727 %  22.727272727272727 %  22.727272727272727 % 
25.0 %  25.0 %  25.0 %  25.0 % 


# EEGNet
Implementation of EEGNetfrom (Lawhern et. al.) for the use on "Thinking out loud" dataset.

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim


class EEGNet(nn.Module):
    def __init__(self, hz, interval = "full", dropout = 0.5):
        super(EEGNet, self).__init__()
        self.f1 = 16
        self.d = 4
        self.f2 = self.d*self.f1

        # Block 1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels = self.f1, kernel_size = (1, 128), padding = 0, bias = False) # 128 = half sampling rate, reduces to 2Hz, 16 is probably F1, the number of filters (output channels)
        self.batchnorm1 = nn.BatchNorm2d(self.f1, False)

        self.conv2 = nn.Conv2d(in_channels = self.f1, out_channels = self.f1*self.d, kernel_size = (128, 1), padding = 0, bias = False) # row, column] 128 = number of sensors for the depth wise convolution
        self.batchnorm2 = nn.BatchNorm2d(self.f1*self.d, False) # Not sure this is the right number of features.
        # Elu
        self.elu = nn.ELU()
        # Dropout
        self.dropout = nn.Dropout(dropout)

        self.avgPool1 = nn.AvgPool2d((1,4)) # Reduces sampling rate to 64hz

        # Block 2

        # Separable convolution
        self.depthwise = nn.Conv2d(in_channels = self.f1*self.d, out_channels = self.f1*self.d, kernel_size = (1, 32),
                        bias=False, groups = self.f1*self.d, padding = (0, 16//2)) # Captures 500ms of data at sampling rate 64Hz (32 Hz is 500ms of that)
        self.pointwise = nn.Conv2d(in_channels = self.f1*self.d, out_channels = self.f2, kernel_size = 1) # This results in f2 channels of the output here.
        
        self.batchnorm3 = nn.BatchNorm2d(self.f2, False)

        self.avgPool2 = nn.AvgPool2d((1,8))
        
        # Fully connected classifier
        if interval == "action":
            if hz == 254:
                self.classifier = nn.Linear(896,4, bias = False)
            elif hz == 1024:
                self.classifier = nn.Linear(4736,4, bias = False)
        elif interval == "full":
            if hz == 254:
                self.classifier = nn.Linear(1920,4, bias = False)
            elif hz == 1024:
                self.classifier = nn.Linear(8832, 4, bias = False)
        self.softmax = nn.LogSoftmax(dim = 1)


    def forward(self, x): # Dropout 0.5 dropout for within-subject and 0.25 for cross-subject
        # Block 1
        res = self.conv1(x)
        res = self.batchnorm1(res)
        res = self.conv2(res)
        res = torch.renorm(res, p=2, dim=0, maxnorm=1)
        res = self.batchnorm2(res)
        #res = F.dropout(0.25)
        res = self.elu(res)
        res = self.avgPool1(res)
        res = self.dropout(res)
        # Block 2
        res = self.depthwise(res)
        res = self.pointwise(res)
        res = self.batchnorm3(res)
        res = self.elu(res)
        res = self.avgPool2(res)
        res = self.dropout(res)
        # Classifier
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res


## Testing code
Code for testing the EEGNet implementation

In [5]:
network = EEGNet(hz = 1024, interval = "action").float().to(device)
accuracy_check(network, train_data, train_labels)

print("####### TRAINING #######")
op = optim.Adam(params = network.parameters(), lr = 0.0001)
lossf = nn.NLLLoss()
train_model(network, train_data = train_data, train_labels = train_labels, val_data = val_data, val_labels = val_labels, epochs = 50, batch_size = 10, loss_func = lossf, optimizer = op)

accuracy_check(network, train_data, train_labels)

ACCURACY: 0.25316455696202533
####### TRAINING #######
Epoch	 train loss	 validation loss
0 	  1.5213467915852865 	 1.3896933794021606
1 	  1.2361810048421225 	 1.3929362297058105
2 	  1.1883989969889324 	 1.3941725492477417
3 	  1.0163252512613932 	 1.393800139427185
4 	  1.0060908635457357 	 1.3930872678756714
5 	  0.9035060246785481 	 1.391650915145874
6 	  0.8808631261189779 	 1.3821842670440674
7 	  0.7935773213704427 	 1.3507463932037354
8 	  0.7113124847412109 	 1.3240143060684204
9 	  0.6597771962483724 	 1.276682734489441
10 	  0.6008935928344726 	 1.277762770652771
11 	  0.5700113932291667 	 1.3094937801361084
12 	  0.5287777582804362 	 1.3027291297912598
13 	  0.4561431566874186 	 1.3192694187164307
14 	  0.41896645228068036 	 1.3259711265563965
15 	  0.3710137367248535 	 1.3339167833328247
16 	  0.32483587265014646 	 1.3592990636825562
17 	  0.3190416971842448 	 1.3632556200027466
18 	  0.26601128578186034 	 1.36430823802948
19 	  0.2714275360107422 	 1.3952056169509888
20 

In [7]:
accuracy_check(network, train_data, train_labels)
accuracy_check(network, val_data, val_labels)
accuracy_check(network, test_data, test_labels)

ACCURACY: 0.9556962025316456
ACCURACY: 0.36363636363636365
ACCURACY: 0.25
