In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from progressbar import ProgressBar
import time,os
from sklearn.utils import shuffle

In [None]:
from lasagne.layers import InputLayer,DenseLayer,ConcatLayer
from lasagne.nonlinearities import rectify,linear,softplus,sigmoid
from lasagne.updates import adam

In [None]:
from Tars.models import VAE
from Tars.distributions import Gaussian,Bernoulli
from Tars.load_data import mnist

In [None]:
load,plot = mnist('../datasets/')
train_x,train_y,_,_,test_x,test_y = load(test=True)

x_dim       = 28*28
z_dim       = 64
y_dim       = 10

activation = rectify
seed = 1234
np.random.seed(seed)

n_epoch = 100
n_batch = 100

In [None]:
x = InputLayer((None,x_dim))
y = InputLayer((None,y_dim))
z = InputLayer((None,z_dim))

q_0  = DenseLayer(ConcatLayer([x,y]),num_units=512,nonlinearity=activation)
q_1  = DenseLayer(q_0,num_units=512,nonlinearity=activation)
q_mean = DenseLayer(q_1,num_units=z_dim,nonlinearity=linear)
q_var = DenseLayer(q_1,num_units=z_dim,nonlinearity=softplus)
q = Gaussian(q_mean,q_var,given=[x,y]) #q(z|x,y)

p_0  = DenseLayer(ConcatLayer([z,y]),num_units=512,nonlinearity=activation)
p_1  = DenseLayer(p_0,num_units=512,nonlinearity=activation)
p_mean = DenseLayer(p_1,num_units=x_dim,nonlinearity=sigmoid)
p = Bernoulli(p_mean,given=[z,y]) #p(x|z,y)

In [None]:
model = VAE(q,p,n_batch=n_batch,optimizer=adam)

In [None]:
n_sample = 10
choice   = np.random.choice(test_x.shape[0], n_batch)
sample_x = test_x[choice]
sample_y = test_y[choice]

def plot_image(t,i):
    fig = plt.figure(figsize=(12, 8))

    y = np.eye(y_dim).astype(np.float32)
    sample_z = q.np_sample_mean_given_x(sample_x,sample_y)
    X = np.array([[np.array(p.np_sample_mean_given_x(_z[np.newaxis],
                                                  y[_y][np.newaxis]))[0]
                   for _z in sample_z[:n_sample]]
                  for _y in range(y_dim)])

    for j in range(n_sample):
        ax = fig.add_subplot(10, 11, 11 * j + 1)
        _X,cmap = plot(sample_x[j][np.newaxis])
        ax.imshow(_X[0],cmap)
        ax.axis('off')
        for k in range(y_dim):
            ax = fig.add_subplot(10, 11, 11 * j + k + 2)
            _X,cmap = plot(X[k,j][np.newaxis])
            ax.imshow(_X[0],cmap)
            ax.axis('off')

    plt.savefig('../plot/%d/%04d.jpg'%(t,i))
    plt.close()

In [None]:
t = int(time.time())
os.mkdir('../plot/%d' % t)

pbar = ProgressBar(maxval=n_epoch).start()
for i in range(1, n_epoch+1):
    shuffle(train_x, train_y)
    lowerbound_train = model.train([train_x,train_y])

    if (i%10 == 0) or (i == 1):
        log_likelihood_test = model.test([test_x,test_y],k=10)
        lw = "epoch = %d, lower bound (train) = %lf (%lf %lf) lower bound (test) = %lf\n" %(i,sum(lowerbound_train),lowerbound_train[0],lowerbound_train[1],np.mean(log_likelihood_test))
        f = open("../plot/%d/temp.txt" % t, "a")
        f.write(lw)
        f.close()
        print lw
        plot_image(t,i)
        
    pbar.update(i)