In [3]:
import torch
from torch import nn
import numpy as np

In [5]:
class Encoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=20):
        super().__init__()

        dims = [in_channels, 32, 64, 128, 256, 512]
        layers = []
        for in_dim, out_dim in zip(dims, dims[1:]):
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm(out_dim),
                    nn.LeakyReLU()
                )
            )

        self.encoder = nn.Sequential(*layers)
        self.fc_mu = nn.Linear(dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(dims[-1]*4, latent_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, start_dim=1)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)

        return mu, log_var

In [6]:
class Decoder(nn.Module):
    def __init__(self, out_channels=3, latent_dim=20):
        super().__init__()

        dims = [512, 256, 128, 64, 32]
        layers = []

        for in_dim, out_dim in zip(dims, dims[1:]):
            layers.append(
                nn.Sequential(
                    nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(out_dim),
                    nn.LeakyReLU()
                )
            )

        self.decoder_input = nn.Linear(latent_dim, dims[-1]*4)

        self.decoder = nn.Sequential(*layers)

        self.decoder_output = nn.Sequential(
            nn.ConvTranspose2d(dims[-1], dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(dims[-1], out_channels=out_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.decoder_input(x)
        x = x.view(-1, 512, 2, 2)
        x = self.decoder(x)
        x = self.decoder_output(x)

        return x

In [7]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterise(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.rand_like(std)

        return eps * std + mu

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterise(mu, log_var)
        out = self.decoder(z)

        return out, mu, log_var