# 4. Variational AutoEncoder Demo (FC)

### Imports

In [1]:
# Imports
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision.utils import save_image
import matplotlib.pyplot as plt

In [2]:
# CUDA check
CUDA = True
device = "cuda" if (torch.cuda.is_available() and CUDA) else "cpu"
print(torch.cuda.is_available())
print(device)

True
cuda


### Dataset and dataloaders

In [3]:
# Data Preprocessing
# - ToTensor
# - Image Normalization
transform = transforms.Compose([transforms.ToTensor()])

In [4]:
# Train datasets/dataloaders
train_set = torchvision.datasets.MNIST(root='./data', \
                                       train = True, \
                                       download = True, \
                                       transform = transform)
train_loader = torch.utils.data.DataLoader(train_set, \
                                           batch_size = 32, \
                                           shuffle = False)

In [5]:
# Test datasets/dataloaders
test_set = torchvision.datasets.MNIST(root = './data', \
                                      train = False, \
                                      download = True, \
                                      transform = transform)
test_loader = torch.utils.data.DataLoader(test_set, \
                                          batch_size = 4, \
                                          shuffle = False)

### Model

Conv2d and Conv2dTranspose layers

In [6]:
# Define Variational AutoEncoder Model for MNIST
class MNIST_VAE(nn.Module):
    
    def __init__(self, image_channels, init_channels, kernel_size, latent_dim):
        super().__init__()
 
        # Encoder with stacked Conv
        self.enc1 = nn.Conv2d(image_channels, init_channels, kernel_size, \
                              stride = 2, padding = 1)
        self.enc2 = nn.Conv2d(init_channels, init_channels*2, kernel_size, \
                              stride = 2, padding = 1)
        self.enc3 = nn.Conv2d(init_channels*2, init_channels*4, kernel_size, \
                              stride = 2, padding = 1)
        self.enc4 = nn.Conv2d(init_channels*4, 64, kernel_size, \
                              stride = 2, padding = 0)
        
        # FC layers for learning representations
        self.fc1 = nn.Linear(64, 128)
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_log_var = nn.Linear(128, latent_dim)
        self.fc2 = nn.Linear(latent_dim, 64)
        
        # Decoder, simply mirroring the encoder with ConvTranspose
        self.dec1 = nn.ConvTranspose2d(64, init_channels*8, kernel_size, \
                                       stride = 1, padding = 0)
        self.dec2 = nn.ConvTranspose2d(init_channels*8, init_channels*4, kernel_size, \
                                       stride = 2, padding = 1)
        self.dec3 = nn.ConvTranspose2d(init_channels*4, init_channels*2, kernel_size, \
                                       stride = 2, padding = 1)
        self.dec4 = nn.ConvTranspose2d(init_channels*2, image_channels, kernel_size, \
                                       stride = 2, padding = 1)
        
        
    def sample(self, mu, log_var):
        """
        mu: mean from the encoder's latent space
        log_var: log variance from the encoder's latent space
        """
        
        # Standard deviation
        std = torch.exp(0.5*log_var)
        
        # randn_like is used to produce a vector with same dimensionality as std
        eps = torch.randn_like(std)
        
        # Sampling
        sample = mu + (eps * std)
        return sample
    
    
    def forward(self, x):
        
        # Encoder
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        
        # Pooling
        batch, _, _, _ = x.shape
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        
        # FC layers to get mu and log_var
        hidden = self.fc1(x)
        mu = self.fc_mu(hidden)
        log_var = self.fc_log_var(hidden)
        
        # Get the latent vector through reparameterization
        z = self.sample(mu, log_var)
        z = self.fc2(z)
        z = z.view(-1, 64, 1, 1)
 
        # Decoding
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = torch.sigmoid(self.dec4(x))
        return x, mu, log_var

### Train model

In [7]:
# Defining Parameters (part 1 - model parameters)
# - 4x4 kernel
# - 8 filters
# - image channels is set to 1, because MNIST is greayscale
# - latent dimension for sampling set to 16 (arbitrarily)
kernel_size = 4
init_channels = 8
image_channels = 1
latent_dim = 16

In [8]:
# Initialize MNIST Autoencoder
torch.manual_seed(10)
model = MNIST_VAE(image_channels, init_channels, kernel_size, latent_dim).to(device)

In [9]:
# Defining Parameters
# - BCE Loss, which will be our reconstruction loss for now
# - Adam as optimizer, default parameters
# - 100 Epochs
# - 64 as batch size
num_epochs = 100
batch_size = 64
optimizer = optim.Adam(model.parameters(), lr = 0.001)
criterion = nn.MSELoss(reduction = 'sum')

In [10]:
def final_loss(mse_loss, mu, logvar, alpha = 1):
    """
    This function will add the reconstruction loss (MSELoss) and the KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    - param mse_loss: recontruction loss
    - param mu: the mean from the latent vector
    - param logvar: log variance from the latent vector
    - alpha: scaling parameter
    """
    
    MSE = bce_loss 
    KLD = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + alpha*KLD

In [11]:
# Train
loss_list = []
for epoch in range(num_epochs):
    for data in train_loader:
        
        # Send data to device
        img, _ = data
        img = Variable(img).to(device)
        
        # Forward pass
        # Gets output, mu and logvar values
        output, mu, logvar = model(img)
        
        # Twofold loss
        reco_loss = criterion(output, img)
        loss = final_loss(reco_loss, mu, logvar)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # Display
    print('epoch {}/{}, loss {:.4f}'.format(epoch + 1, num_epochs, loss.item()))
    loss_list.append(loss.item())

RuntimeError: Calculated padded input size per channel: (3 x 3). Kernel size: (4 x 4). Kernel size can't be greater than actual input size

In [None]:
# Display loss
plt.figure()
plt.plot(loss_list)
plt.show()