In [1]:
# layer size test

import torch
from torchinfo import summary
import torch.nn as nn
import torchvision
import time

In [2]:
class A64_8(nn.Module):
    """A64_6 with larger input kernels
    
    Input sizes are (5, 64, 64).
    """
    def __init__(self) -> None:
        super(A64_8, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(5, 10, kernel_size=9, stride=2, padding=4),
            nn.ReLU(),

            nn.Conv2d(10, 20, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),

            nn.Conv2d(20, 40, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(40, 40, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),

            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(40, 20, kernel_size=5, padding=2, stride=1),
            nn.ReLU(),

            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(20, 10, kernel_size=(3, 3), padding=1, stride=1),
            nn.ReLU(),

            nn.Conv2d(10, 5, kernel_size=1, stride=1),
            nn.ReLU()

            # checkerboard patterns: https://distill.pub/2016/deconv-checkerboard/
            # 1. subpixel convolution: use a kernel size that is divisible by the stride to avoid 
            #       the overlap issue
            # 2. separate out the upsampling from the convolution to compute features
            #       for example, you might resize the image (using nearest neighbor ! or bilinear interpolation)
            #       and then do a convolutional layer (resize-convolution is implicityly weight-tying in 
            #       a way that discourages high frequency artifacts)
            #       TRY: torch.nn.Upsample('https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        decoded = torchvision.transforms.functional.crop(
            decoded, 0, 0, 64, 64)
        return decoded

In [3]:
class A64_9(nn.Module):
    """A64_6 with all 3x3 kernels. 
    Change encoder to use max pooling layers.

    by jarl @ 25 Oct 2023 15:14
    
    Input sizes are (5, 64, 64), encoded size is (40, 8, 8)
    """
    def __init__(self) -> None:
        super(A64_9, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(5, 10, kernel_size=3, stride=2, padding=1),  # padding='same' maintains the output size
            nn.ReLU(),

            nn.Conv2d(10, 20, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            nn.Conv2d(20, 40, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(40, 40, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(40, 20, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(20, 10, kernel_size=(3, 3), padding=1, stride=1),
            nn.ReLU(),

            nn.Conv2d(10, 5, kernel_size=1, stride=1),
            nn.ReLU()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        decoded = torchvision.transforms.functional.crop(
            decoded, 0, 0, 64, 64)
        return decoded
    

In [34]:
class FullAE1(nn.Module):
    """Autoencoder for a whole 707x200 image.

    The objective is to reduce the number of data points 
    in the latent space to about ~2000.

    Args:
        nn (nn.Module): _description_
    """
    def __init__(self):
        super(FullAE1, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(5, 10, kernel_size=3, stride=1, padding='same'),  # padding='same' maintains the output size
            nn.MaxPool2d(2, 2),
            nn.ReLU(),

            nn.Conv2d(10, 20, kernel_size=3, stride=1, padding='same'),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),

            nn.Conv2d(20, 20, kernel_size=3, stride=1, padding='same'),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),

            nn.Conv2d(20, 20, kernel_size=3, stride=1, padding='same'),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),

            nn.Conv2d(20, 20, kernel_size=3, stride=1, padding='same'),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(20, 20, kernel_size=3, stride=1, padding='same'),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ReLU(),

            nn.Conv2d(20, 20, kernel_size=3, stride=1, padding='same'),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ReLU(),

            nn.Conv2d(20, 20, kernel_size=3, stride=1, padding='same'),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ReLU(),

            nn.Conv2d(20, 10, kernel_size=3, stride=1, padding='same'),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.ReLU(),

            nn.Conv2d(10, 5, kernel_size=3, stride=1, padding='same'),
            nn.UpsamplingBilinear2d(size=(707, 200)),
            nn.ReLU()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        decoded = torchvision.transforms.functional.crop(
            decoded, 0, 0, 707, 200)
        return decoded

In [29]:
model = FullAE1()

In [36]:
model = FullAE1()
start = time.perf_counter_ns()
print(summary(model.decoder, input_size=(1, 20, 22, 6), device='mps'))
end = time.perf_counter_ns()
print(f'decode time: {(end-start)/1e6} ms')

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 5, 707, 200]          --
├─Conv2d: 1-1                            [1, 20, 22, 6]            3,620
├─UpsamplingBilinear2d: 1-2              [1, 20, 44, 12]           --
├─ReLU: 1-3                              [1, 20, 44, 12]           --
├─Conv2d: 1-4                            [1, 20, 44, 12]           3,620
├─UpsamplingBilinear2d: 1-5              [1, 20, 88, 24]           --
├─ReLU: 1-6                              [1, 20, 88, 24]           --
├─Conv2d: 1-7                            [1, 20, 88, 24]           3,620
├─UpsamplingBilinear2d: 1-8              [1, 20, 176, 48]          --
├─ReLU: 1-9                              [1, 20, 176, 48]          --
├─Conv2d: 1-10                           [1, 10, 176, 48]          1,810
├─UpsamplingBilinear2d: 1-11             [1, 10, 352, 96]          --
├─ReLU: 1-12                             [1, 10, 352, 96]          --
├─C

In [7]:
model = FullAE1()
start = time.perf_counter_ns()
print(summary(model.encoder, input_size=(1, 5, 707, 200), device='mps'))
end = time.perf_counter_ns()
print(f'encode time: {(end-start)/1e6} ms')

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 20, 22, 6]            --
├─Conv2d: 1-1                            [1, 10, 707, 200]         460
├─MaxPool2d: 1-2                         [1, 10, 353, 100]         --
├─ReLU: 1-3                              [1, 10, 353, 100]         --
├─Conv2d: 1-4                            [1, 20, 353, 100]         1,820
├─MaxPool2d: 1-5                         [1, 20, 176, 50]          --
├─ReLU: 1-6                              [1, 20, 176, 50]          --
├─Conv2d: 1-7                            [1, 20, 176, 50]          3,620
├─MaxPool2d: 1-8                         [1, 20, 88, 25]           --
├─ReLU: 1-9                              [1, 20, 88, 25]           --
├─Conv2d: 1-10                           [1, 20, 88, 25]           3,620
├─MaxPool2d: 1-11                        [1, 20, 44, 12]           --
├─ReLU: 1-12                             [1, 20, 44, 12]           --
├─Con

In [26]:
model = FullAE1()
print(summary(model, input_size=(1, 5, 707, 200), device='mps'))

Layer (type:depth-idx)                   Output Shape              Param #
FullAE1                                  [1, 5, 707, 200]          --
├─Sequential: 1-1                        [1, 20, 22, 6]            --
│    └─Conv2d: 2-1                       [1, 10, 707, 200]         460
│    └─MaxPool2d: 2-2                    [1, 10, 353, 100]         --
│    └─ReLU: 2-3                         [1, 10, 353, 100]         --
│    └─Conv2d: 2-4                       [1, 20, 353, 100]         1,820
│    └─MaxPool2d: 2-5                    [1, 20, 176, 50]          --
│    └─ReLU: 2-6                         [1, 20, 176, 50]          --
│    └─Conv2d: 2-7                       [1, 20, 176, 50]          3,620
│    └─MaxPool2d: 2-8                    [1, 20, 88, 25]           --
│    └─ReLU: 2-9                         [1, 20, 88, 25]           --
│    └─Conv2d: 2-10                      [1, 20, 88, 25]           3,620
│    └─MaxPool2d: 2-11                   [1, 20, 44, 12]           --
│    