<div style="text-align: center; font-size: 30pt; font-weight: bold; margin: 1em 0em 1em 0em">Wasserstein Auto-Encoders</div>

In [1]:
import sys
import os

In [2]:
sys.path.append(os.path.abspath('../autoencoders'))

In [3]:
# "Magic" commands for automatic reloading of module, perfect for prototyping
%reload_ext autoreload
%autoreload 2

import wasserstein

In [4]:
import numpy as np

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

# The dataset

In [6]:
root = '../data'
    
trans = transforms.Compose([
    # transforms.ToTensor(), 
    # transforms.ToPILImage(),
    # transforms.Resize((8, 8)),
    transforms.ToTensor()
    # transforms.Normalize((0.5,), (1.0,))
])

# if not exist, download mnist dataset
train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('>> total trainning batch number: {}'.format(len(train_loader)))
print('>> total testing batch number: {}'.format(len(test_loader)))

>> total trainning batch number: 600
>> total testing batch number: 100


In [7]:
data = [d for d, _ in train_loader][0]

In [8]:
data.size()

torch.Size([100, 1, 28, 28])

# Wasserstein Auto-Encoder

In [65]:
wae = wasserstein.WassersteinAutoEncoder(ksi=10)

In [49]:
%psource wasserstein.WassersteinAutoEncoder

In [66]:
print(wae)

WassersteinAutoEncoder(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=12, bias=True)
  (fc4): Linear(in_features=12, out_features=2, bias=True)
  (fc5): Linear(in_features=2, out_features=12, bias=True)
  (fc6): Linear(in_features=12, out_features=50, bias=True)
  (fc7): Linear(in_features=50, out_features=128, bias=True)
  (fc8): Linear(in_features=128, out_features=784, bias=True)
)


# Training procedure

## TensorBoard

First, we create a `SummaryWriter` instance (in order to use tensorboard):

In [67]:
from tensorboardX import SummaryWriter
writer = SummaryWriter('wae/bce-mean3')

In order to visualize the graph, we call next cell:

In [68]:
dummy_input = torch.autograd.Variable(torch.rand(1, 1, 28, 28))
writer.add_graph(wae, dummy_input)

## Training and testing

We define the device used during the gradient descent:

In [69]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

We define the learning rate:

In [70]:
learning_rate = 1e-3

And the optimizer:

In [71]:
optimizer = torch.optim.Adam(wae.parameters(), lr=learning_rate, weight_decay=1e-5)

## Sampling space

We want to sample from the latent space.

In [74]:
space = np.array(
    [
        [x, y] 
        for y in np.linspace(-1.5, 1.5, 16) 
        for x in np.linspace(-1.5, 1.5, 16)
    ], 
    dtype=np.float
)

space = torch.from_numpy(space).type(torch.FloatTensor)

# Training

In [78]:
for epoch in range(1, 100):
    
    wasserstein.train(epoch, wae, optimizer, train_loader, device, writer)
    wasserstein.test(epoch, wae, test_loader, device, writer)
    
    with torch.no_grad():
        
        # sample = torch.randn(64, 2).to(device)
        
        sample = space # + .05 * torch.randn(16 * 16, 2).to(device)
        sample = wae.decode(sample).cpu()
        
        # utils.save_image(
        #     sample.view(16 * 16, 1, 28, 28), 
        #     'results/vae-sample_' + str(epoch) + '.png', 
        #     nrow=16
        # )
        
        writer.add_image(
            'sample', 
            utils.make_grid(sample.view(16 * 16, 1, 28, 28), nrow=16), 
            epoch
        )

>> Epoch: 1 Average loss: 0.0020
>> Test set loss: 0.0030


KeyboardInterrupt: 