Motivated by the following posts

https://vxlabs.com/2017/12/08/variational-autoencoder-in-pytorch-commented-and-annotated/#a-simple-vae-implemented-using-pytorch

https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/variational_autoencoder/main.py

In [1]:
from typing import Tuple
import numpy as np
import os
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data
from torchvision import datasets, transforms
from torchvision.utils import save_image
import time

In [10]:
class VAE(nn.Module):
    """
    Variational Auto Encoder 
    """
    
    def __init__(self, input_dim=28*28, hidden_dim=400, latent_dim=20):
        """
        @param input_dim: dimention of the input, expected to be the MNIST images 
        @param hidden_dim: dimension of the hidden fully connected layer, 
        @param latent_dim: dimension of the latent space, i.e., dimension of the mean and variance of the underlying Gaussian        
        """
        super(VAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc4 = nn.Linear(latent_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = F.relu(self.fc1(x))
        mu = self.fc2(h)
        log_var = self.fc3(h)
        return mu, log_var
    
    def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_var) 
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        h = F.relu(self.fc4(z))
        return torch.sigmoid(self.fc5(h))
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, log_var = self.encode(x)
        z = self.reparametrize(mu, log_var)
        x_tilde = self.decode(z)
        return x_tilde, mu, log_var
        

In [24]:
CUDA = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 1
BATCH_SIZE = 128
LOG_INTERVAL = 10
EPOCHS = 30
ZDIMS = 20
learning_rate = 0.003 #1e-3

torch.manual_seed(SEED)
if CUDA:
    torch.cuda.manual_seed(SEED)
    
kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {}

In [25]:
#mnist_train = datasets.MNIST(
#    os.path.expanduser('~/ml_datasets/'),
#    train=True,
#    download=True,
#    transform=transforms.ToTensor()
#)


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        os.path.expanduser('~/ml_datasets/'),
        train=True,
        download=True,
        transform=transforms.ToTensor()
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    **kwargs    
)


test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        os.path.expanduser('~/ml_datasets/'),
        train=False,
        transform=transforms.ToTensor()
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    **kwargs    
)


In [26]:
model = VAE(input_dim=28*28, hidden_dim=400, latent_dim=20)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [27]:
def loss_func(x: torch.Tensor, x_tilde: torch.Tensor, mu: torch.Tensor, log_var: torch.Tensor):
    reconstruct_err = F.binary_cross_entropy(x_tilde, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1.0 + log_var - mu.pow(2) - log_var.exp())
    return reconstruct_err + kl_div
   

In [28]:
def train(epoch_idx) -> float:
    model.train()
    train_loss = 0.0
    for batch_idx, (x, y) in enumerate(train_loader):
        x = x.view(-1, 28 * 28).to(device)
        x_tilde, mu, log_var = model(x)        
        loss = loss_func(x, x_tilde, mu, log_var)
        optimizer.zero_grad()        
        loss.backward()
        optimizer.step()
        train_loss += loss.item()  
        
    train_loss /= len(train_loader.dataset)
    return train_loss
                

In [29]:
def test(epoch_idx):
    model.eval()
    test_loss = 0.0
    for i, (data, _) in enumerate(test_loader):
        x = data.view(-1, 28 * 28).to(device)
        x_tilde, mu, log_var = model(x)
        test_loss += loss_func(x, x_tilde, mu, log_var).item()
        
        if i == 0:
            n = min(data.size(0), 8)
            comparison = torch.cat([data[:n], x_tilde.view(-1, 1, 28, 28)[:n]])
            save_image(comparison.data.cpu(), 
                       './mnist_vae_results/reconstruction_' + str(epoch_idx) + '.png',
                       nrow=8)
    
    test_loss /= len(test_loader.dataset)
    return test_loss


In [31]:
train_losses = []
test_losses = []
for epoch_idx in range(EPOCHS):
    
    start_time = time.time()
    train_loss = train(epoch_idx)
    train_losses.append(train_loss)
    train_time = time.time() - start_time
    
    start_time = time.time()
    test_loss = test(epoch_idx)
    test_losses.append(test_loss)
    test_time = time.time() - start_time
    
    print('Epoch {}, train loss: {:.4f}, train time: {:.2f}; test loss: {:.4f}, test time: {:.2f}'
          .format(epoch_idx, train_loss, train_time, test_loss, test_time))
            

Epoch 0, train loss: 144.6460, train time: 13.67; test loss: 117.9572, test time: 1.24
Epoch 1, train loss: 114.8587, train time: 14.42; test loss: 111.6151, test time: 1.24
Epoch 2, train loss: 111.0970, train time: 15.64; test loss: 109.6300, test time: 1.25
Epoch 3, train loss: 109.4085, train time: 15.92; test loss: 108.4468, test time: 1.26
Epoch 4, train loss: 108.3950, train time: 15.85; test loss: 107.3400, test time: 1.24
Epoch 5, train loss: 107.6463, train time: 16.12; test loss: 106.8236, test time: 1.33
Epoch 6, train loss: 107.1626, train time: 15.91; test loss: 106.2297, test time: 1.34
Epoch 7, train loss: 106.6156, train time: 16.06; test loss: 106.0287, test time: 1.31
Epoch 8, train loss: 106.3203, train time: 16.02; test loss: 105.8030, test time: 1.33
Epoch 9, train loss: 105.9918, train time: 15.97; test loss: 105.7344, test time: 1.31
Epoch 10, train loss: 105.7012, train time: 16.12; test loss: 105.3562, test time: 1.34
Epoch 11, train loss: 105.4922, train time

In [30]:
print(len(train_loader), len(train_loader.dataset))
x, y = next(iter(train_loader))
len(x), len(y)

469 60000


(128, 128)