In [1]:
import torch
import torch.nn as nn


class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, image_length=128, image_width=128):
        super(Encoder, self).__init__()
        # 5x5 64 conv. downsampling, BNorm, ReLu
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(64),
            nn.ReLU()
        ) # in_channelsximage_lengthximage_width -> 64x(image_length/2)x(image_width/2)
        # 5x5 128 conv. downsampling, BNorm, ReLu
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(128),
            nn.ReLU()
        ) # 64x(image_length/2)x(image_width/2) -> 128x(image_length/4)x(image_width/4)
        # 5x5 256 conv. downsampling, BNorm, ReLu
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(256),
            nn.ReLU()
        ) # 128x(image_length/4)x(image_width/4) -> 256x(image_length/8)x(image_width/8)
        # 2048 fully-connected, BNorm, ReLu
        self.fc1 = nn.Sequential(
            nn.Linear(256 * (image_length // 64) * (image_width // 64), out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU()
        ) # 256x(image_length/64)x(image_width/64) -> out_channels

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        # View as 1D tensor
        x = x.view(x.size(0), -1)
        x = self.fc1(x)

        return x
    
class Decoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, image_length=128, image_width=128):
        super(Decoder, self).__init__()
        self.image_length = image_length
        self.image_width = image_width
        # out_channels fully-connected, BNorm, ReLu
        self.fc1 = nn.Sequential(
            nn.Linear(in_channels, 256 * (image_length // 8) * (image_width // 8)),
            nn.BatchNorm1d(256 * (image_length // 8) * (image_width // 8)),
            nn.ReLU()
        ) # out_channels -> 256x(image_length/8)x(image_width/8)
        # 5x5 256 conv. upsampling, BNorm, ReLu
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        ) # 256x(image_length/8)x(image_width/8) -> 128x(image_length/4)x(image_width/4)
        # 5x5 128 conv. upsampling, BNorm, ReLu
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        ) # 128x(image_length/4)x(image_width/4) -> 64x(image_length/2)x(image_width/2)
        # 5x5 64 conv. upsampling, BNorm, ReLu
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(64, out_channels, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ) # 64x(image_length/2)x(image_width/2) -> out_channelsximage_lengthximage_width

    def forward(self, x):
        x = self.fc1(x)
        x = x.view(x.size(0), 256, (self.image_length // 8), (self.image_width // 8))
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        return x


In [2]:
# Define test parameters
image_length = 64
image_width = int(image_length * 1.5)
in_channels = 3
out_channels = 8*8*256
batch_size = 2

print(f'Batches: {batch_size}, Channels: {in_channels}, Image Length: {image_length}, Image Width: {image_width}')

# Test Encoder
input = torch.randn(batch_size, in_channels, image_length, image_width).to('cuda')
test_encoder = Encoder(in_channels, out_channels, image_length, image_width).to('cuda')
output = test_encoder(input)
print(f'Test Encoder: {output.shape}')

# Test Decoder
input = output
test_decoder = Decoder(out_channels, in_channels, image_length, image_width).to('cuda')
output = test_decoder(input)
print(f'Test Decoder: {output.shape}')

# Print model summary
from torchsummary import summary
summary(test_encoder, input_size=(in_channels, image_length, image_width))
summary(test_decoder, input_size=(out_channels,))

Batches: 2, Channels: 3, Image Length: 64, Image Width: 96
Test Encoder: torch.Size([2, 16384])
Test Decoder: torch.Size([2, 3, 64, 96])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 48]           4,864
         MaxPool2d-2           [-1, 64, 16, 24]               0
       BatchNorm2d-3           [-1, 64, 16, 24]             128
              ReLU-4           [-1, 64, 16, 24]               0
            Conv2d-5           [-1, 128, 8, 12]         204,928
         MaxPool2d-6            [-1, 128, 4, 6]               0
       BatchNorm2d-7            [-1, 128, 4, 6]             256
              ReLU-8            [-1, 128, 4, 6]               0
            Conv2d-9            [-1, 256, 2, 3]         819,456
        MaxPool2d-10            [-1, 256, 1, 1]               0
      BatchNorm2d-11            [-1, 256, 1, 1]             512
             ReLU-12          