VAE = Variational Inferent + AutoEncoder

From encoder_decoder experiments, the naive encoder decoder doesn't not give normal latent variables, which makes generationg hard:
 - we are trying to learn the distribution of X, encoding to latent variable only reduce the dimension, we still need to learn the latent distribution
 - we wish the latent variable follows the normal distribution.
 - consider variational inference: learn the latent distirbution as they are variational gaussian

In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Download and load the training data
trainset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Split the dataset into training and test sets
train_data, test_data = train_test_split(trainset, test_size=0.2, random_state=42)

# Create data loaders for training and test sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

In [2]:
for batch in trainloader:
    images, labels = batch
    print(images.shape)
    print(labels.shape)
    break

torch.Size([64, 1, 28, 28])
torch.Size([64])


In [3]:
import torch.nn as nn

class Inferencer(nn.Module):
    def __init__(self, latent_dim):
        super(Inferencer, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 8, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(8, 8, 3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.fc_mu = nn.Linear(8*3*3, latent_dim)        
        self.fc_sigma = nn.Linear(8*3*3, latent_dim)     

    def forward(self,x):
        x = self.conv1(x) # 1*28*28 -> 8*28*28
        x = torch.relu(x)
        x = self.pool1(x) # 8*28*28 -> 8*14*14

        x = self.conv2(x) # 8*14*14 -> 8*14*14
        x = torch.relu(x)
        x = self.pool2(x) # 8*14*14 -> 8*7*7

        x = self.conv3(x) # 8*7*7 -> 8*7*7
        x = torch.relu(x)
        x = self.pool3(x) # 8*7*7 -> 8*3*3

        x = x.view(-1, 8*3*3)
        
        mu = self.fc_mu(x)
        sigma = self.fc_sigma(x)
        return mu,sigma
    
class Generator(nn.Module):
    def __init__(self,latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.fc = nn.Linear(latent_dim, 8*3*3)
        self.deconv1 = nn.ConvTranspose2d(8, 8, kernel_size=2, stride=2,output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(8, 8, kernel_size=2, stride=2)
        self.deconv3 = nn.ConvTranspose2d(8, 1, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(8,1,3,padding=1)
        
    def forward(self,z):
        z = self.fc(z)
        z = z.view(-1, 8, 3, 3)
        z = self.deconv1(z)
        z = torch.relu(z)
        z = self.deconv2(z)
        z = torch.relu(z)
        z = self.deconv3(z)
        z = self.conv1(z)
        z = torch.sigmoid(z)
        
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.inferencer = Inferencer(latent_dim)
        self.generator = Generator(latent_dim)
    
    def forward(self,x):
        mu, sigma = self.inferencer(x)
        z = mu + sigma * torch.randn_like(mu)
        x_hat = self.generator(z)
        return x_hat, mu, sigma
    
    def get_loss(self,x): 
        x_hat, mu, sigma = self.forward(x)
        # Reconstruction loss
        reconstruction_loss = torch.sum((x - x_hat)**2)
        # KL divergence 
        KL_divergence = 0.5 * torch.sum(mu**2 + sigma**2 - torch.log(sigma**2) - 1)
        ELBO = reconstruction_loss + KL_divergence
        return ELBO

In [None]:
for batch in trainloader:
    images, labels = batch
    inferencer = Inferencer(2)
    generator = Generator(2)
in    
    print(loss)
    break

RuntimeError: Given groups=1, weight of size [1, 8, 3, 3], expected input[64, 1, 28, 28] to have 8 channels, but got 1 channels instead

In [None]:
vae = VAE(2)
vae(batch)