In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# Create 32 samples of data

In [2]:
data = torch.randn(32,1,28,28)

In [3]:
data.shape

torch.Size([32, 1, 28, 28])

# Encoder Layers

In [4]:
conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=2)
out_conv1 = conv1(data)
print(out_conv1.shape)

torch.Size([32, 16, 12, 12])


In [5]:
conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
out_conv2 = conv2(out_conv1)
print(out_conv2.shape)

torch.Size([32, 32, 4, 4])


In [6]:
flatten = out_conv2.reshape((data.shape[0], -1))

In [7]:
flatten.shape

torch.Size([32, 512])

In [8]:
linear1 = nn.Linear(512,300)
out_linear1 = linear1(flatten)
print(out_linear1.shape)

torch.Size([32, 300])


In [9]:
mu = nn.Linear(300, 16) # latent_size = 16
out_mu = mu(out_linear1)
print(out_mu.shape) # same for logvar layer

torch.Size([32, 16])


# Decoder Layers

In [10]:
linear2 = nn.Linear(16, 300)
out_linear2 = linear2(out_mu)
print(out_linear2.shape)

torch.Size([32, 300])


In [11]:
linear3 = nn.Linear(300, 512)
out_linear3 = linear3(out_linear2)
print(out_linear3.shape)

torch.Size([32, 512])


In [12]:
unflatten = out_linear3.reshape((out_linear3.shape[0], 32, 4, 4))
print(unflatten.shape)

torch.Size([32, 32, 4, 4])


In [13]:
conv3 = nn.ConvTranspose2d(32, 16, kernel_size=5,stride=2)
out_conv3 = conv3(unflatten)
print(out_conv3.shape)

torch.Size([32, 16, 11, 11])


In [14]:
conv4 = nn.ConvTranspose2d(16, 1, kernel_size=5, stride=2)
out_conv4 = conv4(out_conv3)
print(out_conv4.shape)

torch.Size([32, 1, 25, 25])


In [15]:
conv5 = nn.ConvTranspose2d(1, 1, kernel_size=4)# Final Layer of VAE
out_conv5 = conv5(out_conv4)
print(out_conv5.shape)

torch.Size([32, 1, 28, 28])


# Create Variational Autoencoder Class

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE,self).__init__()

        # Encoder Layers

        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.linear1 = nn.Linear(512,300)
        self.mu = nn.Linear(300, 16) # latent_size = 16
        self.logvar = nn.Linear(300, 16)



        # Decoder Layers

        self.linear2 = nn.Linear(16, 300)
        self.linear3 = nn.Linear(300, 512)
        self.conv3 = nn.ConvTranspose2d(32, 16, kernel_size=5,stride=2)
        self.conv4 = nn.ConvTranspose2d(16, 1, kernel_size=5, stride=2)
        self.conv5 = nn.ConvTranspose2d(1, 1, kernel_size=4)



    def encoder(self,x):
        t = F.relu(self.conv1(x))
        t = F.relu(self.conv2(t))
        t = t.reshape((x.shape[0], -1))

        t = F.relu(self.linear1(t))
        mu = self.mu(t)
        logvar = self.logvar(t)
        return mu, logvar

    def reparameterization(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std).to(device)
        return eps*std + mu

    def unFlatten(self, x):
        return x.reshape((x.shape[0], 32, 4, 4))

    def decoder(self, z):
        t = F.relu(self.linear2(z))
        t = F.relu(self.linear3(t))
        t = self.unFlatten(t)
        t = F.relu(self.conv3(t))
        t = F.relu(self.conv4(t))
        t = torch.sigmoid(self.conv5(t))
        return t


    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterization(mu,logvar)
        pred = self.decoder(z)
        return pred, mu, logvar