In [2]:
import torch
import torch.nn as nn
from torchsummary import summary
from lstm_dpcnn import * 

In [3]:
class LSTM_DPCNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.lstm = nn.LSTM(input_size=256, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True) # (b, seq_len, input_size) => (b, seq_len, hidden_size*2)
        self.dpcnn = DPCNN(1)
        self.linear = nn.Linear(512, 9)
        
    def forward(self, x):
        """
        each forward takes a 256 vector as input 
        """
        lstm_output, (_, _) = self.lstm(x) # (b, seq_len, hidden_size*2)
        lstm_output = torch.transpose(lstm_output, 2, 1) 
        lstm_output = lstm_output.view(lstm_output.shape[0], -1) # padding: (b, seq_len*hidden_size*2)
        dpcnn_output = self.dpcnn(x) # (b, 256)
        output = torch.cat((lstm_output, dpcnn_output), 1) #(b, seqlen*hidden_size*2+256) 
        output = self.linear(output)
        return output

In [4]:
model = LSTM_DPCNN()

In [5]:
example = torch.randn((3, 1, 256))

In [8]:
summary(model, example, device="cpu", depth=5)

Layer (type:depth-idx)                   Output Shape              Param #
├─LSTM: 1-1                              [-1, 1, 256]              790,528
├─DPCNN: 1-2                             [-1, 256]                 --
|    └─BasicBlock: 2-1                   [-1, 256, 256]            --
|    |    └─ReLU: 3-1                    [-1, 1, 256]              --
|    |    └─Conv1d: 3-2                  [-1, 256, 256]            1,024
|    |    └─ReLU: 3-3                    [-1, 256, 256]            --
|    |    └─Conv1d: 3-4                  [-1, 256, 256]            196,864
|    |    └─Conv1d: 3-5                  [-1, 256, 256]            512
|    └─Sequential: 2-2                   [-1, 256, 2]              --
|    |    └─Sequential: 3-6              [-1, 256, 128]            --
|    |    |    └─MaxPool1d: 4-1          [-1, 256, 128]            --
|    |    |    └─BasicBlock: 4-2         [-1, 256, 128]            --
|    |    |    |    └─ReLU: 5-1          [-1, 256, 128]            --
|

Layer (type:depth-idx)                   Output Shape              Param #
├─LSTM: 1-1                              [-1, 1, 256]              790,528
├─DPCNN: 1-2                             [-1, 256]                 --
|    └─BasicBlock: 2-1                   [-1, 256, 256]            --
|    |    └─ReLU: 3-1                    [-1, 1, 256]              --
|    |    └─Conv1d: 3-2                  [-1, 256, 256]            1,024
|    |    └─ReLU: 3-3                    [-1, 256, 256]            --
|    |    └─Conv1d: 3-4                  [-1, 256, 256]            196,864
|    |    └─Conv1d: 3-5                  [-1, 256, 256]            512
|    └─Sequential: 2-2                   [-1, 256, 2]              --
|    |    └─Sequential: 3-6              [-1, 256, 128]            --
|    |    |    └─MaxPool1d: 4-1          [-1, 256, 128]            --
|    |    |    └─BasicBlock: 4-2         [-1, 256, 128]            --
|    |    |    |    └─ReLU: 5-1          [-1, 256, 128]            --
|