In [23]:
import os
import torch
from torch import nn, optim
from torch.nn import functional as F
import torch.utils.data
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Get filepaths

In [24]:
# Home Directory
HOME_DIR = os.getcwd().split('/notebooks')[0]

# Data filepaths
DATA_DIR = os.path.join(HOME_DIR, 'data')

RESULTS_DIR = os.path.join(HOME_DIR, 'images/vae_kingma')

# Get Data

In [25]:
use_cuda = False
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

train_loader = torch.utils.data.DataLoader(
    dataset = datasets.MNIST(
        DATA_DIR, 
        train = True,
        download = True,
        transform = transforms.ToTensor() # can this not be in function call version ?
    ),
    batch_size = batch_size,
    shuffle = True,
    **kwargs
)


test_loader = torch.utils.data.DataLoader(
   dataset = datasets.MNIST(
        DATA_DIR, 
        train = False,
        transform = transforms.ToTensor()),
   batch_size = batch_size,
   shuffle = True,
   **kwargs
)

# Model Class

In [26]:
class VAE(nn.Module):

    def __init__(self):
        """ Want to model the "encoder" & "decoder" with NNs.
          The approximate posterior is N(mu,sigma^2) where mu & sigma are given by NNs.
        """
    
        super(VAE, self).__init__()
  
        # Encoder Layers (MNIST to Latent)
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
  
        # Decoder Layers (Latent to "MNSIT")
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
  
  
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
    
        mu = self.fc21(h1)
        logvar = self.fc22(h1) # Why was it 
    
        return mu, logvar
    
  
    def decode(self, z):
    
        h3 = F.relu(self.fc3(z))
        decoded_pixel_probabilities = torch.sigmoid(self.fc4(h3))
    
        return decoded_pixel_probabilities
    
    
    def reparametrization(self, mu, logvar):
        """Reparametarize the unobserved latent variable z with epsilon"""
    
        sigma = torch.exp(0.5*logvar)
        eps = torch.randn_like(sigma)
    
        return mu + sigma*eps
  
  
    def forward(self, x):
    
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparametrization(mu, logvar)
    
        decoded_pixel_probabilities = self.decode(z)
    
        return decoded_pixel_probabilities, mu, logvar
    
    

# Loss Function

In [27]:
def loss_function(reconstructed_pixel_probabilities, original_images, mu, logvar):
      """Note: From Section 3 of Kignma et al VAE paper."""
  
      kld = - 0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # is the use of logvar here, why it's assumed to be logvar before ?
      bce = F.binary_cross_entropy(reconstructed_pixel_probabilities, original_images.view(-1, 784), reduction = 'sum')
  
      loss = kld + bce 
      # Note that both kdl & bce are negative. 
      ## In paper want to maximize the "loss" so they are positive but we optimize by finding the minimum in pytorch.
  
      return loss

# Train & Test functions

In [28]:
device = torch.device("cpu")

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

log_interval = 10


def train(epoch):
    model.train() # sets the model in training mode.
    train_loss = 0
  
    for batch_index, (original_images, _) in enumerate(train_loader):
        # Add .to(device) for the when CUDA is used
        original_images = original_images.to(device) 
    
        # Clear gradients since will iteratively update weights based on gradient at different data points.
        optimizer.zero_grad() 
    
        reconstructed_pixel_probs, mu, logvar = model(original_images)
        loss = loss_function(reconstructed_pixel_probs, original_images, mu, logvar)
    
    
        # calculate the gradient and update the weights
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    
        # print training updates
        if batch_index % log_interval == 0:
            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_index * len(original_images), len(train_loader.dataset),
                    100. * batch_index / len(train_loader), loss.item() / len(original_images)
                )
            )
  
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
  
    # Disable gradient calculation since won't be any updating weights (& calling Tensor.backward()).
    with torch.no_grad(): 
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            reconstructed_pixel_probs, mu, logvar = model(data)
            test_loss += loss_function(reconstructed_pixel_probs, data, mu, logvar).item()
      
            # print test updates
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat(
                    [data[:n], reconstructed_pixel_probs.view(batch_size, 1, 28, 28)[:n]]
                )
                save_image(comparison.cpu(), RESULTS_DIR + '/reconstruction_' + str(epoch) + '.png', nrow=n)
  
        # print average test loss
        test_loss /= len(test_loader.dataset)
        print('====> Test set loss: {:.4f}'.format(test_loss))

# Train \& Test model

In [22]:
epochs = 1

for epoch in range(epochs + 1):
    train(epoch)
    test(epoch)
  
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),  RESULTS_DIR + '/sample_' + str(epoch) + '.png')

====> Epoch: 0 Average loss: 164.4333
====> Test set loss: 127.5191
====> Epoch: 1 Average loss: 121.6843
====> Test set loss: 115.9679
