In [None]:
import cPickle
import gzip

from breze.learn.data import one_hot
from breze.learn.base import cast_array_to_local_type
from breze.learn.utils import tile_raster_images

import climin.stops


import climin.initialize

from breze.learn import sgvb
from matplotlib import pyplot as plt
from matplotlib import cm

import numpy as np

from IPython.html import widgets
%matplotlib inline 

import theano
theano.config.compute_test_value = 'ignore'#'raise'

In [None]:
GPU = theano.config.device.startswith('gpu')
if GPU:
    import os
    os.environ['GNUMPY_IMPLICIT_CONVERSION'] = 'allow'

In [None]:
datafile = '../mnist.pkl.gz' # deeplearning.net/data/mnist/mnist.pkl.gz
# Load data.                                                                                                   

with gzip.open(datafile,'rb') as f:                                                                        
    train_set, val_set, test_set = cPickle.load(f)                                                       

X, Z = train_set                                                                                               
VX, VZ = val_set
TX, TZ = test_set

Z = one_hot(Z, 10)
VZ = one_hot(VZ, 10)
TZ = one_hot(TZ, 10)

image_dims = 28, 28

X, Z, VX, VZ, TX, TZ = [cast_array_to_local_type(i) for i in (X, Z, VX,VZ, TX, TZ)]

In [None]:
fig, ax = plt.subplots(figsize=(9, 9))

img = tile_raster_images(np.array(X[:64]), image_dims, (8, 8), (1, 1))
ax.imshow(img, cmap=cm.binary)

In [None]:
batch_size = 200
#optimizer = 'rmsprop', {'step_rate': 1e-4, 'momentum': 0.95, 'decay': .95, 'offset': 1e-6}
#optimizer = 'adam', {'step_rate': .5, 'momentum': 0.9, 'decay': .95, 'offset': 1e-6}
optimizer = 'adam'

reload(sgvb)

class MyVAE(sgvb.VariationalAutoEncoder,
            sgvb.MlpGaussLatentVAEMixin, 
            sgvb.MlpBernoulliVisibleVAEMixin, 
            ):
    pass
kwargs = {}


# This is the number of random variables NOT the size of 
# the sufficient statistics for the random variables.
n_latents = 64
n_hidden = 512

m = MyVAE(X.shape[1], [n_hidden], n_latents, [n_hidden], ['rectifier'] * 1, ['rectifier'] * 1,
          optimizer=optimizer, batch_size=batch_size,
          **kwargs)

#m.exprs['loss'] += 0.001 * (m.parameters.enc_in_to_hidden ** 2).sum() / m.exprs['inpt'].shape[0]

climin.initialize.randomize_normal(m.parameters.data, 0, 1e-2)

#climin.initialize.sparsify_columns(m.parameters['enc_in_to_hidden'], 15)
#climin.initialize.sparsify_columns(m.parameters['enc_hidden_to_hidden_0'], 15)
#climin.initialize.sparsify_columns(m.parameters['dec_hidden_to_out'], 15)

#f_latent_mean = m.function(['inpt'], 'latent_mean')
#f_sample = m.function([('gen', 'layer-0-inpt')], 'output')
#f_recons = m.function(['inpt'], 'output')

In [None]:
m.estimate_nll(X[:10])

In [None]:
max_passes = 250
max_iter = max_passes * X.shape[0] / batch_size
n_report = X.shape[0] / batch_size

stop = climin.stops.AfterNIterations(max_iter)
pause = climin.stops.ModuloNIterations(n_report)

for i, info in enumerate(m.powerfit((X,), (VX,), stop, pause)):
    print i, info['loss'], info['val_loss']


In [None]:
m.parameters.data[...] = info['best_pars']

In [None]:
f_sample = m.function([m.recog_sample], m.vae.gen.sample())
f_recons = m.function(['inpt'], m.vae.gen.sample())

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))

S = f_sample(cast_array_to_local_type(np.random.randn(64, m.n_latent).astype('float32')))[:, :784].astype('float32')
img = tile_raster_images(np.array(S), image_dims, (8, 8), (1, 1))
axs[0].imshow(img, cmap=cm.binary)

R = f_recons(X[:64])[:, :784].astype('float32')
img = tile_raster_images(np.array(R), image_dims, (8, 8), (1, 1))

axs[1].imshow(img, cmap=cm.binary)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
img = tile_raster_images(np.array(m.parameters[m.vae.recog.mlp.layers[0].weights].T), image_dims, (10, 10), (1, 1))
axs[0].imshow(img, cmap=cm.binary)

img = tile_raster_images(np.array(m.parameters[m.vae.gen.mlp.layers[-1].weights]), image_dims, (10, 10), (1, 1))
axs[1].imshow(img, cmap=cm.binary)

In [None]:
f_L = m.function([m.vae.inpt], m.vae.recog.stt)

In [None]:
L = f_L(X)

In [None]:
fig, ax = plt.subplots(figsize=(9, 9))
ax.scatter(L[:, 0], L[:, 1], c=Z[:].argmax(1), lw=0, s=10, alpha=.2)