<div style="text-align: center; font-size: 30pt; font-weight: bold; margin: 1em 0em 1em 0em">Wasserstein Auto-Encoders</div>

In [5]:
import sys
import os

In [6]:
sys.path.append(os.path.abspath('../autoencoders'))

In [7]:
# "Magic" commands for automatic reloading of module, perfect for prototyping
%reload_ext autoreload
%autoreload 2

import wasserstein

In [58]:
import numpy as np

In [59]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

# The dataset

In [12]:
root = '../data'
    
trans = transforms.Compose([
    # transforms.ToTensor(), 
    # transforms.ToPILImage(),
    # transforms.Resize((8, 8)),
    transforms.ToTensor()
    # transforms.Normalize((0.5,), (1.0,))
])

# if not exist, download mnist dataset
train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('>> total trainning batch number: {}'.format(len(train_loader)))
print('>> total testing batch number: {}'.format(len(test_loader)))

>> total trainning batch number: 600
>> total testing batch number: 100


In [14]:
data = [d for d, _ in train_loader][0]

In [16]:
data.size()

torch.Size([100, 1, 28, 28])

# Wasserstein Auto-Encoder

In [10]:
wae = wasserstein.WassersteinAutoEncoder(ksi=10)

In [49]:
%psource wasserstein.WassersteinAutoEncoder

In [9]:
print(wae)

WassersteinAutoEncoder(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=12, bias=True)
  (fc4): Linear(in_features=12, out_features=2, bias=True)
  (fc5): Linear(in_features=2, out_features=12, bias=True)
  (fc6): Linear(in_features=12, out_features=50, bias=True)
  (fc7): Linear(in_features=50, out_features=128, bias=True)
  (fc8): Linear(in_features=128, out_features=784, bias=True)
)


# Training procedure

## TensorBoard

First, we create a `SummaryWriter` instance (in order to use tensorboard):

In [78]:
from tensorboardX import SummaryWriter
writer = SummaryWriter('wae/experience1')

In order to visualize the graph, we call next cell:

In [79]:
dummy_input = torch.autograd.Variable(torch.rand(1, 1, 28, 28))
writer.add_graph(wae, dummy_input)

## Training and testing

We define the device used during the gradient descent:

In [63]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

We define the learning rate:

In [64]:
learning_rate = 1e-3

And the optimizer:

In [65]:
optimizer = torch.optim.Adam(wae.parameters(), lr=learning_rate, weight_decay=1e-5)

The `train()` method is defined as:

In [70]:
def train(epoch):
    
    wae.train()
    train_loss = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        
        # We move the mini-batch to the device (useful is using a GPU)
        data = data.to(device)
        
        # We initialize the gradients
        optimizer.zero_grad()
        
        # We compute the recontruction of x (x_tilde) and its encoding (z)
        x_tilde, z = wae(data)
        
        # We compute the loss
        loss = wae.loss(x_tilde=x_tilde, x=data, z=z)
        
        # Backpropagation
        loss.backward()
        
        # Updating the loss
        train_loss += loss.item()
        
        # Updating the parameters
        optimizer.step()

    print('>> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))
    
    writer.add_scalar('data/train-loss', train_loss / len(train_loader.dataset), epoch)

The `test()` method is defined as:

In [74]:
def test(epoch):
    
    wae.eval()
    test_loss = 0
    
    # We do not compute gradients during the testing phase, hence the no_grad() environment
    with torch.no_grad():
        
        for i, (data, _) in enumerate(test_loader):
            
            data = data.to(device)
            x_tilde, z = wae(data)
            
            test_loss += wae.loss(x_tilde=x_tilde, x=data, z=z).item()
            
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n], x_tilde.view(100, 1, 28, 28)[:n]])
                
                writer.add_image('reconstruction', comparison.cpu(), epoch)

    test_loss /= len(test_loader.dataset)
    print('>> Test set loss: {:.4f}'.format(test_loss))
    
    writer.add_scalar('data/test-loss', test_loss, epoch)

## Sampling space

We want to sample from the latent space.

In [75]:
space = np.array(
    [
        [x, y] 
        for y in np.linspace(-1.5, 1.5, 16) 
        for x in np.linspace(-1.5, 1.5, 16)
    ], 
    dtype=np.float
)

space = torch.from_numpy(space).type(torch.FloatTensor)

# Training

In [80]:
for epoch in range(100, 200):
    
    train(epoch)
    test(epoch)
    
    with torch.no_grad():
        
        # sample = torch.randn(64, 2).to(device)
        
        sample = space # + .05 * torch.randn(16 * 16, 2).to(device)
        sample = wae.decode(sample).cpu()
        
        utils.save_image(
            sample.view(16 * 16, 1, 28, 28), 
            'results/vae-sample_' + str(epoch) + '.png', 
            nrow=16
        )
        
        writer.add_image(
            'sample', 
            utils.make_grid(sample.view(16 * 16, 1, 28, 28), nrow=16), 
            epoch
        )

>> Epoch: 100 Average loss: 0.0018
>> Test set loss: 0.0026
>> Epoch: 101 Average loss: 0.0018
>> Test set loss: 0.0025
>> Epoch: 102 Average loss: 0.0018
>> Test set loss: 0.0026
>> Epoch: 103 Average loss: 0.0018
>> Test set loss: 0.0026
>> Epoch: 104 Average loss: 0.0018
>> Test set loss: 0.0026
>> Epoch: 105 Average loss: 0.0018
>> Test set loss: 0.0026
>> Epoch: 106 Average loss: 0.0018
>> Test set loss: 0.0027
>> Epoch: 107 Average loss: 0.0018
>> Test set loss: 0.0026
>> Epoch: 108 Average loss: 0.0018
>> Test set loss: 0.0025
>> Epoch: 109 Average loss: 0.0018
>> Test set loss: 0.0025
>> Epoch: 110 Average loss: 0.0018
>> Test set loss: 0.0027
>> Epoch: 111 Average loss: 0.0018
>> Test set loss: 0.0025
>> Epoch: 112 Average loss: 0.0018
>> Test set loss: 0.0027
>> Epoch: 113 Average loss: 0.0017
>> Test set loss: 0.0026
>> Epoch: 114 Average loss: 0.0018
>> Test set loss: 0.0027
>> Epoch: 115 Average loss: 0.0018
>> Test set loss: 0.0026
>> Epoch: 116 Average loss: 0.0018
>> Te

KeyboardInterrupt: 