# Model Inspection

In [1]:
from learning.models import *
from torchinfo import summary

Define a model

In [2]:
model = CRNN(num_channels=9,
            window_size=120,
            future_size=20,
            hidden_sizes=None,
            channel_sizes=[128, 128, 128],
            kernel_sizes=[3, 5, 7, 5, 3],
            stride_sizes=[1, 1, 1, 1, 1],
            dilation_sizes=[1, 1, 1],
            pool_sizes=[1, 1, 1, 1, 1],
            state_sizes=[32, 32],
            attn_sizes=[128, 128],
            head_sizes=[4, 4, 4],
            dropout_rate=0,
            batch_normalization=False)

Print the model's components

In [3]:
print(model)

CRNN(
  (conv_stack): ConvStack(
    (conv0): Conv1d(9, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (relu0): ReLU()
    (conv1): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
    (relu1): ReLU()
    (conv2): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
    (relu2): ReLU()
  )
  (zero_pad): ConstantPad1d(padding=(0, 20), value=0)
  (swap_in): Transposer(dim1=-1, dim2=-2)
  (lstm_stack): RecurrentStack(
    (lstm0): LSTM(128, 32, batch_first=True, bidirectional=True)
    (proj0): Projector(dim=0)
    (lstm1): LSTM(64, 32, batch_first=True, bidirectional=True)
    (proj1): Projector(dim=0)
    (linear): Linear(in_features=64, out_features=5, bias=True)
  )
  (swap_out): Transposer(dim1=-1, dim2=-2)
)


Analyze the parameter sizes

In [4]:
summary(model, input_size=(1, 9, 120), device="cpu", depth=2)

Layer (type:depth-idx)                   Output Shape              Param #
CRNN                                     --                        --
├─ConvStack: 1-1                         [1, 128, 120]             --
│    └─Conv1d: 2-1                       [1, 128, 120]             3,584
│    └─ReLU: 2-2                         [1, 128, 120]             --
│    └─Conv1d: 2-3                       [1, 128, 120]             82,048
│    └─ReLU: 2-4                         [1, 128, 120]             --
│    └─Conv1d: 2-5                       [1, 128, 120]             114,816
│    └─ReLU: 2-6                         [1, 128, 120]             --
├─ConstantPad1d: 1-2                     [1, 128, 140]             --
├─Transpose: 1-3                         [1, 140, 128]             --
├─RecurrentStack: 1-4                    [1, 140, 5]               --
│    └─LSTM: 2-7                         [1, 140, 64]              41,472
│    └─Projector: 2-8                    [1, 140, 64]              --