# Variational Autoencoder 

In [1]:
import os
import sys
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision.utils import save_image
from torchvision.transforms import transforms


print(f'numpy version is : {np.__version__}')
print(f'torch version is : {torch.__version__}')

numpy version is : 1.21.5
torch version is : 1.11.0+cpu


In [2]:
# Let's defin the save directory
sample_dir = 'variational-auto'
if not os.path.exists(os.path.join(os.getcwd(), sample_dir)):
    sample_dir = os.path.join(os.getcwd(), sample_dir)
    os.makedirs(sample_dir)


In [4]:
# Define the hyper-parameters
z_dim = 20
epochs = 10
h_dim = 400
batch_size = 128
input_size = 28 * 28
learning_rate = 0.01
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
# Get the dataste 
train_dataset = datasets.MNIST(
    root='../basics/mnist/',
    train=True,
    transform=transforms.ToTensor(),
    download=False
)

test_dataset = datasets.MNIST(
    root='../basics/mnist/',
    train=False,
    transform=transforms.ToTensor(),
    download=False
)

In [6]:
# Let's define the data loading pipeline
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [11]:
# Let's define the variational auto-encoder
from distutils.log import log


class VAE(nn.Module):
    # Define the parameters
    def __init__(self, image_size=28*28, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        # Let's define the layers
        self.fc1 = nn.Linear(in_features=image_size, out_features=h_dim)
        self.fc2 = nn.Linear(in_features=h_dim, out_features=z_dim)
        self.fc3 = nn.Linear(in_features=z_dim, out_features=z_dim)
        self.fc4 = nn.Linear(in_features=z_dim, out_features=h_dim)
        self.fc5 = nn.Linear(in_features=h_dim, out_features=image_size)

    def encode(self, x):
        h = torch.nn.functional.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.rand_like(std)
        return mu + eps * std
    def decode(self, z):
        h = torch.nn.functional.relu(self.fc4(z))
        return torch.nn.functional.sigmoid(self.fc5(h))

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

# Let's define the model 
model = VAE().to(device=device)

In [9]:
# Let's define the optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
for epoch in tqdm(range(epochs)):
    for i, (x, _) in enumerate(train_loader):
        # Forward pass
        x = x.to(device).view(-1, input_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = torch.nn.functional.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, epochs, i+1, len(train_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

  0%|          | 0/10 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x400 and 20x20)