In [1]:
import numpy as np
from VAE import VAE
from matplotlib import pyplot as plt
from utils import *

In [2]:
plt.rcParams["figure.figsize"] = (10, 10)

<h2>Params</h2>

In [3]:
x_train = np.load('data.npy')

epochs = 300
batch_size = 32

<h2>Train</h2>

In [None]:
vae = VAE()

ces = []
kls = []
for epoch in range(epochs):
    indices = np.arange(x_train.shape[0])
    np.random.shuffle(indices)

    for b in range(0, indices.shape[0], batch_size):
        out = vae.forward(x_train[indices[b:b+batch_size]])
        _ = vae.backward(x_train[indices[b:b+batch_size]], out)

    out_all = vae.forward(x_train)
    l = vae.loss(x_train, out_all)
    ces.append(vae.ce)
    kls.append(vae.kl)

    print("Epoch {0} :: Reconstruction loss {1} :: Regularization loss {2}".format(epoch+1, ces[-1], kls[-1]))

<h2>Reconstruction</h2>

In [10]:
ims = []
for i in range(100):
    j = np.random.randint(0, x_train.shape[0]-1)
    mu, logvar = vae.encode(x_train[j])
    eps = np.random.normal(0, 1, vae.z_units)
    z = mu + np.multiply(eps, np.sqrt(np.exp(logvar)))
    im = vae.decode(z).reshape(28, 28)
    ims.append(im)
ims = np.array(ims)

In [11]:
img_tile(ims, path="images", epoch=300, save=True)

In [4]:
def interpolate(im1, im2, steps=10):
    mu1, logvar1 = vae.encode(im1)
    mu2, logvar2 = vae.encode(im2)
    
    ims = []
    
    dmu, dlogvar = mu2 - mu1, logvar2 - logvar1
    for delta in range(steps):
        mu = mu1 + delta/steps * dmu
        logvar = logvar1 + delta/steps * dlogvar
        eps = np.random.normal(0, 1, vae.z_units)
        
        z = mu + np.multiply(eps, np.sqrt(np.exp(logvar)))
        x = vae.decode(z)
        
        ims.append(x.reshape(28, 28))
    return np.array(ims)

In [5]:
def plot_images(images, cols=3, cmap='gray'):
    rows = (len(images) + cols - 1) / cols
    for i, image in enumerate(images):
        plt.subplot(rows, cols, i+1)
        plt.imshow(image, cmap=cmap)
    plt.show()

In [None]:
ims = interpolate(x_train[17], x_train[23])
img_save(ims, path="images", epoch=22)

<h2>Losses</h2>

In [None]:
plt.plot(list(range(1, len(ces)+1)), ces)
plt.xlabel("Epochs")
plt.ylabel("Reconstruction Loss")
plt.title("Reconstruction Loss")
#plt.savefig("images/ce.png")
plt.show()

In [None]:
plt.plot(list(range(1, len(kls)+1)), kls)
plt.xlabel("Epochs")
plt.ylabel("Regularization Loss")
plt.title("Regularization Loss")
#plt.savefig("images/kl.png")
plt.show()

<h2>Save</h2>

In [6]:
fname = "epochs300"

In [None]:
import pickle
with open(f"states/{fname}.pkl", "wb") as f:
    pickle.dump(vae, f)

<h2>Load</h2>

In [7]:
import pickle
with open(f"states/{fname}.pkl", "rb") as f:
    vae = pickle.load(f)