In [1]:
import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi']=200

In [2]:
device='cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class Encoder(nn.Module):
    def __init__(self,latent_dims):
        super(Encoder,self).__init__()
        self.linear1 = nn.Linear(784,512)
        self.linear2 = nn.Linear(512,latent_dims)
    def forward(self,x):
        x = torch.flatten(x,start_dim=1)
        x = F.relu(self.linear1(x))
        return self.linear2(x)

In [4]:
class Decoder(nn.Module):
    def __init__(self,latent_dim):
        super(Decoder,self).__init__()
        self.linear1 = nn.Linear(latent,512)
        self.linear2 = nn.Linear(512,784)
    def forward(self,z):
        z = F.relu(self.linear1(z))
        z = torch.sigmoid(self.linear2(z))
        return z.reshape((-1,1,28,28))

In [5]:
class Autoencoder(nn.Module):
    def __init__(self,latent_dims):
        super(Autoencoder,self).__init__()
        self.encoder = Encoder(latent_dims)
        self.decoder = Decoder(latent_dims)
    def forward(self,x):
        z = self.encoder(x)
        return self.decoder(z)