In [1]:
import sys
sys.path.insert(0, '../src')

import numpy as np

def model_summary(model):
    encoder_params = np.sum([np.prod(p.size()) for p in model.encoder.parameters()])
    recurrent_params = np.sum([np.prod(p.size()) for p in model.recurrent.parameters()])
    decoder_params = np.sum([np.prod(p.size()) for p in model.decoder.parameters()])

    total_params = encoder_params + decoder_params + recurrent_params

    # Print table
    table  =  '|  Module   |  Num Params  |\n'
    table +=  '|-----------|--------------|\n'
    table += f'| Encoder   |{encoder_params:12g}  |\n'
    table += f'| Recurrent |{recurrent_params:12g}  |\n'
    table += f'| Decoder   |{decoder_params:12g}  |\n'
    table +=  '|-----------|--------------|\n'
    table += f'| Total     |{total_params:12g}  |\n'

    print(table)
    print(model)

    return total_params

In [2]:
from res_ae_convnext import ConvNextFramePredictor

convnext_params = model_summary(ConvNextFramePredictor())

|  Module   |  Num Params  |
|-----------|--------------|
| Encoder   |       60896  |
| Recurrent |       47136  |
| Decoder   |       60866  |
|-----------|--------------|
| Total     |      168898  |

ConvNextFramePredictor(
  (encoder): ConvNextEncoder(
    (downsample_blocks): ModuleList(
      (0): Sequential(
        (0): LayerNorm2d((32,), eps=1e-06, elementwise_affine=True)
        (1): Conv2d(32, 64, kernel_size=(2, 2), stride=(2, 2))
      )
      (1): Sequential(
        (0): LayerNorm2d((64,), eps=1e-06, elementwise_affine=True)
        (1): Conv2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
      )
    )
    (stages): ModuleList(
      (0): Sequential(
        (0): ConvNextBlock(
          (depth_conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=same, groups=32)
          (norm): LayerNorm2d((32,), eps=1e-06, elementwise_affine=True)
          (conv1): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
          (conv2): Conv2d(128, 32, kernel_size=(1, 1), s

In [3]:
from res_ae import ResidualFramePredictor


res_ae_params = model_summary(ResidualFramePredictor())

|  Module   |  Num Params  |
|-----------|--------------|
| Encoder   |         542  |
| Recurrent | 1.19578e+07  |
| Decoder   |         530  |
|-----------|--------------|
| Total     | 1.19588e+07  |

ResidualFramePredictor(
  (encoder): Encoder(
    (cells): ModuleList(
      (0): CnnCell(
        (conv): Conv2d(2, 4, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (bn): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): CnnCell(
        (conv): Conv2d(4, 6, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (bn): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (decoder): Decoder(
    (cells): ModuleList(
      (0): DeCnnCell(
        (deconv): ConvTranspose2d(6, 4, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (bn): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): DeCnnCell(
        (deconv): ConvTranspos

In [4]:
print(f'ResAE model has {res_ae_params / convnext_params:.1f} as many parameters') 

ResAE model has 70.8 as many parameters
