In [30]:
import ipynb
import torch
from ipynb.fs.full.data_handler import load_subject, extract_labels, extract_action_interval, split_data, split_info, to_device, get_innerspeech
from ipynb.fs.full.train_model import train_model, accuracy_check

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data, description = load_subject(1)
data, description = get_innerspeech(data, description)
labels = extract_labels(description)
#data_interval = extract_action_interval(data)

train_data, val_data, test_data, train_labels, val_labels, test_labels = split_data(data, labels)

train_data, val_data, test_data, train_labels, val_labels, test_labels = to_device(train_data, val_data, test_data, train_labels, val_labels, test_labels, device)
split_info(train_data, val_data, test_data, train_labels, val_labels, test_labels)

[30, 30, 30, 30] 
 [10, 10, 10, 10] 
 [10, 10, 10, 10]
25.0 %  25.0 %  25.0 %  25.0 % 
25.0 %  25.0 %  25.0 %  25.0 % 
25.0 %  25.0 %  25.0 %  25.0 % 


# Shallow CNN
Implementation of the shallow CNN structure from (Schirrmeister et. al.) for the use on "Thinking out loud" dataset.

In [31]:
import torch
import torch.nn as nn
import torch.optim as optim

class ShallowCNN(nn.Module):
    def __init__(self, interval = "full", dropout = 0.5):
        super(ShallowCNN, self).__init__()
        first_channels = 40
        

        # Temporal convolution
        self.tempconv = nn.Conv2d(in_channels = 1, out_channels = first_channels, kernel_size = (1, 25), padding = 0, bias = False)
        # Spatial convolution
        self.spatconv = nn.Conv2d(in_channels = first_channels, out_channels = 40, kernel_size=(128,1), padding = 0, bias = False)
        # Batch normalization
        self.batchnorm = nn.BatchNorm2d(40, False)
        # ELU
        self.elu = nn.ELU()
        # Dropout
        self.dropout = nn.Dropout(dropout)
        # Mean pooling
        self.meanpool = nn.AvgPool2d(kernel_size = (1,75), stride = (1,15)) # This 15 might be a 75 :(

        # Classifier
        if interval == "action":
            self.classifier = nn.Linear(1440,4, bias = False)
        elif interval == "full":
            self.classifier = nn.Linear(2840,4, bias = False)
        # Softmax
        self.softmax = nn.LogSoftmax(dim = 1)

        
    def forward(self, x):
        res = self.tempconv(x)
        res = self.spatconv(res)
        res = self.batchnorm(res)
        res = self.elu(res)
        res = self.meanpool(res)
        res = self.dropout(res)
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res


In [32]:
network = ShallowCNN(interval = "full").float().to(device)
accuracy_check(network, train_data, train_labels)

print("####### TRAINING #######")
op = optim.Adam(params = network.parameters(), lr = 0.0001)
lossf = nn.NLLLoss()
train_model(network, train_data = train_data, train_labels = train_labels, val_data = val_data, val_labels = val_labels, epochs = 20, batch_size = 10, loss_func = lossf, optimizer = op)

accuracy_check(network, train_data, train_labels)

ACCURACY: 0.225
####### TRAINING #######
Epoch	 train loss	 validation loss
0 	  1.4338887532552083 	 1.3862944841384888
1 	  1.2662057081858318 	 1.386293888092041
2 	  1.1895232200622559 	 1.3862930536270142
3 	  1.1430881023406982 	 1.3862924575805664
4 	  1.0497805277506511 	 1.38629150390625
5 	  0.9624302387237549 	 1.3862929344177246
6 	  0.9066576162974039 	 1.3862935304641724
7 	  0.8030935128529867 	 1.3862957954406738
8 	  0.7336226304372152 	 1.386296033859253
9 	  0.677902619043986 	 1.3863813877105713
10 	  0.6215527455012003 	 1.3864519596099854
11 	  0.573543111483256 	 1.3866522312164307
12 	  0.500117301940918 	 1.387143611907959
13 	  0.46445616086324054 	 1.3902523517608643
14 	  0.43715937932332355 	 1.3924238681793213
15 	  0.3762841622034709 	 1.4045963287353516
16 	  0.3423775831858317 	 1.4311254024505615
17 	  0.31652243932088214 	 1.4832096099853516
18 	  0.25681595007578534 	 1.5713469982147217
19 	  0.23757944504419962 	 1.702101469039917
ACCURACY: 1.0


In [33]:
accuracy_check(network, train_data, train_labels)
accuracy_check(network, val_data, val_labels)
accuracy_check(network, test_data, test_labels)

ACCURACY: 1.0
ACCURACY: 0.25
ACCURACY: 0.3
