In [3]:
# layer size test

import torch
from torchinfo import summary
import torch.nn as nn
import torchvision

In [26]:
class A64_8(nn.Module):
    """A64_6 with larger input kernels
    
    Input sizes are (5, 64, 64).
    """
    def __init__(self) -> None:
        super(A64_8, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(5, 10, kernel_size=9, stride=2, padding=4),
            nn.ReLU(),

            nn.Conv2d(10, 20, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),

            nn.Conv2d(20, 40, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(40, 40, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),

            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(40, 20, kernel_size=5, padding=2, stride=1),
            nn.ReLU(),

            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(20, 10, kernel_size=(3, 3), padding=1, stride=1),
            nn.ReLU(),

            nn.Conv2d(10, 5, kernel_size=1, stride=1),
            nn.ReLU()

            # checkerboard patterns: https://distill.pub/2016/deconv-checkerboard/
            # 1. subpixel convolution: use a kernel size that is divisible by the stride to avoid 
            #       the overlap issue
            # 2. separate out the upsampling from the convolution to compute features
            #       for example, you might resize the image (using nearest neighbor ! or bilinear interpolation)
            #       and then do a convolutional layer (resize-convolution is implicityly weight-tying in 
            #       a way that discourages high frequency artifacts)
            #       TRY: torch.nn.Upsample('https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        decoded = torchvision.transforms.functional.crop(
            decoded, 0, 0, 64, 64)
        return decoded

In [27]:
model = A64_8()
summary(model.decoder, input_size=(1, 40, 8, 8))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 5, 64, 64]            --
├─UpsamplingBilinear2d: 1-1              [1, 40, 16, 16]           --
├─Conv2d: 1-2                            [1, 40, 16, 16]           14,440
├─ReLU: 1-3                              [1, 40, 16, 16]           --
├─UpsamplingBilinear2d: 1-4              [1, 40, 32, 32]           --
├─Conv2d: 1-5                            [1, 20, 32, 32]           20,020
├─ReLU: 1-6                              [1, 20, 32, 32]           --
├─UpsamplingBilinear2d: 1-7              [1, 20, 64, 64]           --
├─Conv2d: 1-8                            [1, 10, 64, 64]           1,810
├─ReLU: 1-9                              [1, 10, 64, 64]           --
├─Conv2d: 1-10                           [1, 5, 64, 64]            55
├─ReLU: 1-11                             [1, 5, 64, 64]            --
Total params: 36,325
Trainable params: 36,325
Non-trainable params: 0
Tota

In [24]:
model = A64_8()
summary(model.encoder, input_size=(1, 5, 64, 64))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 40, 8, 8]             --
├─Conv2d: 1-1                            [1, 10, 32, 32]           4,060
├─ReLU: 1-2                              [1, 10, 32, 32]           --
├─Conv2d: 1-3                            [1, 20, 16, 16]           5,020
├─ReLU: 1-4                              [1, 20, 16, 16]           --
├─Conv2d: 1-5                            [1, 40, 8, 8]             7,240
├─ReLU: 1-6                              [1, 40, 8, 8]             --
Total params: 16,320
Trainable params: 16,320
Non-trainable params: 0
Total mult-adds (M): 5.91
Input size (MB): 0.08
Forward/backward pass size (MB): 0.14
Params size (MB): 0.07
Estimated Total Size (MB): 0.29