www.bigrabbitdata.com

In [201]:
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import datasets, transforms
from torchvision.utils import save_image

import os

## Load MNIST Data

In [202]:
transform  = transforms.Compose([transforms.ToTensor()
                                ])
training_dataset = datasets.MNIST(root='./mnist', train=True, 
                                  download=True, transform= transform)
validation_dataset = datasets.MNIST(root='./mnist', train=False, 
                                  download=True, transform= transform) 

training_loader = torch.utils.data.DataLoader(dataset=training_dataset, 
                                              batch_size=100,
                                              shuffle=True)
   
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, 
                                              batch_size=100,
                                              shuffle=True)

# Create directory to save results
result_dir = 'Intro18-VAE-Result'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

## Create our Variational Autoencoder Network Model

In [203]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_dim, latent_dim):
        super().__init__()
        self.input_size = input_size
        self.fc1 = nn.Linear(input_size, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        mean = self.fc2_mean(h)
        log_var = self.fc2_logvar(h)
        return mean, log_var
    
    def reparameterize(self, mean, logvar):
        std = torch.exp(logvar/2)
        eps = torch.randn_like(std)
        return mean + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        out = torch.sigmoid(self.fc4(h))
        return out
    
    def forward(self, x):
        mean, logvar = self.encode(x.view(-1, self.input_size))
        z = self.reparameterize(mean, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mean, logvar

## Define Model

In [204]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = VariationalAutoencoder(784, 400, 20).to(device)

## Loss Function

Reconstruction Loss +  KL Divergence Loss

In [205]:
def loss_function(reconstructed_image, original_image, mean, logvar):
    bce = F.binary_cross_entropy(reconstructed_image, original_image.view(-1, 784), reduction = 'sum')
    # kld = torch.sum(0.5 * torch.sum(logvar.exp() + mean.pow(2) - 1 - logvar, 1))
    kld = 0.5 * torch.sum(logvar.exp() + mean.pow(2) - 1 - logvar)
    return bce + kld

## Training encoder and decoder

In [206]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 10

# Main function
for e in range(1, epochs + 1):
 
    train_batch_loss = 0.0
    train_epoch_loss = 0.0
    val_epoch_loss = 0.0
    
    for batch_idx, data in enumerate(training_loader, start=1):
        inputs = data[0].to(device)
        reconstruct_image, mean, variance = model(inputs)
        loss = loss_function(reconstruct_image, inputs, mean, variance)
        optimizer.zero_grad()
        loss.backward()
        train_batch_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print("Train Epoch {} [Batch {}/{}]\tLoss: {:.3f}".format(e, batch_idx, len(training_loader), loss.item()/len(inputs)))
            
    
    else:
        with torch.no_grad():
            for val_batch_idx, val_data in enumerate(validation_loader, start=1):

                val_inputs = val_data[0].to(device)
                val_reconstruct_image, val_mean, val_variance = model(val_inputs)
                val_epoch_loss += loss_function(val_reconstruct_image, 
                                                val_inputs, 
                                                val_mean, 
                                                val_variance)
                
                if val_batch_idx  == 1:
                    # Save one sample from each epoch
                    comparison = torch.cat([val_inputs[:5], 
                                            val_reconstruct_image.view(-1, 1, 28, 28)[:5]])
                    save_image(comparison.cpu(), result_dir + '/reconstruction_' + str(e) + '.png', nrow = 5)
        
        print('=====> Epoch {}, Average Loss: {:.3f}'.format(e, train_batch_loss/len(training_loader.dataset)))
        print('Validation set: Average Loss: {:.6f}'.format(val_epoch_loss/len(validation_loader.dataset)))

Train Epoch 1 [Batch 100/600]	Loss: 173.417
Train Epoch 1 [Batch 200/600]	Loss: 157.241
Train Epoch 1 [Batch 300/600]	Loss: 144.682
Train Epoch 1 [Batch 400/600]	Loss: 139.074
Train Epoch 1 [Batch 500/600]	Loss: 128.661
Train Epoch 1 [Batch 600/600]	Loss: 129.323
=====> Epoch 1, Average Loss: 156.280
Validation set: Average Loss: 123.136131
Train Epoch 2 [Batch 100/600]	Loss: 125.259
Train Epoch 2 [Batch 200/600]	Loss: 116.801
Train Epoch 2 [Batch 300/600]	Loss: 117.336
Train Epoch 2 [Batch 400/600]	Loss: 120.858
Train Epoch 2 [Batch 500/600]	Loss: 114.732
Train Epoch 2 [Batch 600/600]	Loss: 113.585
=====> Epoch 2, Average Loss: 118.724
Validation set: Average Loss: 113.701401
Train Epoch 3 [Batch 100/600]	Loss: 114.495
Train Epoch 3 [Batch 200/600]	Loss: 114.459
Train Epoch 3 [Batch 300/600]	Loss: 119.679
Train Epoch 3 [Batch 400/600]	Loss: 110.759
Train Epoch 3 [Batch 500/600]	Loss: 110.121
Train Epoch 3 [Batch 600/600]	Loss: 113.252
=====> Epoch 3, Average Loss: 113.060
Validation s

## Test encoder

In [208]:
# generate 100 samples from gaussian distribution
samples_latent_space = torch.randn(100, 20).to(device)
generated_samples = model.decode(samples_latent_space).cpu()
save_image(generated_samples.view(100, 1, 28, 28), 
           result_dir + '/generated_samples.png', 
           nrow=10)