In [1]:
from torch.utils.data import Dataset

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)

In [2]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 13 * 10, 10)
        )
    
        self.decoder = nn.Sequential(
            nn.Linear(10, 16640),
            Reshape(-1, 128, 13, 10),
            nn.ConvTranspose2d(128, 64, 3, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 2, 2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, (2, 1), 2),
            nn.ReLU(),
            nn.ConvTranspose2d(1, 1, 1, 1)
        )
    
    def forward(self, x):
        out = self.encoder(x)
        out = self.decoder(out)
        return out

    def get_codes(self, x):
        return self.encoder(x)

In [3]:
from torchsummary import summary 
model = AutoEncoder().cuda()
summary(model, (3, 100, 75))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 50, 38]             896
              ReLU-2           [-1, 32, 50, 38]               0
            Conv2d-3           [-1, 64, 25, 19]          18,496
              ReLU-4           [-1, 64, 25, 19]               0
            Conv2d-5          [-1, 128, 13, 10]          73,856
              ReLU-6          [-1, 128, 13, 10]               0
           Flatten-7                [-1, 16640]               0
            Linear-8                   [-1, 10]         166,410
            Linear-9                [-1, 16640]         183,040
          Reshape-10          [-1, 128, 13, 10]               0
  ConvTranspose2d-11           [-1, 64, 25, 19]          73,792
             ReLU-12           [-1, 64, 25, 19]               0
  ConvTranspose2d-13           [-1, 32, 50, 38]           8,224
             ReLU-14           [-1, 32,