In [1]:
import ipynb
import torch
from ipynb.fs.full.data_handler import load_subject, extract_labels, extract_action_interval, split_data, split_info, to_device
from ipynb.fs.full.train_model import train_model, accuracy_check

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data, description = load_subject(1)
labels = extract_labels(description)
data_interval = extract_action_interval(data)

train_data, val_data, test_data, train_labels, val_labels, test_labels = split_data(data_interval, labels)

train_data, val_data, test_data, train_labels, val_labels, test_labels = to_device(train_data, val_data, test_data, train_labels, val_labels, test_labels, device)
split_info(train_data, val_data, test_data, train_labels, val_labels, test_labels)

[99, 99, 99, 99] 
 [13, 13, 14, 14] 
 [13, 13, 12, 12]
25.0 %  25.0 %  25.0 %  25.0 % 
24.074074074074073 %  24.074074074074073 %  25.925925925925924 %  25.925925925925924 % 
26.0 %  26.0 %  24.0 %  24.0 % 


# Shallow CNN
Implementation of the shallow CNN structure from (Schirrmeister et. al.) for the use on "Thinking out loud" dataset.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

class ShallowCNN(nn.Module):
    def __init__(self, interval = "full", dropout = 0.5):
        super(ShallowCNN, self).__init__()
        first_channels = 40
        

        # Temporal convolution
        self.tempconv = nn.Conv2d(in_channels = 1, out_channels = first_channels, kernel_size = (1, 25), padding = 0, bias = False)
        # Spatial convolution
        self.spatconv = nn.Conv2d(in_channels = first_channels, out_channels = 40, kernel_size=(128,1), padding = 0, bias = False)
        # Batch normalization
        self.batchnorm = nn.BatchNorm2d(40, False)
        # ELU
        self.elu = nn.ELU()
        # Dropout
        self.dropout = nn.Dropout(dropout)
        # Mean pooling
        self.meanpool = nn.AvgPool2d(kernel_size = (1,75), stride = (1,15)) # This 15 might be a 75 :(

        # Classifier
        if interval == "action":
            self.classifier = nn.Linear(1440,4, bias = False)
        elif interval == "full":
            self.classifier = nn.Linear(2840,4, bias = False)
        # Softmax
        self.softmax = nn.LogSoftmax(dim = 1)

        
    def forward(self, x):
        res = self.tempconv(x)
        res = self.spatconv(res)
        res = self.batchnorm(res)
        res = self.elu(res)
        res = self.meanpool(res)
        res = self.dropout(res)
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res


In [6]:
network = ShallowCNN(interval = "action").float().to(device)
accuracy_check(network, train_data, train_labels)

print("####### TRAINING #######")
op = optim.Adam(params = network.parameters(), lr = 0.0001)
lossf = nn.NLLLoss()
train_model(network, train_data = train_data, train_labels = train_labels, val_data = val_data, val_labels = val_labels, epochs = 50, batch_size = 10, loss_func = lossf, optimizer = op)

accuracy_check(network, train_data, train_labels)

ACCURACY: 0.25
####### TRAINING #######
Epoch	 train loss	 validation loss
0 	  1.4495796790489783 	 1.3862603902816772
1 	  1.3266488099709535 	 1.386222004890442
2 	  1.2864039494441106 	 1.3860478401184082
3 	  1.2438658689841247 	 1.3855808973312378
4 	  1.1853823539538262 	 1.3835722208023071
5 	  1.1565294510278947 	 1.4194326400756836
6 	  1.1126894828600762 	 1.4478398561477661
7 	  1.0734599186823919 	 1.502573847770691
8 	  1.0422704647748897 	 1.5182851552963257
9 	  0.984276111309345 	 1.5553083419799805
10 	  0.9529058016263522 	 1.559572696685791
11 	  0.9428936884953425 	 1.6133947372436523
12 	  0.8807775057279147 	 1.597985863685608
13 	  0.83496827345628 	 1.6592047214508057
14 	  0.8217211992312701 	 1.685848355293274
15 	  0.7803047375801282 	 1.7050511837005615
16 	  0.7454075446495643 	 1.7079063653945923
17 	  0.7248755234938401 	 1.713695764541626
18 	  0.6899392054631159 	 1.7811042070388794
19 	  0.681616465250651 	 1.8257538080215454
20 	  0.6263808225974058 

In [7]:
accuracy_check(network, train_data, train_labels)
accuracy_check(network, val_data, val_labels)
accuracy_check(network, test_data, test_labels)

ACCURACY: 0.9772727272727273
ACCURACY: 0.3148148148148148
ACCURACY: 0.26
