In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available else 'cpu'
IMG_SIZE = 224

In [26]:
class Encoder(nn.Module):
    def __init__(self, latent_dims, capacity):
        super(Encoder, self).__init__()
        c = capacity
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1)
        self.fc = nn.Linear(in_features=c*2*56*56, out_features=latent_dims)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [27]:
class Decoder(nn.Module):
    def __init__(self, latent_dims, capacity):
        super(Decoder, self).__init__()
        self.capacity = capacity
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*56*56)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), self.capacity*2, 56, 56)
        x = F.relu(self.conv2(x))
        x = torch.tanh(self.conv1(x))
        return x

In [28]:
class AutoEncoder(nn.Module):
    def __init__(self, latent_dims, capacity):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(latent_dims, capacity)
        self.decoder = Decoder(latent_dims, capacity)

    def forward(self, x):
        latent = self.encoder(x)
        x_recon = self.decoder(latent)
        return x_recon

In [29]:
model = AutoEncoder(latent_dims=10, capacity=64).to(device)
x = torch.randn(10, 3, IMG_SIZE, IMG_SIZE).to(device)
model(x)

tensor([[[[-0.1902, -0.2515, -0.2436,  ..., -0.1444, -0.2528, -0.1960],
          [-0.2400, -0.0931, -0.2858,  ..., -0.2936, -0.1930, -0.0973],
          [-0.1969, -0.2048, -0.3651,  ..., -0.0318, -0.2925, -0.1465],
          ...,
          [-0.2126, -0.1665, -0.2605,  ..., -0.0445, -0.1612, -0.1794],
          [-0.1907, -0.0835, -0.1628,  ..., -0.1932, -0.3620, -0.0912],
          [-0.1907, -0.1681, -0.1871,  ..., -0.2092, -0.1916, -0.1450]]],


        [[[-0.1882, -0.2578, -0.2616,  ..., -0.1309, -0.2510, -0.1985],
          [-0.2173, -0.1150, -0.2659,  ..., -0.2456, -0.2049, -0.0919],
          [-0.1961, -0.1965, -0.3515,  ..., -0.0536, -0.3299, -0.1486],
          ...,
          [-0.2108, -0.1696, -0.2627,  ..., -0.0566, -0.1207, -0.1768],
          [-0.2011, -0.1151, -0.2192,  ..., -0.1859, -0.3358, -0.0850],
          [-0.1787, -0.1800, -0.1912,  ..., -0.2132, -0.2159, -0.1405]]],


        [[[-0.1966, -0.2435, -0.2434,  ..., -0.1363, -0.2432, -0.1926],
          [-0.2564, -0.064