In [8]:
%load_ext autoreload
%autoreload 2
import vae.VAE
from keras.datasets import mnist
from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load MNIST Data and split into test/train

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train_re = x_train.reshape(x_train.shape[0], x_train.shape[1]*x_train.shape[2])
x_test_re = x_test.reshape(x_test.shape[0], x_test.shape[1]*x_test.shape[2])
print x_train_re.shape
print x_test_re.shape

In [None]:
train_std = np.std(x_train_re, axis=0)
train_std[train_std == 0] = 1
train_mu = np.mean(x_train_re, axis=0)
x_train_re_norm = (x_train_re - train_mu)/train_std
x_test_re_norm = (x_test_re - train_mu)/train_std

## Instaniate Gaussian-Gaussian VAE and Load Params

In [9]:
homoscedastic = True  # const var
if homoscedastic :
    v = vae.VAE.VAE(param_file='/home/gridsan/CH24434/sandbox/vae/params/homo_gauss_gauss_mnist.json')
else :
    v = vae.VAE.VAE(param_file='/home/gridsan/CH24434/sandbox/vae/params/hetero_gauss_gauss_mnist.json')
print(v.params)
v.construct()


Batch Size: 100
Num Epochs: 50
Number of Samples: 10

VAE Type: HomoscedasticGaussianGaussian

Optimizer: Adagrad
Learning Rate: 0.01


Name: Input
Type: Input
Output to Loss: False
Shape: (784,)



Name: Encode1
Type: Dense
Output to Loss: False
Size: 256
Activation: relu
Reshape: None 



Name: params_layer
Type: DenseKD
Output to Loss: True
K: 2
Concat: False

-----
Name: mu
Type: Dense
Output to Loss: False
Size: 2
Activation: linear
Reshape: None 

-----

-----
Name: log_sigma
Type: Dense
Output to Loss: False
Size: 2
Activation: linear
Reshape: None 

-----



Name: Sample1
Type: Sample
Output to Loss: False
Size: 2
Distribution: Gaussian



Name: Decode1
Type: Dense
Output to Loss: False
Size: 256
Activation: relu
Reshape: None 



Name: Decode2
Type: Dense
Output to Loss: False
Size: 784
Activation: sigmoid
Reshape: None 



Name: Input
Type: Input
Output to Loss: False
Shape: (784,)



Name: Encode1
Type: Dense
Output to Loss: False
Size: 256
Activation: relu
Reshape: None 



TypeError: unsupported operand type(s) for +: 'int' and 'tuple'

## Optionally Load a Saved Model

In [None]:
LOAD_MODEL = True
if LOAD_MODEL :
    if homoscedastic :
        v.model.load_weights('../saved_models/homo_gauss_gauss_mnist.h5')
    else :
        v.model.load_weights('../saved_models/hetero_gauss_gauss_mnist.h5')

## Fit the VAE

In [None]:
if homoscedastic :
    v.fit(x_train_re, shuffle=True, add_axis=False)
else :
    v.fit(x_train_re, shuffle=True, add_axis=True)

In [None]:
v.model.save_weights('../saved_models/hetero_gauss_gauss_mnist.h5')

## Run model on test data

In [None]:
if homoscedastic :
    pred = v.model.predict(x_test_re, batch_size=100).reshape(len(x_test_re),28,28)
else :
    pred = v.model.predict(x_test_re, batch_size=100)
print pred.shape

In [None]:
pred = mu_log_var.reshape(len(x_test_re),28,28,2)
print pred.shape

In [None]:
%matplotlib notebook
sample = 310

plt.subplot(1,2,1)
plt.imshow(x_test[sample,:,:])
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.subplot(1,2,2)
if homoscedastic :
    plt.imshow(pred[sample,:,:])
else :
    plt.imshow(pred[sample,:,0].reshape(28,28))
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)

## Reconstruction Error

In [None]:
%matplotlib notebook
#abs_loss_ber = np.genfromtxt('../data/ber_gauss_recon.loss')
lw = 2.5
N, W, H =  x_test.shape
abs_loss = np.abs(x_test-pred)
abs_loss = abs_loss.reshape(N*W*H)
plt.plot(np.sort(abs_loss), linewidth=lw, label='Gaussian-Gaussian Recon Loss')
plt.plot(np.sort(abs_loss_ber), linewidth=lw, label='Bernoulli-Gaussian Recon Loss')
plt.gca().axes.get_xaxis().set_ticks([])
plt.grid()
plt.legend(loc=2)
plt.ylabel('Absolute Difference')
plt.xlabel('Test Example')
plt.show()

## Encode some test data

In [None]:
ENCODE_MU = True
if ENCODE_MU :
    z = v.encode(x_test_re, 'mu')
else :
    z = np.exp(v.encode(x_test_re, 'log_sigma',skip=['mu']))
print z.shape

In [None]:
%matplotlib notebook
cmap = plt.get_cmap('jet', 10)
plt.scatter(z[:,0],z[:,1], c=y_test, cmap=cmap)
plt.show()
if ENCODE_MU :
    plt.xlabel('mu_1')
    plt.ylabel('mu_2')
else :
    plt.xlabel('log_var')
    plt.ylabel('log_var')
plt.colorbar()

## VAE Decoding

In [None]:
batch_size = 4
Z = np.random.normal(size=(batch_size, 2))
Z = np.array([[1,-2],#1
              [-1,2],#0
              [-1.1,-2],#7
              [2.5,1.5]])#3
x = v.decode(Z, 'Decode1').reshape(batch_size, 28, 28)
print x.shape

In [None]:
%matplotlib notebook
THRESH = None

for i in np.arange(0,x.shape[0]) :
    plt.subplot(2,2,i+1)
    if THRESH :
        plt.imshow(x[i,:,:] > THRESH)
    else :
        plt.imshow(x[i,:,:])
    #plt.colorbar()
    plt.gca().axes.get_xaxis().set_visible(False)
    plt.gca().axes.get_yaxis().set_visible(False)
plt.show()