In [6]:
import numpy as np
import logging
import sys
from matplotlib import pyplot as plt

%matplotlib inline
sys.path.append("../")

import torch
from torchvision.datasets import MNIST
from torchvision import transforms

from aef.models.aef import LinearAutoencoder, ConvolutionalAutoencoder, DenseAutoencoder
from aef.trainer import AutoencoderTrainer, AutoencoderFlowTrainer
from aef.losses import nll, mse

logging.basicConfig(
    format="%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s",
    datefmt="%H:%M",
    level=logging.INFO,
)


## Data

In [2]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist = MNIST('./data', download=True, transform=img_transform)

## Train autoencoder

In [3]:
ae = LinearAutoencoder(latent_dim=20)

In [4]:
trainer = AutoencoderTrainer(ae, output_filename="output/linear_phase1")
trainer.train(
    dataset=mnist,
    loss_functions=[mse],
    loss_labels=["MSE"],
    batch_size=256,
    epochs=10,
    verbose="all",
    initial_lr=1.e-3,
    final_lr=1.e-5
)

12:44 aef.trainer          INFO    Epoch   1: train loss  0.38373 (MSE:  0.384)
12:44 aef.trainer          INFO               val. loss   0.23628 (MSE:  0.236)
12:44 aef.trainer          INFO    Epoch   2: train loss  0.21599 (MSE:  0.216)
12:44 aef.trainer          INFO               val. loss   0.19546 (MSE:  0.195)
12:44 aef.trainer          INFO    Epoch   3: train loss  0.18591 (MSE:  0.186)
12:44 aef.trainer          INFO               val. loss   0.17671 (MSE:  0.177)
12:45 aef.trainer          INFO    Epoch   4: train loss  0.17264 (MSE:  0.173)
12:45 aef.trainer          INFO               val. loss   0.16812 (MSE:  0.168)
12:45 aef.trainer          INFO    Epoch   5: train loss  0.16611 (MSE:  0.166)
12:45 aef.trainer          INFO               val. loss   0.16357 (MSE:  0.164)
12:46 aef.trainer          INFO    Epoch   6: train loss  0.16253 (MSE:  0.163)
12:46 aef.trainer          INFO               val. loss   0.16095 (MSE:  0.161)
12:46 aef.trainer          INFO    Epoch

(array([0.38373476, 0.2159896 , 0.18590829, 0.17263613, 0.16611483,
        0.16253368, 0.16045673, 0.15920294, 0.15843015, 0.15796479]),
 array([0.23627509, 0.19546227, 0.17670867, 0.16811791, 0.16356901,
        0.16094959, 0.15944439, 0.15852198, 0.15791299, 0.15756598]))

In [7]:
trainer = AutoencoderFlowTrainer(ae, output_filename="output/linear_phase2")
trainer.train(
    dataset=mnist,
    loss_functions=[nll],
    loss_labels=["NLL"],
    batch_size=256,
    epochs=10,
    verbose="all",
    initial_lr=1.e-3,
    final_lr=1.e-5,
    parameters=ae.flow.parameters()
)

12:50 aef.trainer          INFO    Epoch   1: train loss  6.70760 (NLL:  6.708)
12:50 aef.trainer          INFO               val. loss   2.94287 (NLL:  2.943)
12:51 aef.trainer          INFO    Epoch   2: train loss  2.21820 (NLL:  2.218)
12:51 aef.trainer          INFO               val. loss   1.87041 (NLL:  1.870)
12:51 aef.trainer          INFO    Epoch   3: train loss  1.52925 (NLL:  1.529)
12:51 aef.trainer          INFO               val. loss   1.51029 (NLL:  1.510)
12:51 aef.trainer          INFO    Epoch   4: train loss  1.21785 (NLL:  1.218)
12:51 aef.trainer          INFO               val. loss   1.25081 (NLL:  1.251)
12:52 aef.trainer          INFO    Epoch   5: train loss  1.05185 (NLL:  1.052)
12:52 aef.trainer          INFO               val. loss   1.24243 (NLL:  1.242)
12:52 aef.trainer          INFO    Epoch   6: train loss  0.96892 (NLL:  0.969)
12:52 aef.trainer          INFO               val. loss   1.12239 (NLL:  1.122)
12:53 aef.trainer          INFO    Epoch

(array([6.70759836, 2.21820221, 1.5292514 , 1.21785227, 1.05184544,
        0.96891843, 0.90665354, 0.87589779, 0.85561718, 0.84586942]),
 array([2.94286701, 1.87040907, 1.51028906, 1.25081363, 1.24242838,
        1.12238981, 1.01575777, 1.04433545, 1.07195497, 1.10258527]))

## Visualize latent space

In [None]:
x = torch.cat([mnist[i][0].unsqueeze(0) for i in range(1000)], dim=0)
y = np.asarray([mnist[i][1] for i in range(1000)])

In [None]:
h, u = ae.latent(x)
x_out = ae.decoder(h)
h = h.detach().numpy().reshape((-1,2))
u = u.detach().numpy().reshape((-1,2))

In [None]:
fig = plt.figure(figsize=(10,5))

ax = plt.subplot(1,2,1)
for i in range(10):
    plt.scatter(h[y==i][:,0], h[y==i][:,1], c="C{}".format(i), s=20., label="{}".format(i+1))
plt.legend()
    
ax = plt.subplot(1,2,2)
for i in range(10):
    plt.scatter(u[y==i][:,0], u[y==i][:,1], c="C{}".format(i), s=20., label="{}".format(i+1))
plt.legend()

plt.tight_layout()
plt.show()


## Visualize reconstruction

In [None]:
x = torch.cat([mnist[i][0].unsqueeze(0) for i in range(1000)], dim=0)
y = np.asarray([mnist[i][1] for i in range(1000)])

In [None]:
h = ae.encoder(x)
x_out = ae.decoder(h)
h = h.detach().numpy()
x_out = x_out.detach().numpy()

In [None]:
fig = plt.figure(figsize=(12,12))

for i in range(18):
    ax = plt.subplot(6, 6, 2*i + 1)
    plt.imshow(x[i].reshape((28,28)))
    ax = plt.subplot(6, 6, 2*i + 2)
    plt.imshow(x_out[i].reshape((28,28)))
        
plt.tight_layout()
plt.show()