In [23]:
from torch import nn
import torch

# a = torch.Tensor([1, 2, 3,5,6])
# b = torch.Tensor([1, 2, 3])

# print(a + b)


class RNNModel(nn.Module):
    def __init__(self):
        super(RNNModel, self).__init__()
        self.input_size = 768
        self.hidden_size = 256
        self.num_layers = 2
        self.num_classes = 1

        self.lstm = nn.LSTM(
            self.input_size, self.hidden_size, self.num_layers, batch_first=True
        )

        self.linear_layers = nn.Sequential(
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(128, self.num_classes),
        )

    def forward(self, x):
        # Set initial hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to("cuda")
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to("cuda")

        # Forward propagate LSTM
        out, _ = self.lstm(
            x, (h0, c0)
        )  # out: tensor of shape (batch_size, seq_length, hidden_size)

        # Decode the hidden state of the last time step
        out = self.linear_layers(out[:, -1, :])
        return out


rnn_model = RNNModel()
rnn_model.to("cuda")

# x = torch.randn(1, 1, 768)
# y = rnn_model(x)
# from torchviz import make_dot
# make_dot(y.mean(), params=dict(rnn_model.named_parameters()))

RNNModel(
  (lstm): LSTM(768, 256, num_layers=2, batch_first=True)
  (linear_layers): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [26]:
from pytorch_model_summary import summary
x = torch.randn(1, 1, 768).to('cuda')
print(summary(rnn_model, x, show_input=True))

---------------------------------------------------------------------------------------------
      Layer (type)                               Input Shape         Param #     Tr. Param #
            LSTM-1     [1, 1, 768], [2, 1, 256], [2, 1, 256]       1,576,960       1,576,960
            ReLU-2                                  [1, 256]               0               0
          Linear-3                                  [1, 256]          32,896          32,896
         Dropout-4                                  [1, 128]               0               0
            ReLU-5                                  [1, 128]               0               0
          Linear-6                                  [1, 128]             129             129
Total params: 1,609,985
Trainable params: 1,609,985
Non-trainable params: 0
---------------------------------------------------------------------------------------------


In [19]:
from pytorch_model_summary import summary
x = torch.randn(1, 1, 768).to('cuda')
print(summary(rnn_model, x, show_hierarchical=True))

---------------------------------------------------------------------------------------------
      Layer (type)                              Output Shape         Param #     Tr. Param #
            LSTM-1     [1, 1, 256], [2, 1, 256], [2, 1, 256]       1,576,960       1,576,960
            ReLU-2                                  [1, 256]               0               0
          Linear-3                                  [1, 128]          32,896          32,896
         Dropout-4                                  [1, 128]               0               0
            ReLU-5                                  [1, 128]               0               0
          Linear-6                                    [1, 1]             129             129
Total params: 1,609,985
Trainable params: 1,609,985
Non-trainable params: 0
---------------------------------------------------------------------------------------------



RNNModel(
  (lstm): LSTM(768, 256, num_layers=2, batch_first=True), 1,576,960 para