# Semi-invertible autoencoder

In [1]:
import numpy as np
import logging
import sys
from matplotlib import pyplot as plt
import torch

%matplotlib inline
sys.path.append("../")


from aef.models.autoencoding_flow import TwoStepAutoencodingFlow
from aef.trainer import AutoencodingFlowTrainer, NumpyDataset
from aef.losses import nll, mse

logging.basicConfig(
    format="%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s",
    datefmt="%H:%M",
    level=logging.INFO,
)


## Data

In [2]:
x = np.load("../data/tth/x_train.npy")
x_means = np.mean(x, axis=0)
x_stds = np.std(x, axis=0)
x = (x - x_means[np.newaxis,:]) / x_stds[np.newaxis,:]
y = np.ones(x.shape[0])

tth_data = NumpyDataset(x, y)


In [3]:
print(x.shape, np.mean(x, axis=0), np.std(x, axis=0), np.all(np.isfinite(x)))

(1000000, 48) [ 2.71121316e-05  2.21955761e-05 -2.06415079e-05  5.99880950e-05
  3.58828365e-05  2.92703389e-05 -2.56204794e-06 -2.69238922e-06
 -2.25854528e-05  1.57471004e-05  1.90822193e-05 -8.24490144e-06
 -1.19448589e-06 -3.90379773e-06 -9.44798273e-07 -1.75509683e-06
 -4.38888082e-06 -1.60391392e-06  1.09348264e-04  6.28851849e-05
  7.23982739e-05  5.32308986e-05  6.58810459e-05  9.76492156e-05
  5.11364924e-08  1.30453106e-08 -5.40003198e-09  7.93480837e-09
  4.53948967e-09 -2.53895518e-08 -2.50021894e-05 -1.36102271e-05
  3.20206273e-05 -4.16706207e-08  5.87673483e-08 -1.34949687e-05
 -2.23240058e-05  2.96114084e-08 -7.19317228e-08 -1.31591298e-02
 -2.75077327e-05  1.56093520e-08  3.62920751e-08  8.02898370e-09
  7.22730142e-09  6.01409695e-08 -1.94656842e-08 -2.85034183e-08] [1.0000378  1.0000255  1.0000088  1.0000036  1.0000488  1.0000529
 1.000028   1.0000645  1.0000403  1.0000135  1.000053   1.0000312
 1.0014762  1.0023024  0.99985677 1.0018277  1.0017785  1.0054659
 1.0000

## Train autoencoder

In [4]:
ae = TwoStepAutoencodingFlow(data_dim=48, latent_dim=10, steps_inner=5, steps_outer=5)

10:44 aef.models.autoencod INFO    Created autoencoding flow with 3546562 trainable parameters


In [5]:
trainer = AutoencodingFlowTrainer(ae)
trainer.train(
    dataset=tth_data,
    loss_functions=[mse],
    loss_labels=["MSE"],
    loss_weights=[1.],
    batch_size=256,
    epochs=5,
    verbose="all",
    initial_lr=1.e-3,
    final_lr=1.e-4,
)

11:57 aef.trainer          INFO    Epoch   1: train loss  0.37535 (MSE:  0.375)
11:57 aef.trainer          INFO               val. loss   0.33899 (MSE:  0.339)
13:14 aef.trainer          INFO    Epoch   2: train loss  0.31291 (MSE:  0.313)
13:14 aef.trainer          INFO               val. loss   0.29001 (MSE:  0.290)
14:24 aef.trainer          INFO    Epoch   3: train loss  0.28160 (MSE:  0.282)
14:24 aef.trainer          INFO               val. loss   0.27172 (MSE:  0.272)
15:35 aef.trainer          INFO    Epoch   4: train loss  0.26741 (MSE:  0.267)
15:35 aef.trainer          INFO               val. loss   0.25783 (MSE:  0.258)
16:44 aef.trainer          INFO    Epoch   5: train loss  0.25443 (MSE:  0.254)
16:44 aef.trainer          INFO               val. loss   0.24470 (MSE:  0.245)
16:44 aef.trainer          INFO    Early stopping did not improve performance


(array([0.37534592, 0.31290702, 0.28160141, 0.26740549, 0.2544259 ]),
 array([0.3389913 , 0.29000982, 0.2717246 , 0.25783074, 0.24469922]))

In [6]:
trainer = AutoencodingFlowTrainer(ae)
trainer.train(
    dataset=tth_data,
    loss_functions=[mse, nll],
    loss_labels=["MSE", "NLL"],
    loss_weights=[1., 1.e-5],
    batch_size=256,
    epochs=5,
    verbose="all",
    initial_lr=1.e-3,
    final_lr=1.e-4,
    parameters=ae.outer_transform.parameters()
)

17:41 aef.trainer          INFO    Epoch   1: train loss  0.34536 (MSE:  0.345, NLL: 81.107)
17:41 aef.trainer          INFO               val. loss   0.43150 (MSE:  0.431, NLL: 77.953)
18:27 aef.trainer          INFO    Epoch   2: train loss  0.34496 (MSE:  0.344, NLL: 77.098)
18:27 aef.trainer          INFO               val. loss   0.33012 (MSE:  0.329, NLL: 79.739)
19:09 aef.trainer          INFO    Epoch   3: train loss  0.30396 (MSE:  0.303, NLL: 78.665)
19:09 aef.trainer          INFO               val. loss   0.28741 (MSE:  0.287, NLL: 77.428)
19:53 aef.trainer          INFO    Epoch   4: train loss  0.28227 (MSE:  0.281, NLL: 77.412)
19:53 aef.trainer          INFO               val. loss   0.26745 (MSE:  0.267, NLL: 75.137)
20:37 aef.trainer          INFO    Epoch   5: train loss  0.26703 (MSE:  0.266, NLL: 75.967)
20:37 aef.trainer          INFO               val. loss   0.25797 (MSE:  0.257, NLL: 76.492)
20:37 aef.trainer          INFO    Early stopping did not improve perf

(array([0.34536268, 0.34495671, 0.30395755, 0.28226989, 0.26703403]),
 array([0.43150306, 0.33011812, 0.28740557, 0.26745002, 0.25797144]))

## Visualize latent space

In [7]:
n = 1000

x = torch.cat([mnist[i][0].unsqueeze(0) for i in range(n)], dim=0)
y = np.asarray([mnist[i][1] for i in range(n)])

x = x.view(x.size(0), -1)
x.to(self.device, self.dtype)
x_out, _, u = self.model(x)

x = x.detach().numpy().reshape(-1, 28, 28)
x_out = x_out.detach().numpy().reshape(-1, 28, 28)
u = u.detach().numpy().reshape(x_out.shape[0], -1)
tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300).fit_transform(u)

NameError: name 'mnist' is not defined

In [None]:
fig = plt.figure(figsize=(10,5))

ax = plt.subplot(1,2,1)
for i in range(10):
    plt.scatter(u[y==i][:,0], u[y==i][:,1], c="C{}".format(i), s=20., label="{}".format(i+1))
plt.legend()
plt.xlabel(r"$u_0$")
plt.ylabel(r"$u_1$")
    
ax = plt.subplot(1,2,2)
for i in range(10):
    plt.scatter(tsne[y==i][:,0], tsne[y==i][:,1], c="C{}".format(i), s=20., label="{}".format(i+1))
plt.legend()
plt.xlabel(r"$t-SNE component 0$")
plt.ylabel(r"$t-SNE component 1$")

plt.tight_layout()
plt.savefig("../figures/mnist_latent.pdf")


## Visualize reconstruction

In [None]:
x = torch.cat([mnist[i][0].unsqueeze(0) for i in range(1000)], dim=0)
y = np.asarray([mnist[i][1] for i in range(1000)])

In [None]:
h = ae.encoder(x)
x_out = ae.decoder(h)
h = h.detach().numpy()
x_out = x_out.detach().numpy()

In [None]:
fig = plt.figure(figsize=(12,12))

for i in range(18):
    ax = plt.subplot(6, 6, 2*i + 1)
    plt.imshow(x[i].reshape((28,28)))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax = plt.subplot(6, 6, 2*i + 2)
    plt.imshow(x_out[i].reshape((28,28)))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
        
plt.tight_layout()
plt.savefig("../figures/reconstruction.pdf")
