In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from progressbar import ProgressBar
import time,os

In [None]:
from lasagne.layers import InputLayer,DenseLayer
from lasagne.nonlinearities import rectify,linear,softplus,sigmoid
from lasagne.updates import adam

In [None]:
from Tars.model import VAE
from Tars.distribution import Gaussian,Bernoulli
from Tars.load_data import mnist

In [None]:
load,plot = mnist('../datasets/')
train_x,_,_,_,test_x,_ = load(test=True)

n_x       = 28*28
n_z       = 64
n_y       = 10

activation = rectify
seed = 1234
np.random.seed(seed)

n_epoch = 100
n_batch = 100

In [None]:
x = InputLayer((None,n_x))
q_0  = DenseLayer(x,num_units=512,nonlinearity=activation)
q_1  = DenseLayer(q_0,num_units=512,nonlinearity=activation)
q_mean = DenseLayer(q_1,num_units=n_z,nonlinearity=linear)
q_var = DenseLayer(q_1,num_units=n_z,nonlinearity=softplus)
q = Gaussian(q_mean,q_var,given=[x]) #q(z|x)

z = InputLayer((None,n_z))
p_0  = DenseLayer(z,num_units=512,nonlinearity=activation)
p_1  = DenseLayer(p_0,num_units=512,nonlinearity=activation)
p_mean = DenseLayer(p_1,num_units=n_x,nonlinearity=sigmoid)
p = Bernoulli(p_mean,given=[z]) #p(x|z)

In [None]:
model = VAE(q,p,n_batch=n_batch,optimizer=adam,l=1)

In [None]:
n_sample = 100
sample_z  = np.random.standard_normal((n_batch, n_z)).astype(np.float32)

def plot_image(t,i):
    sample_x = p.np_sample_mean_given_x(sample_z)
    fig = plt.figure(figsize=(10,10))
    X,cmap = plot(sample_x[:n_sample])

    for j,x in enumerate(X):
            ax = fig.add_subplot(10, 10, j + 1)
            ax.imshow(x,cmap)
            ax.axis('off')

    plt.savefig('../plot/%d/%04d.jpg'%(t,i))
    plt.close()

In [None]:
t = int(time.time())
os.mkdir('../plot/%d' % t)

pbar = ProgressBar(maxval=n_epoch).start()
for i in range(1, n_epoch+1):
    np.random.shuffle(train_x)
    lowerbound_train = model.train([train_x])

    if (i%10 == 0) or (i == 1):
        log_likelihood_test = model.log_likelihood_test([test_x],k=10,mode='iw')
        lw = "epoch = %d, lower bound (train) = %lf (%lf %lf) lower bound (test) = %lf\n" %(i,sum(lowerbound_train),lowerbound_train[0],lowerbound_train[1],np.mean(log_likelihood_test))
        f = open("../plot/%d/temp.txt" % t, "a")
        f.write(lw)
        f.close()
        print lw
        plot_image(t,i)
        
    pbar.update(i)