# Model Inspection

In [1]:
from learning.models import *
from torchinfo import summary

Define a model

In [2]:
model = TNN(num_channels=9,
            window_size=120,
            future_size=30,
            hidden_sizes=[64, 128],
            channel_sizes=[64, 64, 64],
            kernel_sizes=[3, 5, 7, 5, 3],
            stride_sizes=[1, 1, 1, 1, 1],
            dilation_sizes=[1, 1, 1],
            pool_sizes=[1, 1, 1, 1, 1],
            state_sizes=[32, 32],
            attn_sizes=[128, 128, 128],
            head_sizes=[4, 4, 4],
            dropout_rate=0,
            batch_normalization=False)

Print the model's components

In [3]:
print(model)

TNN(
  (zero_pad): ConstantPad1d(padding=(0, 30), value=0)
  (swap_in): Transposer(dim1=-1, dim2=-2)
  (linear): Linear(in_features=9, out_features=128, bias=True)
  (trans): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (linear1): Linear(in_features=128, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=128, bias=True)
    (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (swap_out): Transposer(dim1=-1, dim2=-2)
)


Analyze the parameter sizes

In [4]:
summary(model, input_size=(1, 9, 120), device="cpu", depth=2)

Layer (type:depth-idx)                   Output Shape              Param #
TNN                                      --                        --
├─ConstantPad1d: 1-1                     [1, 9, 150]               --
├─Transpose: 1-2                         [1, 150, 9]               --
├─Linear: 1-3                            [1, 150, 128]             1,280
├─TransformerEncoderLayer: 1-4           [1, 150, 128]             --
│    └─MultiheadAttention: 2-1           [1, 150, 128]             66,048
│    └─Dropout: 2-2                      [1, 150, 128]             --
│    └─LayerNorm: 2-3                    [1, 150, 128]             256
│    └─Linear: 2-4                       [1, 150, 2048]            264,192
│    └─Dropout: 2-5                      [1, 150, 2048]            --
│    └─Linear: 2-6                       [1, 150, 128]             262,272
│    └─Dropout: 2-7                      [1, 150, 128]             --
│    └─LayerNorm: 2-8                    [1, 150, 128]             