In [1]:
import torch
from models.res_autoencoder import Encoder, Decoder

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

enc = Encoder()
print('Encoder with param count {}:'.format(count_params(enc)))
print(enc)
eg_input = torch.randn(1, 1, 28, 28)
eg_output = enc(eg_input)
print('Enc: {} -> {}'.format(eg_input.shape, eg_output.shape))
print('\n\n')

Encoder with param count 525000:
Encoder(
  (input_transform): Conv2d(1, 4, kernel_size=(1, 1), stride=(1, 1))
  (fe_modules): ModuleDict(
    (block_0): ModuleDict(
      (enc): EncoderBlock(
        (model): Sequential(
          (0): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): LayerNorm((16, 28, 28), eps=1e-05, elementwise_affine=True)
          (2): LeakyReLU(negative_slope=0.2)
          (3): Conv2d(16, 16, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (4): LayerNorm((16, 14, 14), eps=1e-05, elementwise_affine=True)
        )
      )
      (sc): ShortcutBlock(
        (model): Sequential(
          (0): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (1): LayerNorm((16, 14, 14), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (block_1): ModuleDict(
      (enc): EncoderBlock(
        (model): Sequential(
          (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), b

In [2]:
dec = Decoder()
print('Decoder with param count {}:'.format(count_params(dec)))
print(dec)
eg_dec_output = dec(eg_output)
print('Dec: {} -> {}'.format(eg_output.shape, eg_dec_output.shape))

Decoder with param count 610065:
Decoder(
  (mixer): Sequential(
    (0): LayerNorm((128, 1, 1), eps=1e-05, elementwise_affine=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): ConvTranspose2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fc_modules): ModuleDict(
    (expand): Sequential(
      (0): LayerNorm((128, 1, 1), eps=1e-05, elementwise_affine=True)
      (1): ConvTranspose2d(128, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
    )
    (block_0): ModuleDict(
      (dec): DecoderBlock(
        (model): Sequential(
          (0): LayerNorm((64, 7, 7), eps=1e-05, elementwise_affine=True)
          (1): ConvTranspose2d(64, 64, kernel_size=(2, 2), stride=(2, 2), bias=False)
          (2): LayerNorm((64, 14, 14), eps=1e-05, elementwise_affine=True)
          (3): LeakyReLU(negative_slope=0.2)
          (4): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (sc): ShortcutBlock(
        (model): Seq