https://www.mdpi.com/1424-8220/16/1/115/html#B49-sensors-16-00115

In [None]:
import torch
from torch import nn

class Net(nn.Module):
    def __init__(self, BATCH_SIZE, SLIDING_WINDOW_LENGTH, NB_SENSOR_CHANNELS, NUM_FILTERS, FILTER_SIZE, NUM_UNITS_LSTM, NUM_CLASSES, FINAL_SEQUENCE_LENGTH):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, NUM_FILTERS, (FILTER_SIZE, 1))
        self.conv2 = nn.Conv2d(NUM_FILTERS, NUM_FILTERS, (FILTER_SIZE, 1))
        self.conv3 = nn.Conv2d(NUM_FILTERS, NUM_FILTERS, (FILTER_SIZE, 1))
        self.conv4 = nn.Conv2d(NUM_FILTERS, NUM_FILTERS, (FILTER_SIZE, 1))

        self.lstm1 = nn.LSTM(NUM_FILTERS, NUM_UNITS_LSTM, batch_first=True)
        self.lstm2 = nn.LSTM(NUM_UNITS_LSTM, NUM_UNITS_LSTM, batch_first=True)

        self.fc1 = nn.Linear(NUM_UNITS_LSTM, NUM_CLASSES)

        self.batch_size = BATCH_SIZE
        self.final_sequence_length = FINAL_SEQUENCE_LENGTH
        self.num_classes = NUM_CLASSES

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        # swap dimensions to make sequence second dimension
        x = x.permute(0, 2, 1, 3).contiguous()

        # flatten dimensions for LSTM
        b, s, c, _ = x.size()
        x = x.view(b, s, c)

        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)

        x = x.contiguous().view(-1, x.size(-1))
        x = self.fc1(x)

        x = x.view(self.batch_size, self.final_sequence_length, self.num_classes)

        # use only the last sequence
        x = x[:, -1, :]

        return x
