# VAE on MNIST

In [35]:
'''
Loading necessary libraries
'''
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [36]:
'''
Setup parameters
'''
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

n_epochs = 10
n_classes = 10
batch_size = 100
lr = 1e-3

image_size = 784
h_dim = 400
z_dim = 20

imgs_dir = 'vae_fake_images'
if not os.path.exists(imgs_dir):
    os.makedirs(imgs_dir)

In [37]:
'''
Loading MNIST dataset
'''

# Datasets
train_dataset = torchvision.datasets.MNIST(root='./data/mnist', 
                                           train=True,
                                           download=True, 
                                           transform=transforms.ToTensor())

# Loaders
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size,
                                           shuffle=True, 
                                           num_workers=12)

In [48]:
'''
Define model class
'''
class VAE(nn.Module):
    def __init__(self, image_size, h_dim=200, z_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def encode(self, x):
        h = self.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def raparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = self.relu(self.fc4(z))
        return self.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.raparameterize(mu, log_var)
        x_reconstruct = self.decode(z)
        return x_reconstruct, mu, log_var
    
model = VAE(image_size, h_dim, z_dim).to(device)        

In [54]:
'''
Optimizer and Loss function
'''
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.BCELoss(reduction='sum')

In [55]:
'''
Train the model
'''
total_steps = len(train_loader)

for epoch in range(n_epochs):
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device).view(-1, image_size)
        
        # forward
        imgs_reconstruct, mu, log_var = model(images)
        
        reconst_loss = loss_fn(imgs_reconstruct, images)
        KL_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        loss = reconst_loss + KL_div
        
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print
        if (i+1)%100==0:
            print('Epoch [{}/{}]-[{}/{}] Reconstruct loss {:.4f}, KL: {:.4f}'
                 .format(epoch+1, n_epochs, i+1, total_steps, reconst_loss.item(), KL_div.item()))
    
    # Save the reconstructed images
    with torch.no_grad():
        # Sample 
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1,1,28,28)
        save_image(out, './vae_fake_images/sample-'+str(epoch+1)+'.png')
        
        # Reconstruction
        out, _,_ = model(images)
        x_concat = torch.cat([images.view(-1,1,28,28), out.view(-1,1,28,28)], dim=3)
        save_image(out, './'+imgs_dir+'/reconst-'+str(epoch+1)+'.png')
        

Epoch [1/10]-[100/600] Reconstruct loss 16755.3906, KL: 583.6683
Epoch [1/10]-[200/600] Reconstruct loss 17380.3789, KL: 630.0624
Epoch [1/10]-[300/600] Reconstruct loss 15849.9414, KL: 720.6468
Epoch [1/10]-[400/600] Reconstruct loss 16990.7754, KL: 752.9923
Epoch [1/10]-[500/600] Reconstruct loss 16223.6152, KL: 690.3672
Epoch [1/10]-[600/600] Reconstruct loss 15819.6104, KL: 657.9927
Epoch [2/10]-[100/600] Reconstruct loss 16828.2812, KL: 652.1355
Epoch [2/10]-[200/600] Reconstruct loss 14084.6494, KL: 673.8443
Epoch [2/10]-[300/600] Reconstruct loss 15514.7969, KL: 682.3788
Epoch [2/10]-[400/600] Reconstruct loss 15129.8174, KL: 682.4009
Epoch [2/10]-[500/600] Reconstruct loss 14471.8418, KL: 707.8922
Epoch [2/10]-[600/600] Reconstruct loss 15656.1328, KL: 717.9883
Epoch [3/10]-[100/600] Reconstruct loss 15907.7871, KL: 681.7916
Epoch [3/10]-[200/600] Reconstruct loss 15301.7539, KL: 689.6980
Epoch [3/10]-[300/600] Reconstruct loss 15304.9805, KL: 662.9360
Epoch [3/10]-[400/600] Re