<div style="text-align: center; font-size: 30pt; font-weight: bold; margin: 1em 0em 1em 0em">Auto-Encoders v. Variational Auto-Encoders</div>

In [1]:
import sys
import os

In [2]:
sys.path.append(os.path.abspath('../autoencoders'))

In [3]:
# "Magic" commands for automatic reloading of module, perfect for prototyping
%reload_ext autoreload
%autoreload 2

import autoencoders

# The dataset

In [20]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

In [5]:
root = '../data'
    
trans = transforms.Compose([
    # transforms.ToTensor(), 
    # transforms.ToPILImage(),
    # transforms.Resize((8, 8)),
    transforms.ToTensor()
    # transforms.Normalize((0.5,), (1.0,))
])

# if not exist, download mnist dataset
train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('>> total trainning batch number: {}'.format(len(train_loader)))
print('>> total testing batch number: {}'.format(len(test_loader)))

>> total trainning batch number: 600
>> total testing batch number: 100


# Auto-encoder

In [4]:
model = autoencoders.AutoEncoder()
print(model)

AutoEncoder(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=12, bias=True)
  (fc4): Linear(in_features=12, out_features=2, bias=True)
  (fc5): Linear(in_features=2, out_features=12, bias=True)
  (fc6): Linear(in_features=12, out_features=50, bias=True)
  (fc7): Linear(in_features=50, out_features=128, bias=True)
  (fc8): Linear(in_features=128, out_features=784, bias=True)
)


## Loss and training routine

In [7]:
device = torch.device("cpu")

In [8]:
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

In [9]:
log_interval = 200

In [10]:
import torch.nn.functional as F

In [11]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x):
    
    # reconstruction = F.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
    reconstruction = F.mse_loss(recon_x, x.view(-1, 28*28), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    # KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return reconstruction


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch = model(data)
        loss = loss_function(recon_batch, data)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('>> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch = model(data)
            test_loss += loss_function(recon_batch, data).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(100, 1, 28, 28)[:n]])
                utils.save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('>> Test set loss: {:.4f}'.format(test_loss))

## Training

In [15]:
with torch.no_grad():
    for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch = model(data)
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(100, 1, 28, 28)[:n]])
                utils.save_image(comparison.cpu(),
                         'results/reconstruction_0.png', nrow=n)
                
    sample = 3 * torch.randn(64, 2).to(device)
    sample = model.decode(sample).cpu()
    utils.save_image(sample.view(64, 1, 28, 28),
               'results/sample_0.png')

In [19]:
import numpy as np

In [72]:
space = np.array([[x, y] for y in np.linspace(0, 5, 16) for x in np.linspace(0, 5, 16)], dtype=np.float)
space = torch.from_numpy(space).type(torch.FloatTensor)

In [73]:
epoch += 1
with torch.no_grad():
    # sample = torch.randn(64, 2).to(device)
    sample = space
    sample = model.decode(sample).cpu()
    utils.save_image(sample.view(16 * 16, 1, 28, 28),
               'results/sample_' + str(epoch) + '.png', nrow=16)

In [74]:
for epoch in range(100, 200):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        # sample = torch.randn(64, 2).to(device)
        sample = space + .1 * torch.randn(16 * 16, 2).to(device)
        sample = model.decode(sample).cpu()
        utils.save_image(sample.view(16 * 16, 1, 28, 28),
                   'results/sample_' + str(epoch) + '.png', nrow=16)

>> Epoch: 100 Average loss: 28.0925
>> Test set loss: 29.1073
>> Epoch: 101 Average loss: 27.8610
>> Test set loss: 28.4108
>> Epoch: 102 Average loss: 27.5723
>> Test set loss: 27.8513
>> Epoch: 103 Average loss: 27.5940
>> Test set loss: 27.8839
>> Epoch: 104 Average loss: 27.4592
>> Test set loss: 27.8237
>> Epoch: 105 Average loss: 27.4738
>> Test set loss: 27.8678
>> Epoch: 106 Average loss: 27.3777
>> Test set loss: 27.7388
>> Epoch: 107 Average loss: 27.2969
>> Test set loss: 27.8645
>> Epoch: 108 Average loss: 27.3988
>> Test set loss: 27.8786
>> Epoch: 109 Average loss: 27.2711
>> Test set loss: 27.6674
>> Epoch: 110 Average loss: 27.2504
>> Test set loss: 27.8331
>> Epoch: 111 Average loss: 27.4583
>> Test set loss: 27.7299
>> Epoch: 112 Average loss: 27.2916
>> Test set loss: 27.8134
>> Epoch: 113 Average loss: 27.3659
>> Test set loss: 27.7088
>> Epoch: 114 Average loss: 27.3040
>> Test set loss: 27.7145
>> Epoch: 115 Average loss: 27.3061
>> Test set loss: 27.7898
>> Epoch

>> Epoch: 138 Average loss: 27.2994
>> Test set loss: 27.8493
>> Epoch: 139 Average loss: 27.3739
>> Test set loss: 27.8157
>> Epoch: 140 Average loss: 27.2803
>> Test set loss: 27.7956
>> Epoch: 141 Average loss: 27.7083
>> Test set loss: 27.9262
>> Epoch: 142 Average loss: 27.2205
>> Test set loss: 27.9210
>> Epoch: 143 Average loss: 27.1406
>> Test set loss: 27.5399
>> Epoch: 144 Average loss: 27.2417
>> Test set loss: 27.6991
>> Epoch: 145 Average loss: 27.2332
>> Test set loss: 27.7731
>> Epoch: 146 Average loss: 27.0897
>> Test set loss: 27.5159
>> Epoch: 147 Average loss: 27.1091
>> Test set loss: 27.5932
>> Epoch: 148 Average loss: 27.0706
>> Test set loss: 27.4927
>> Epoch: 149 Average loss: 27.0501
>> Test set loss: 27.4842
>> Epoch: 150 Average loss: 27.1112
>> Test set loss: 27.5175
>> Epoch: 151 Average loss: 27.1166
>> Test set loss: 27.5081
>> Epoch: 152 Average loss: 27.1016
>> Test set loss: 27.4090
>> Epoch: 153 Average loss: 27.0550
>> Test set loss: 27.5480
>> Epoch

>> Epoch: 176 Average loss: 26.9754
>> Test set loss: 27.5705
>> Epoch: 177 Average loss: 26.9804
>> Test set loss: 27.3789
>> Epoch: 178 Average loss: 26.8892
>> Test set loss: 27.7058
>> Epoch: 179 Average loss: 26.9147
>> Test set loss: 27.4598
>> Epoch: 180 Average loss: 27.0304
>> Test set loss: 27.3528
>> Epoch: 181 Average loss: 26.8897
>> Test set loss: 27.6347
>> Epoch: 182 Average loss: 27.0705
>> Test set loss: 27.7295
>> Epoch: 183 Average loss: 27.0751
>> Test set loss: 27.5121
>> Epoch: 184 Average loss: 26.9925
>> Test set loss: 27.4053
>> Epoch: 185 Average loss: 26.9059
>> Test set loss: 27.4238
>> Epoch: 186 Average loss: 26.9411
>> Test set loss: 27.2976
>> Epoch: 187 Average loss: 26.8023
>> Test set loss: 27.3571
>> Epoch: 188 Average loss: 26.7959
>> Test set loss: 27.2420
>> Epoch: 189 Average loss: 26.7553
>> Test set loss: 27.3986
>> Epoch: 190 Average loss: 26.8741
>> Test set loss: 27.3006
>> Epoch: 191 Average loss: 26.8649
>> Test set loss: 27.3316
>> Epoch

# Variational Auto-Encoder

In [23]:
vae = autoencoders.VariationalAutoEncoder()
print(vae)

VariationalAutoEncoder(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=12, bias=True)
  (fc41): Linear(in_features=12, out_features=2, bias=True)
  (fc42): Linear(in_features=12, out_features=2, bias=True)
  (fc5): Linear(in_features=2, out_features=12, bias=True)
  (fc6): Linear(in_features=12, out_features=50, bias=True)
  (fc7): Linear(in_features=50, out_features=128, bias=True)
  (fc8): Linear(in_features=128, out_features=784, bias=True)
)


## Loss and training routine

In [24]:
device = torch.device("cpu")

In [25]:
learning_rate = 1e-3
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate, weight_decay=1e-5)

In [26]:
log_interval = 200

In [27]:
import torch.nn.functional as F

In [28]:
from tensorboardX import SummaryWriter
writer = SummaryWriter()

In [29]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function_vae(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, log_variance = vae(data)
        loss = loss_function_vae(recon_batch, data, mu, log_variance)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        # if batch_idx % log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #         epoch, batch_idx * len(data), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader),
        #         loss.item() / len(data)))

    print('>> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    
    writer.add_scalar('data/train-loss', train_loss / len(train_loader.dataset), epoch)


def test(epoch):
    
    vae.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, log_variance = vae(data)
            test_loss += loss_function_vae(recon_batch, data, mu, log_variance).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(100, 1, 28, 28)[:n]])
                # utils.save_image(comparison.cpu(),
                #          'results/vae-reconstruction_' + str(epoch) + '.png', nrow=n)
                
                writer.add_image('reconstruction', comparison.cpu(), epoch)

    test_loss /= len(test_loader.dataset)
    print('>> Test set loss: {:.4f}'.format(test_loss))
    
    writer.add_scalar('data/test-loss', test_loss, epoch)

## Training

In [30]:
import numpy as np

In [31]:
space = np.array([[x, y] for y in np.linspace(-1.5, 1.5, 16) for x in np.linspace(-1.5, 1.5, 16)], dtype=np.float)
space = torch.from_numpy(space).type(torch.FloatTensor)

In [32]:
for epoch in range(1, 100):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        # sample = torch.randn(64, 2).to(device)
        sample = space # + .05 * torch.randn(16 * 16, 2).to(device)
        sample = vae.decode(sample).cpu()
        utils.save_image(sample.view(16 * 16, 1, 28, 28),
                   'results/vae-sample_' + str(epoch) + '.png', nrow=16)
        writer.add_image('sample', utils.make_grid(sample.view(16 * 16, 1, 28, 28), nrow=16), epoch)

>> Epoch: 1 Average loss: 207.4965
>> Test set loss: 183.7865
>> Epoch: 2 Average loss: 173.9688
>> Test set loss: 168.5135
>> Epoch: 3 Average loss: 165.8932
>> Test set loss: 163.6707
>> Epoch: 4 Average loss: 161.1973
>> Test set loss: 159.0849
>> Epoch: 5 Average loss: 157.4966
>> Test set loss: 156.1624
>> Epoch: 6 Average loss: 155.0789
>> Test set loss: 154.2824
>> Epoch: 7 Average loss: 153.4772
>> Test set loss: 153.5132
>> Epoch: 8 Average loss: 152.2615
>> Test set loss: 151.9901
>> Epoch: 9 Average loss: 151.1153
>> Test set loss: 151.1691
>> Epoch: 10 Average loss: 150.1952
>> Test set loss: 150.4310
>> Epoch: 11 Average loss: 149.4633
>> Test set loss: 149.4782
>> Epoch: 12 Average loss: 148.7509
>> Test set loss: 148.9511
>> Epoch: 13 Average loss: 148.2019
>> Test set loss: 148.4081
>> Epoch: 14 Average loss: 147.6030
>> Test set loss: 147.8209
>> Epoch: 15 Average loss: 147.0881
>> Test set loss: 147.1794
>> Epoch: 16 Average loss: 146.7729
>> Test set loss: 147.1360
>

KeyboardInterrupt: 