In [18]:
import ipynb
import torch
from ipynb.fs.full.data_handler import load_subject, extract_labels, extract_action_interval, split_data, split_info, to_device, get_innerspeech
from ipynb.fs.full.train_model import train_model, accuracy_check

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data, description = load_subject(1)
data, description = get_innerspeech(data, description)
labels = extract_labels(description)
data_interval = extract_action_interval(data)

train_data, val_data, test_data, train_labels, val_labels, test_labels = split_data(data_interval, labels)

train_data, val_data, test_data, train_labels, val_labels, test_labels = to_device(train_data, val_data, test_data, train_labels, val_labels, test_labels, device)
split_info(train_data, val_data, test_data, train_labels, val_labels, test_labels)

[30, 30, 30, 30] 
 [10, 10, 10, 10] 
 [10, 10, 10, 10]
25.0 %  25.0 %  25.0 %  25.0 % 
25.0 %  25.0 %  25.0 %  25.0 % 
25.0 %  25.0 %  25.0 %  25.0 % 


# Shallow CNN
Implementation of the shallow CNN structure from (Schirrmeister et. al.) for the use on "Thinking out loud" dataset.

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim

class ShallowCNN(nn.Module):
    def __init__(self, interval = "full", dropout = 0.5):
        super(ShallowCNN, self).__init__()
        first_channels = 40
        

        # Temporal convolution
        self.tempconv = nn.Conv2d(in_channels = 1, out_channels = first_channels, kernel_size = (1, 25), padding = 0, bias = False)
        # Spatial convolution
        self.spatconv = nn.Conv2d(in_channels = first_channels, out_channels = 40, kernel_size=(128,1), padding = 0, bias = False)
        # Batch normalization
        self.batchnorm = nn.BatchNorm2d(40, False)
        # ELU
        self.elu = nn.ELU()
        # Dropout
        self.dropout = nn.Dropout(dropout)
        # Mean pooling
        self.meanpool = nn.AvgPool2d(kernel_size = (1,75), stride = (1,15)) # This 15 might be a 75 :(

        # Classifier
        if interval == "action":
            self.classifier = nn.Linear(1440,4, bias = False)
        elif interval == "full":
            self.classifier = nn.Linear(2840,4, bias = False)
        # Softmax
        self.softmax = nn.LogSoftmax(dim = 1)

        
    def forward(self, x):
        res = self.tempconv(x)
        res = self.spatconv(res)
        res = self.batchnorm(res)
        res = self.elu(res)
        res = self.meanpool(res)
        res = self.dropout(res)
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res


In [20]:
network = ShallowCNN(interval = "action").float().to(device)
accuracy_check(network, train_data, train_labels)

print("####### TRAINING #######")
op = optim.Adam(params = network.parameters(), lr = 0.0001)
lossf = nn.NLLLoss()
train_model(network, train_data = train_data, train_labels = train_labels, val_data = val_data, val_labels = val_labels, epochs = 20, batch_size = 10, loss_func = lossf, optimizer = op)

accuracy_check(network, train_data, train_labels)

ACCURACY: 0.225
####### TRAINING #######
Epoch	 train loss	 validation loss
0 	  1.3937729199727376 	 1.3862959146499634
1 	  1.2828237215677898 	 1.3862979412078857
2 	  1.2208924293518066 	 1.386300802230835
3 	  1.174707333246867 	 1.3863054513931274
4 	  1.1091994444529216 	 1.386319875717163
5 	  1.0082961718241374 	 1.3863415718078613
6 	  1.0137550830841064 	 1.3863837718963623
7 	  0.9437669118245443 	 1.3864796161651611
8 	  0.8937573432922363 	 1.3867193460464478
9 	  0.840076764424642 	 1.3871347904205322
10 	  0.7819739182790121 	 1.3882390260696411
11 	  0.6986858050028483 	 1.389805793762207
12 	  0.6593041022618612 	 1.3922773599624634
13 	  0.6125906308492025 	 1.4010636806488037
14 	  0.548053503036499 	 1.4147851467132568
15 	  0.4983038504918416 	 1.4512081146240234
16 	  0.4855602979660034 	 1.5170865058898926
17 	  0.4292869567871094 	 1.6574971675872803
18 	  0.39979732036590576 	 1.8824495077133179
19 	  0.35339848200480145 	 2.186584711074829
ACCURACY: 0.9916666

In [21]:
accuracy_check(network, train_data, train_labels)
accuracy_check(network, val_data, val_labels)
accuracy_check(network, test_data, test_labels)

ACCURACY: 0.9916666666666667
ACCURACY: 0.2
ACCURACY: 0.225
