In [None]:
import numpy as np

import theano
import theano.tensor as T

import climin
import climin.stops
import climin.initialize

import breze
from breze.learn.base import theanox
from breze.learn.sgvb.dvbf import DeepVariationalBayesFilter

import gzip
import cPickle
import time

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
GPU = theano.config.device.startswith('gpu')
if GPU:
    import os
    os.environ['GNUMPY_IMPLICIT_CONVERSION'] = 'allow'

In [None]:
datafile = 'pendulum.pkl.gz'
with gzip.open(datafile,'rb') as f:
    train_set, val_set, test_set = cPickle.load(f)
X, U, S = train_set
VX, VU, VS = val_set
TX, TU, TS = test_set

XU = np.concatenate((X, U), -1)
VXU = np.concatenate((VX, VU), -1)
image_dims = 16

In [None]:
n_latent=3

n_obs = X.shape[2]
n_control = U.shape[2]
m = DeepVariationalBayesFilter(n_obs, n_control,
                               n_hiddens_recog=[128], n_state=n_latent, n_alpha=16,
                               n_hiddens_transition=[16], n_hiddens_gen=[128],
                               transfers_recog=['rectifier'], transfers_transition=['rectifier'],
                               transfers_gen=['rectifier'], n_samples=X.shape[1], zeroth_transition=True)

In [None]:
climin.initialize.randomize_normal(m.parameters.data, 0, 1e-1)

beta0 = 0.01
beta = beta0
m.parameters[m.beta] = np.minimum(1.0, beta)

f_reckl_loss = m.function(['inpt'], [m.rec_loss, m.kl])

losses = []

In [None]:
def plot(X, n_samples=3, n_timesteps=None):
    if n_timesteps == None:
        n_timesteps = X.shape[0]

    fig, axs = plt.subplots(n_samples, n_timesteps, figsize=(n_timesteps, n_samples), squeeze=True)

    for j in range(n_samples):
        for i in range(n_timesteps):
            axs[j, i].set_xticks([])
            axs[j, i].set_yticks([])

            axs[j, i].imshow(X[i, j, :image_dims * image_dims].reshape((image_dims, image_dims)), cmap='binary', interpolation='none')

    plt.tight_layout()

def plot_mean(X, n_timesteps=None):
    if n_timesteps == None:
        n_timesteps = X.shape[0]

    fig, axs = plt.subplots(1, n_timesteps, figsize=(n_timesteps, 10), squeeze=True)

    for i in range(n_timesteps):
        axs[i].set_xticks([]) 
        axs[i].set_yticks([])

        axs[i].imshow(X[i, :, :image_dims * image_dims].mean(0).reshape((image_dims, image_dims)), cmap='binary', interpolation='none')

    plt.tight_layout()

In [None]:
m.optimizer=('adadelta', {'step_rate': 0.1})
m.batch_size=500

beta_T=100000.0
beta0=0.01
max_iter=180000

passes_per_report=10
batches_per_pass = 25

report = climin.stops.ModuloNIterations(passes_per_report * batches_per_pass)
stop = climin.stops.AfterNIterations(max_iter)

beta = beta0

def schedule(info):
    if (info['n_iter'] % (passes_per_report * batches_per_pass)) == 0:
        m.parameters[m.beta] = np.minimum(1.0, m.parameters[m.beta].mean() + np.float32(passes_per_report * batches_per_pass) / beta_T)

start_time = time.time()

for info in m.powerfit((XU,), (VXU,), stop=stop, report=report, schedule=schedule):
    
    print time.time() - start_time
    start_time = time.time()

    info['rec_loss'], info['kl_loss'] = f_reckl_loss(theanox(XU))
    info['rec_loss'], info['kl_loss'] = np.asarray(info['rec_loss']), np.asarray(info['kl_loss'])

    info['val_rec_loss'], info['val_kl_loss'] = f_reckl_loss(theanox(VXU))
    info['val_rec_loss'], info['val_kl_loss'] = np.asarray(info['val_rec_loss']), np.asarray(info['val_kl_loss'])

    print '%(n_iter)i %(loss)g %(rec_loss)g %(val_rec_loss)g %(kl_loss)g ' % info,
    print m.parameters[m.beta]


    losses.append([info['loss'], info['rec_loss'], info['kl_loss'], info['val_loss'], info['val_rec_loss'], info['val_kl_loss']])


In [None]:
plt.plot(np.array(losses)[:, 2])

In [None]:
plt.plot(np.array(losses)[:, 1])
plt.plot(np.array(losses)[:, 4])

In [None]:
f_latent_sample = m.function(['inpt'], m.z)
Z = f_latent_sample(theanox(XU))

In [None]:
Z_flat = np.asarray(Z).reshape((-1, np.asarray(Z).shape[-1]))

In [None]:
f, axs = plt.subplots(3, 2, figsize=(15,20))
axs[0,0].set_title('Angle', fontsize=25)
axs[0,0].scatter(Z_flat[:, 0] , Z_flat[:, 1], c=S.reshape((-1, S.shape[-1]))[:, 0], cmap='jet')
axs[1,0].scatter(Z_flat[:, 0] , Z_flat[:, 2], c=S.reshape((-1, S.shape[-1]))[:, 0], cmap='jet')
axs[2,0].scatter(Z_flat[:, 1] , Z_flat[:, 2], c=S.reshape((-1, S.shape[-1]))[:, 0], cmap='jet')
axs[0,1].set_title('Velocity', fontsize=25)
axs[0,1].scatter(Z_flat[:, 0] , Z_flat[:, 1], c=S.reshape((-1, S.shape[-1]))[:, 1], cmap='jet')
axs[1,1].scatter(Z_flat[:, 0] , Z_flat[:, 2], c=S.reshape((-1, S.shape[-1]))[:, 1], cmap='jet')
axs[2,1].scatter(Z_flat[:, 1] , Z_flat[:, 2], c=S.reshape((-1, S.shape[-1]))[:, 1], cmap='jet')

In [None]:
from breze.arch.construct.layer.distributions import DiagGauss
f_gen = m.function(['inpt'], DiagGauss(m.gen_x_mean, m.gen_x_var).sample())

In [None]:
f_recon = m.function(['inpt'], m.output)

In [None]:
plot_mean(X)

In [None]:
plot_mean(f_recon(theanox(XU)))

In [None]:
plot_mean(f_gen(theanox(XU)))

In [None]:
plot(theanox(X))

In [None]:
plot(f_recon(theanox(XU)))

In [None]:
plot(f_gen(theanox(XU)))

In [None]:
plot(f_gen(theanox(np.repeat(np.concatenate((X, np.zeros_like(U)), 2), 5, 0))))