# Shallow CNN
Implementation of the shallow CNN structure from (Schirrmeister et. al.) for the use on "Thinking out loud" dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class ShallowCNN(nn.Module):
    def __init__(self, hz, interval, dropout = 0.01, bias = True):
        super(ShallowCNN, self).__init__()
        first_channels = 40
        

        # Temporal convolution
        self.tempconv = nn.Conv2d(in_channels = 1, out_channels = first_channels, kernel_size = (1, 25), padding = 0, bias = False)
        # Spatial convolution
        self.spatconv = nn.Conv2d(in_channels = first_channels, out_channels = 40, kernel_size=(128,1), padding = 0, bias = False)
        # Batch normalization
        self.batchnorm = nn.BatchNorm2d(40, False)
        # ELU
        self.elu = nn.ELU()
        # Dropout
        self.dropout = nn.Dropout(dropout)
        # Mean pooling
        self.meanpool = nn.AvgPool2d(kernel_size = (1,75), stride = (1,15)) # This 15 might be a 75 :(

        # Classifier
        if interval == "action":
            if hz == 254:
                self.classifier = nn.Linear(1440,4, bias = bias)
            elif hz == 1024:
                self.classifier = nn.Linear(6600,4, bias = bias)
        elif interval == "full":
            if hz == 254:
                self.classifier = nn.Linear(2840,4, bias = bias)  
            elif hz == 1024:
                self.classifier = nn.Linear(12040,4, bias = bias)
        
        # Softmax
        self.softmax = nn.LogSoftmax(dim = 1)

        
    def forward(self, x):
        res = self.tempconv(x)
        res = self.spatconv(res)
        res = self.batchnorm(res)
        res = self.elu(res)
        res = self.meanpool(res)
        res = self.dropout(res)
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res

class ShallowCNNv1(nn.Module):
    def __init__(self, hz, interval, dropout = 0.01, bias = True):
        super(ShallowCNNv1, self).__init__()
        first_channels = 60
        

        # Temporal convolution
        self.tempconv = nn.Conv2d(in_channels = 1, out_channels = first_channels, kernel_size = (1, 25), padding = 0, bias = False)
        # Spatial convolution
        self.spatconv = nn.Conv2d(in_channels = first_channels, out_channels = first_channels, kernel_size=(128,1), padding = 0, bias = False)
        # Batch normalization
        self.batchnorm = nn.BatchNorm2d(first_channels, False)
        # ELU
        self.elu = nn.ELU()
        # Dropout
        self.dropout = nn.Dropout(dropout)
        # Mean pooling
        self.meanpool = nn.AvgPool2d(kernel_size = (1,75), stride = (1,15)) # This 15 might be a 75 :(

        # Classifier
        if interval == "action":
            if hz == 254:
                self.classifier = nn.Linear(1440,4, bias = bias)
            elif hz == 1024:
                self.classifier = nn.Linear(6600,4, bias = bias)
        elif interval == "full":
            if hz == 254:
                self.classifier = nn.Linear(5680,4, bias = bias)  
            elif hz == 1024:
                self.classifier = nn.Linear(12040,4, bias = bias) # 12040 for architecture v0, 18060 for v1
        
        # Softmax
        self.softmax = nn.LogSoftmax(dim = 1)

        
    def forward(self, x):
        res = self.tempconv(x)
        res = self.spatconv(res)
        res = self.batchnorm(res)
        res = self.elu(res)
        res = self.meanpool(res)
        res = self.dropout(res)
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res


class ShallowCNNv2(nn.Module):
    def __init__(self, hz, interval, dropout = 0.01, bias = True):
        super(ShallowCNNv2, self).__init__()
        first_channels = 40
        

        # Temporal convolution
        self.tempconv = nn.Conv2d(in_channels = 1, out_channels = first_channels, kernel_size = (1, 25), padding = 0, bias = False)
        # Spatial convolution
        self.spatconv = nn.Conv2d(in_channels = first_channels, out_channels = first_channels, kernel_size=(128,1), padding = 0, bias = False)
        # Batch normalization
        self.batchnorm = nn.BatchNorm2d(first_channels, False)
        # LeakyRelu
        self.lrelu = nn.LeakyReLU()
        # Dropout
        self.dropout = nn.Dropout(dropout)
        # Mean pooling
        self.meanpool = nn.AvgPool2d(kernel_size = (1,75), stride = (1,15)) # This 15 might be a 75 :(

        # Classifier
        if interval == "action":
            if hz == 254:
                self.classifier = nn.Linear(1440,4, bias = bias)
            elif hz == 1024:
                self.classifier = nn.Linear(6600,4, bias = bias)
        elif interval == "full":
            if hz == 254:
                self.classifier = nn.Linear(2840,4, bias = bias)  
            elif hz == 1024:
                self.classifier = nn.Linear(12040,4, bias = bias)
        
        # Softmax
        self.softmax = nn.LogSoftmax(dim = 1)

        
    def forward(self, x):
        res = self.tempconv(x)
        res = self.spatconv(res)
        res = self.batchnorm(res)
        res = self.lrelu(res)
        res = self.meanpool(res)
        res = self.dropout(res)
        res = torch.flatten(res, start_dim=1)
        res = self.classifier(res)
        res = self.softmax(res)
        return res