# Model Inspection

In [5]:
from models import *
from torchinfo import summary

Define a model

In [6]:
model = TNN(num_channels=9,
             window_size=120,
             future_size=30,
             hidden_sizes=[32, 64, 128, 64, 32],
             channel_sizes=[128, 128, 128, 64, 32],
             kernel_sizes=[3, 5, 7, 5, 3],
             stride_sizes=[1, 1, 1, 1, 1],
             pool_sizes=[1, 1, 1, 1, 1],
             dropout_rate=0,
             rnn_layers=4,
             rnn_state_size=32,
             transformer_encoders=1,
            transformer_decoders=1,
            transformer_dim=256,
            attention_heads=16)

Print the model's components

In [7]:
print(model)

TNN(
  (zero_pad): ConstantPad1d(padding=(0, 0, 0, 30), value=0)
  (swap_last1): SwapLast()
  (embed): Linear(in_features=9, out_features=256, bias=True)
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=256, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): Tr

Analyze the parameter sizes

In [8]:
summary(model, input_size=(9, 120), batch_dim=0)

Layer (type:depth-idx)                             Output Shape              Param #
TNN                                                --                        --
├─Transformer: 1                                   --                        --
│    └─TransformerEncoder: 2                       --                        --
│    │    └─ModuleList: 3-1                        --                        395,776
│    └─TransformerDecoder: 2                       --                        --
│    │    └─ModuleList: 3-2                        --                        659,456
├─SwapLast: 1-1                                    [1, 120, 9]               --
├─Linear: 1-2                                      [1, 120, 256]             2,560
├─ConstantPad1d: 1-3                               [1, 150, 256]             --
├─Transformer: 1-4                                 [1, 150, 256]             --
│    └─TransformerEncoder: 2-1                     [1, 120, 256]             --
│    │    └─LayerNorm:

In [9]:
import torch
print(model(torch.rand((1, 9, 120)).to("cuda")).shape)

torch.Size([1, 5, 150])
