In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from progressbar import ProgressBar
import time,os

In [None]:
from lasagne.layers import InputLayer,DenseLayer,ReshapeLayer,NonlinearityLayer
from lasagne.nonlinearities import rectify,linear,softplus,sigmoid,softmax,tanh
from lasagne.updates import adam

In [None]:
from Tars.model import VAE
from Tars.distribution import Gaussian,Bernoulli,Categorical
from Tars.load_data import mnist

In [None]:
load,plot = mnist('../datasets/')
train_x,_,_,_,test_x,_ = load(test=True)

n_x       = 28*28
n_z       = 64
n_y       = 10
K = 10
N = 30

activation = rectify
rseed = 1234
np.random.seed(rseed)

n_epoch = 100
n_batch = 100

In [None]:
# Bernoulli
x = InputLayer((None,n_x))
q_0  = DenseLayer(x,num_units=512,nonlinearity=activation)
q_1  = DenseLayer(q_0,num_units=512,nonlinearity=activation)
q_mean = DenseLayer(q_1,num_units=K*N,nonlinearity=sigmoid)
q = Bernoulli(q_mean,given=[x],temp=0.01) #q(z|x)


z = InputLayer((None,K*N))
p_0  = DenseLayer(z,num_units=512,nonlinearity=activation)
p_1  = DenseLayer(p_0,num_units=512,nonlinearity=activation)
p_mean = DenseLayer(p_1,num_units=n_x,nonlinearity=sigmoid)
p = Bernoulli(p_mean,given=[z]) #p(x|z)

In [None]:
# Categorical
x = InputLayer((None,n_x))
q_0  = DenseLayer(x,num_units=512,nonlinearity=activation)
q_1  = DenseLayer(q_0,num_units=512,nonlinearity=activation)
q_2 = DenseLayer(q_1,num_units=N*K,nonlinearity=linear)
q_mean = NonlinearityLayer(ReshapeLayer(q_2,((-1,K))),nonlinearity=softmax)
q = Categorical(q_mean,given=[x],temp=0.001,n_dim=N) #q(z|x)

z = InputLayer((None,N*K))
p_0  = DenseLayer(z,num_units=512,nonlinearity=activation)
p_1  = DenseLayer(p_0,num_units=512,nonlinearity=activation)
p_mean = DenseLayer(p_1,num_units=n_x,nonlinearity=sigmoid)
p = Bernoulli(p_mean,given=[z]) #p(x|z)

In [None]:
model = VAE(q, p, n_batch=n_batch, optimizer=adam)

In [None]:
def plot_z(x,N,K,t,i):
    plt.subplot(131)
    X,cmap = plot(x[np.newaxis,:])
    plt.imshow(X[0],cmap)
    sample_z = q.np_sample_given_x(x[np.newaxis,:])
    plt.axis('off')

    plt.subplot(132)
    plt.imshow(sample_z[0].reshape((N,K)), interpolation='nearest',cmap="gray")
    plt.axis('off')
    
    sample_x = p.np_sample_mean_given_x(sample_z)
    plt.subplot(133)
    X,cmap = plot(sample_x[np.newaxis,:])
    plt.imshow(X[0],cmap)
    plt.axis('off')
    
    plt.savefig('../plot/%d/%04d_sample_z.jpg'%(t,i))
    plt.close()

In [None]:
t = int(time.time())
os.mkdir('../plot/%d' % t)

pbar = ProgressBar(maxval=n_epoch).start()
for i in range(1, n_epoch+1):
    np.random.shuffle(train_x)
    lowerbound_train = model.train([train_x])

    if (i%10 == 0) or (i == 1):
        log_likelihood_test = model.test([test_x],k=10)
        lw = "epoch = %d, lower bound (train) = %lf (%lf %lf) lower bound (test) = %lf\n" %(i,sum(lowerbound_train),lowerbound_train[0],lowerbound_train[1],np.mean(log_likelihood_test))
        f = open("../plot/%d/temp.txt" % t, "a")
        f.write(lw)
        f.close()
        print lw
        plot_z(test_x[0],N,K,t,i)
        
    pbar.update(i)