In [1]:
import numpy as np
import logging
import sys
from matplotlib import pyplot as plt

%matplotlib inline
sys.path.append("../")

import torch
from torchvision.datasets import MNIST
from torchvision import transforms

from aef.models.aef import LinearAutoencoder, ConvolutionalAutoencoder, DenseAutoencoder
from aef.trainer import AutoencoderTrainer, AutoencoderFlowTrainer
from aef.losses import nll, mse

logging.basicConfig(
    format="%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s",
    datefmt="%H:%M",
    level=logging.INFO,
)


## Data

In [2]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist = MNIST('./data', download=True, transform=img_transform)

## Train autoencoder

In [3]:
ae = ConvolutionalAutoencoder()

In [4]:
trainer = AutoencoderTrainer(ae, output_filename="output/conv_phase1")
trainer.train(
    dataset=mnist,
    loss_functions=[mse],
    loss_labels=["MSE"],
    batch_size=256,
    epochs=10,
    verbose="all",
    initial_lr=1.e-3,
    final_lr=1.e-5
)

12:50 aef.trainer          INFO    Epoch   1: train loss  0.36704 (MSE:  0.367)
12:50 aef.trainer          INFO               val. loss   0.21302 (MSE:  0.213)
12:53 aef.trainer          INFO    Epoch   2: train loss  0.19476 (MSE:  0.195)
12:53 aef.trainer          INFO               val. loss   0.18284 (MSE:  0.183)
12:55 aef.trainer          INFO    Epoch   3: train loss  0.17712 (MSE:  0.177)
12:55 aef.trainer          INFO               val. loss   0.17298 (MSE:  0.173)
12:56 aef.trainer          INFO    Epoch   4: train loss  0.17013 (MSE:  0.170)
12:56 aef.trainer          INFO               val. loss   0.16768 (MSE:  0.168)
12:57 aef.trainer          INFO    Epoch   5: train loss  0.16665 (MSE:  0.167)
12:57 aef.trainer          INFO               val. loss   0.16520 (MSE:  0.165)
12:58 aef.trainer          INFO    Epoch   6: train loss  0.16459 (MSE:  0.165)
12:58 aef.trainer          INFO               val. loss   0.16414 (MSE:  0.164)
12:59 aef.trainer          INFO    Epoch

(array([0.36704266, 0.19476125, 0.1771214 , 0.17013408, 0.16664528,
        0.16459072, 0.16344604, 0.16273875, 0.1622842 , 0.16201567]),
 array([0.21302194, 0.18284147, 0.17298171, 0.16768125, 0.16519777,
        0.16414152, 0.16281395, 0.16222596, 0.16186571, 0.16167386]))

In [5]:
trainer = AutoencoderFlowTrainer(ae, output_filename="output/conv_phase2")
trainer.train(
    dataset=mnist,
    loss_functions=[nll],
    loss_labels=["NLL"],
    batch_size=256,
    epochs=10,
    verbose="all",
    initial_lr=1.e-3,
    final_lr=1.e-5,
    parameters=ae.flow.parameters()
)

13:04 aef.trainer          INFO    Epoch   1: train loss -579.98395 (NLL: -579.984)
13:04 aef.trainer          INFO               val. loss  -908.13443 (NLL: -908.134)
13:05 aef.trainer          INFO    Epoch   2: train loss -986.57163 (NLL: -986.572)
13:05 aef.trainer          INFO               val. loss  -1018.04145 (NLL: -1018.041)
13:06 aef.trainer          INFO    Epoch   3: train loss -1060.13679 (NLL: -1060.137)
13:06 aef.trainer          INFO               val. loss  -1060.43996 (NLL: -1060.440)
13:08 aef.trainer          INFO    Epoch   4: train loss -1092.51311 (NLL: -1092.513)
13:08 aef.trainer          INFO               val. loss  -1076.83530 (NLL: -1076.835)
13:09 aef.trainer          INFO    Epoch   5: train loss -1109.53771 (NLL: -1109.538)
13:09 aef.trainer          INFO               val. loss  -1104.93970 (NLL: -1104.940)
13:10 aef.trainer          INFO    Epoch   6: train loss -1119.15163 (NLL: -1119.152)
13:10 aef.trainer          INFO               val. loss  -11

(array([ -579.98395277,  -986.57163273, -1060.13678845, -1092.51311007,
        -1109.5377114 , -1119.15162936, -1124.8998427 , -1128.28424974,
        -1130.11095428, -1131.47389291]),
 array([ -908.13442631, -1018.04145218, -1060.43995589, -1076.83530439,
        -1104.9396952 , -1100.79012182, -1112.96235683, -1115.77121334,
        -1111.95489812, -1116.14018844]))

## Visualize latent space

In [6]:
x = torch.cat([mnist[i][0].unsqueeze(0) for i in range(1000)], dim=0)
y = np.asarray([mnist[i][1] for i in range(1000)])

In [7]:
h, u = ae.latent(x)
x_out = ae.decoder(h)
h = h.detach().numpy().reshape((-1,2))
u = u.detach().numpy().reshape((-1,2))

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [16, 16, 3, 3], but got 2-dimensional input of size [1000, 784] instead

In [None]:
fig = plt.figure(figsize=(10,5))

ax = plt.subplot(1,2,1)
for i in range(10):
    plt.scatter(h[y==i][:,0], h[y==i][:,1], c="C{}".format(i), s=20., label="{}".format(i+1))
plt.legend()
    
ax = plt.subplot(1,2,2)
for i in range(10):
    plt.scatter(u[y==i][:,0], u[y==i][:,1], c="C{}".format(i), s=20., label="{}".format(i+1))
plt.legend()

plt.tight_layout()
plt.show()


## Visualize reconstruction

In [None]:
x = torch.cat([mnist[i][0].unsqueeze(0) for i in range(1000)], dim=0)
y = np.asarray([mnist[i][1] for i in range(1000)])

In [None]:
h = ae.encoder(x)
x_out = ae.decoder(h)
h = h.detach().numpy()
x_out = x_out.detach().numpy()

In [None]:
fig = plt.figure(figsize=(12,12))

for i in range(18):
    ax = plt.subplot(6, 6, 2*i + 1)
    plt.imshow(x[i].reshape((28,28)))
    ax = plt.subplot(6, 6, 2*i + 2)
    plt.imshow(x_out[i].reshape((28,28)))
        
plt.tight_layout()
plt.show()