# Notebook used to contain all the classes of the classifiers
The 4 main classifier classes are located here

In [None]:
import torch 
import torch.nn as nn

# get the device type of machine
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class FCNet(nn.Module):
  """
  simple FC net that classifies the input data

  Consists of 3 fully connected layers with batch norm and relu

  """

  def __init__(self, input_channels, seq_length, classes, hidden_dims):
    super(FCNet, self).__init__()
    self.input_dim = input_channels
    assert len(hidden_dims) == 3
    self.main = nn.Sequential(
        nn.Linear(input_channels * seq_length, hidden_dims[0]),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(hidden_dims[0]),
        nn.Linear(hidden_dims[0], hidden_dims[1]),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(hidden_dims[1]),
        nn.Linear(hidden_dims[1], hidden_dims[2]),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(hidden_dims[2]),
        nn.Linear(hidden_dims[2], classes),
        nn.Sigmoid()
    )


  def forward(self, x):
    x = x.reshape(x.shape[0], -1)
    return self.main(x)

In [None]:
test_in = torch.rand((240, 22, 100)).to(device)
fc_test = FCNet(22, 100, 4, [1500, 1000, 500]).to(device)

In [None]:
class CNNValidator(nn.Module):
  """
  Simple network that uses 2 Convolution + max pool layers + relu, 1 conv + relu layer + 1 linear
  """
  def __init__(self, input_channels, classes, hidden_dims):
    super(CNNValidator, self).__init__()
    assert len(hidden_dims) == 3
    self.main = nn.Sequential(
        nn.Conv1d(22, hidden_dims[0], 3, stride=2, padding=2),
        nn.ReLU(True),
        nn.MaxPool1d(3, 1),
        nn.Conv1d(hidden_dims[0], hidden_dims[1], 3, stride=2, padding=2),
        nn.ReLU(True),
        nn.MaxPool1d(3, 1), 
        nn.Conv1d(hidden_dims[1], hidden_dims[2], 3, stride=2, padding=2),
        nn.ReLU(True)
    )
    self.linear = nn.Sequential(
        nn.Flatten(),
        nn.Linear(hidden_dims[2]*13, classes),
        nn.Softmax(dim=1)
        )
    
  def forward(self, x):
    a = self.main(x)
    return self.linear(a)


In [None]:
test_in = torch.rand((240, 22, 100)).to(device)
cnn_test = CNNValidator(22, 4, hidden_dims=[100, 20, 4]).to(device)

In [None]:
cnn_test(test_in).shape

torch.Size([240, 4])

In [None]:
class ConvLSTMValidator(nn.Module):
  """
  mix of cnn with conv and lstm
  """
  def __init__(self, input_channels, classes, hidden_dims):
    super(ConvLSTMValidator, self).__init__()
    assert len(hidden_dims) == 3
    self.main = nn.Sequential(
        nn.Conv1d(input_channels, hidden_dims[0], 3, stride=2, padding=2),
        nn.ReLU(True),
        nn.MaxPool1d(3, 1),
        nn.Conv1d(hidden_dims[0], hidden_dims[1], 3, stride=2, padding=2),
        nn.ReLU(True)
        )
    self.lstm = nn.LSTM(hidden_dims[1], hidden_dims[2], batch_first=True)
    self.out = nn.Sequential(
        nn.Flatten(),
        nn.Linear(26 * hidden_dims[2], classes),
        nn.Softmax(dim=1)
    )


  def forward(self, x):
    batch = x.shape[0]
    out_x = self.main(x)
    out_x_reshaped = out_x.permute(0, 2, 1)
    out_lstm, _ = self.lstm(out_x_reshaped)

    return self.out(out_lstm)

In [None]:
test_in = torch.rand((240, 22, 100)).to(device)
cnn_test = ConvLSTMValidator(22, 4, hidden_dims=[100, 20, 10]).to(device)

In [None]:
z = cnn_test(test_in)

In [None]:
z.shape

torch.Size([240, 4])