# Model Inspection

In [6]:
from models import *
from torchinfo import summary

Define a model

In [7]:
model = CRNN(num_channels=9,
             window_size=120,
             future_size=30,
             hidden_sizes=[64, 64, 64],
             channel_sizes=[16, 32],
             kernel_sizes=[7, 5],
             stride_sizes=[2, 2],
             pool_sizes=[2, 2],
             dropout_rate=0,
             rnn_layers=2,
             rnn_state_size=5)

Print the model's components

In [8]:
print(model)

CRNN(
  (conv_stack): Sequential(
    (0): Conv1d(9, 16, kernel_size=(7,), stride=(2,), padding=(3,))
    (1): ReLU()
    (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv1d(16, 32, kernel_size=(5,), stride=(2,), padding=(2,))
    (4): ReLU()
    (5): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (zero_pad): ConstantPad1d(padding=(0, 30), value=0)
  (swap_last1): SwapLast()
  (lstm): LSTM(32, 5, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=10, out_features=5, bias=True)
  (swap_last2): SwapLast()
)


Analyze the parameter sizes

In [9]:
summary(model, input_size=(9, 120), batch_dim=0)

Layer (type:depth-idx)                   Output Shape              Param #
CRNN                                     --                        --
├─Sequential: 1-1                        [1, 32, 7]                --
│    └─Conv1d: 2-1                       [1, 16, 60]               1,024
│    └─ReLU: 2-2                         [1, 16, 60]               --
│    └─MaxPool1d: 2-3                    [1, 16, 30]               --
│    └─Conv1d: 2-4                       [1, 32, 15]               2,592
│    └─ReLU: 2-5                         [1, 32, 15]               --
│    └─MaxPool1d: 2-6                    [1, 32, 7]                --
├─ConstantPad1d: 1-2                     [1, 32, 37]               --
├─SwapLast: 1-3                          [1, 37, 32]               --
├─LSTM: 1-4                              [1, 37, 10]               2,240
├─Linear: 1-5                            [1, 37, 5]                55
├─SwapLast: 1-6                          [1, 5, 37]                --
Total 