# Draw Network with attention

Using code from 

[1]`https://github.com/ericjang/draw` 

[2]`https://github.com/ikostrikov/TensorFlow-VAE-GAN-DRAW`

In [1]:
%matplotlib inline

#Plotting
import matplotlib
import matplotlib.pyplot as plt
import sys

#Tensor flow
import tensorflow as tf
from tensorflow.models.rnn.rnn_cell import LSTMCell
#Easy way to get the data :)
from tensorflow.examples.tutorials import mnist

import numpy as np
import os

#Most of these dimensinos are used from 
A,B = 28,28 # image width,height
img_size = B*A # the canvas size

#Using the parameters from [1], as many of these dimensions aren't explicitly mentioned in the paper
enc_size = 256 # number of hidden units / output size in LSTM
dec_size = 256

#Parameterizing the attention window:
read_n = 5 # read glimpse grid width/height
write_n = 5 # write glimpse grid width/height
read_size = 2*read_n*read_n 
write_size = write_n*write_n
z_size=10 # QSampler output size
T=10 # MNIST generation sequence length

#These are standard NN net parameters
batch_size=100 # training minibatch size
train_iters=5000
learning_rate=1e-3 # learning rate for optimizer
eps=1e-8 # epsilon for numerical stability

print "All parameters set"

All parameters set


# Variable roundup

## Encoder RNN

* x = raw data read from the mnist images
* x_hat = error image, x - sigmoid(c_(t-1)) [equation 3 in the paper]

* h = encoded data 

## Decoder RNN
* z = samples drawn from the latent distribution [equation 6 in the paper]

In [2]:
#Define all elementary functions needed:

# Incredible abstraction in Tensorflow!
# Define the LSTMCell
# Documentation: https://www.tensorflow.org/versions/r0.7/tutorials/recurrent/index.html
lstm_enc = LSTMCell(enc_size, read_size+dec_size) # encoder Op
lstm_dec = LSTMCell(dec_size, z_size) # decoder Op

DO_SHARE = None #See note below, this is a necessary hack for initialization

#And define operations with those LSTMs
def encode(state,input):
    with tf.variable_scope("encoder",reuse=DO_SHARE):
        return lstm_enc(input,state)
def decode(state,input):
    with tf.variable_scope("decoder",reuse=DO_SHARE):
        return lstm_dec(input, state)
    
#For the Draw network with attention, we paramteri
def attn_window(scope,h_dec,N):
    with tf.variable_scope(scope,reuse=DO_SHARE):
        params=linear(h_dec,5)
    gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(1,5,params)
    gx=(A+1)/2*(gx_+1)
    gy=(B+1)/2*(gy_+1)
    sigma2=tf.exp(log_sigma2)
    delta=(max(A,B)-1)/(N-1)*tf.exp(log_delta) # batch x N
    return filterbank(gx,gy,sigma2,delta,N)+(tf.exp(log_gamma),)

def read(x,x_hat,h_dec_prev):
    Fx,Fy,gamma=attn_window("read",h_dec_prev,read_n)
    def filter_img(img,Fx,Fy,gamma,N):
        Fxt=tf.transpose(Fx,perm=[0,2,1])
        img=tf.reshape(img,[-1,B,A])
        glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt))
        glimpse=tf.reshape(glimpse,[-1,N*N])
        return glimpse*tf.reshape(gamma,[-1,1])
    x=filter_img(x,Fx,Fy,gamma,read_n) # batch x (read_n*read_n)
    x_hat=filter_img(x_hat,Fx,Fy,gamma,read_n)
    return tf.concat(1,[x,x_hat]) # concat along feature axis

def sampleQ(h_enc):
    """
    Samples Zt ~ normrnd(mu,sigma) via reparameterization trick for normal dist
    mu is (batch,z_size)
    """
    with tf.variable_scope("mu",reuse=DO_SHARE):
        mu=linear(h_enc,z_size)
    with tf.variable_scope("sigma",reuse=DO_SHARE):
        logsigma=linear(h_enc,z_size)
        sigma=tf.exp(logsigma)
    return (mu + sigma*e, mu, logsigma, sigma)

# The write function sans-attention is simply the linear Wx + b
def linear(x,output_dim):
    """
    affine transformation Wx+b
    assumes x.shape = (batch_size, num_features)
    """
    w=tf.get_variable("w", [x.get_shape()[1], output_dim]) 
    b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0))
    return tf.matmul(x,w)+b

#[equation 18 from the paper]
def write(h_dec):
    with tf.variable_scope("writeW",reuse=DO_SHARE):
        w=linear(h_dec,write_size) # batch x (write_n*write_n)
    N=write_n
    w=tf.reshape(w,[batch_size,N,N])
    Fx,Fy,gamma=attn_window("write",h_dec,write_n)
    Fyt=tf.transpose(Fy,perm=[0,2,1])
    wr=tf.batch_matmul(Fyt,tf.batch_matmul(w,Fx))
    wr=tf.reshape(wr,[batch_size,B*A])
    #gamma=tf.tile(gamma,[1,B*A])
    return wr*tf.reshape(1.0/gamma,[-1,1])


print "Functions defined."

Functions defined.


In [3]:

def filterbank(gx, gy, sigma2,delta, N):
    grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1])
    mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19
    mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20
    a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1])
    b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1])
    mu_x = tf.reshape(mu_x, [-1, N, 1])
    mu_y = tf.reshape(mu_y, [-1, N, 1])
    sigma2 = tf.reshape(sigma2, [-1, 1, 1])
    Fx = tf.exp(-tf.square((a - mu_x) / (2*sigma2))) # 2*sigma2?
    Fy = tf.exp(-tf.square((b - mu_y) / (2*sigma2))) # batch x N x B
    # normalize, sum over A and B dims
    Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps)
    Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps)
    return Fx,Fy

#Initialize all variables
cs=[0]*T # sequence of canvases
mus,logsigmas,sigmas=[0]*T,[0]*T,[0]*T # gaussian params generated by SampleQ. We will need these for computing loss.
x = tf.placeholder(tf.float32,shape=(batch_size,img_size)) # input (batch_size * img_size)
e=tf.random_normal((batch_size,z_size), mean=0, stddev=1) # Qsampler noise

# initial states
h_dec_prev=tf.zeros((batch_size,dec_size))
enc_state=lstm_enc.zero_state(batch_size, tf.float32)
dec_state=lstm_dec.zero_state(batch_size, tf.float32)

# Build the graph/network (which is done in an unrolled state)
for t in range(T):
    #Initialize the previos canvas as 0s for the first run
    c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1]
    #Create the error image
    x_hat=x-tf.sigmoid(c_prev) # error image
    
    r=read(x,x_hat,h_dec_prev)

    #Pass through the lstm_enc. 
    #Note that the first time encode and decode are called, the reuse flag must be false
    #If not, we get the following error: 
    # """Under-sharing: Variable encoder/LSTMCell/W_0 does not exist, disallowed. 
    # Did you mean to set reuse=None in VarScope?"""
    h_enc,enc_state=encode(enc_state,tf.concat(1,[r,h_dec_prev]))
    #Draw from the latest distribution
    z,mus[t],logsigmas[t],sigmas[t]=sampleQ(h_enc)
    
    h_dec,dec_state=decode(dec_state,z)
    cs[t]=c_prev+write(h_dec) # store results
    h_dec_prev=h_dec
    
    DO_SHARE=True

print "Model defined."

Model defined.


In [4]:
# Define cost functions 
# the final canvas matrix is used to parameterize a bernoulli distrbition D(X|c). Reconstruction loss
# is negative log probability. 
def binary_crossentropy(t,o):
    return -(t*tf.log(o+eps) + (1.0-t)*tf.log(1.0-o+eps))

# reconstruction term appears to have been collapsed down to a single scalar value (rather than one per item in minibatch)
x_recons=tf.nn.sigmoid(cs[-1])

# the final canvas matrix is used to parameterize a bernoulli distrbition D(X|c). Reconstruction loss
# is negative log probability [Equation 9]
# However, Eric Jang's implmentation uses a mean of summed cross-entropy. The results seem reasonable so I am 
# going to stick with it for now
Lx=tf.reduce_sum(binary_crossentropy(x,x_recons),1) # reconstruction term
Lx=tf.reduce_mean(Lx) 

#The latent loss Lz is the sum LK divergence of P(Z) from Q(Z|h)
kl_terms=[0]*T
for t in range(T):
    mu2=tf.square(mus[t])
    sigma2=tf.square(sigmas[t])
    logsigma=logsigmas[t]
    kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*.5 # each kl term is (1x minibatch)
KL=tf.add_n(kl_terms) # this is 1x minibatch, corresponding to summing kl_terms from 1:T
Lz=tf.reduce_mean(KL) # average over minibatches

cost=Lx+Lz

#As this can be really finnicky, I'm using Eric Jang's implentation directly here
optimizer=tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads=optimizer.compute_gradients(cost)
for i,(g,v) in enumerate(grads):
    if g is not None:
        grads[i]=(tf.clip_by_norm(g,5),v) # clip gradients
train_op=optimizer.apply_gradients(grads)

print "Loss and Optimization defined"

Loss and Optimization defined


In [5]:
data_directory = "mnist"
if not os.path.exists(data_directory):
    os.makedirs(data_directory)
train_data = mnist.input_data.read_data_sets(data_directory, one_hot=True).train # binarized (0-1) mnist data
fetches=[]
fetches.extend([Lx,Lz,train_op])
Lxs=[0]*train_iters
Lzs=[0]*train_iters

sess=tf.InteractiveSession()

saver = tf.train.Saver() # saves variables learned during training
tf.initialize_all_variables().run()

for i in range(train_iters):
    xtrain,_=train_data.next_batch(batch_size) # xtrain is (batch_size x img_size)
    feed_dict={x:xtrain}
    results=sess.run(fetches,feed_dict)
    Lxs[i],Lzs[i],_=results
    if i%100==0:
        print("iter=%d : Lx: %f Lz: %f" % (i,Lxs[i],Lzs[i]))

print "Training completed"

Extracting mnist/train-images-idx3-ubyte.gz
Extracting mnist/train-labels-idx1-ubyte.gz
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz
iter=0 : Lx: 543.944458 Lz: 1.542613
iter=100 : Lx: 205.020157 Lz: 4.008572
iter=200 : Lx: 199.314865 Lz: 3.124275
iter=300 : Lx: 202.209946 Lz: 1.864911
iter=400 : Lx: 193.502075 Lz: 1.659343
iter=500 : Lx: 197.096466 Lz: 2.332488
iter=600 : Lx: 194.527328 Lz: 2.655602
iter=700 : Lx: 180.402847 Lz: 4.911880
iter=800 : Lx: 161.686478 Lz: 4.881144
iter=900 : Lx: 158.418243 Lz: 4.673352
iter=1000 : Lx: 153.008209 Lz: 4.565330
iter=1100 : Lx: 158.537537 Lz: 4.582819
iter=1200 : Lx: 144.478500 Lz: 4.719623
iter=1300 : Lx: 137.608810 Lz: 4.670774
iter=1400 : Lx: 136.504745 Lz: 4.235938
iter=1500 : Lx: 137.091232 Lz: 4.872277
iter=1600 : Lx: 128.613373 Lz: 4.729065
iter=1700 : Lx: 120.264236 Lz: 4.967459
iter=1800 : Lx: 123.811806 Lz: 4.903710
iter=1900 : Lx: 117.166840 Lz: 5.018393
iter=2000 : Lx: 117.712608 Lz: 4.70644

In [7]:
#Visualizer

#Taking visualization code from Eric Jang's plot_data.py
def xrecons_grid(X,B,A):
    """
    plots canvas for single time step
    X is x_recons, (batch_size x img_size)
    assumes features = BxA images
    batch is assumed to be a square number
    """
    padsize=1
    padval=.5
    ph=B+2*padsize
    pw=A+2*padsize
    batch_size=X.shape[0]
    N=int(np.sqrt(batch_size))
    X=X.reshape((N,N,B,A))
    img=np.ones((N*ph,N*pw))*padval
    for i in range(N):
        for j in range(N):
            startr=i*ph+padsize
            endr=startr+B
            startc=j*pw+padsize
            endc=startc+A
            img[startr:endr,startc:endc]=X[i,j,:,:]
    return img


    interactive = False;
    
    prefix="withattention"
    out_file=os.path.join(FLAGS.data_dir,"draw_data.npy")
    [C,Lxs,Lzs]=np.load(out_file)
    T,batch_size,img_size=C.shape
    X=1.0/(1.0+np.exp(-C)) # x_recons=sigmoid(canvas)
    B=A=int(np.sqrt(img_size))
    if interactive:
        f,arr=plt.subplots(1,T)
    for t in range(T):
        img=xrecons_grid(X[t,:,:],B,A)
        if interactive:
            arr[t].matshow(img,cmap=plt.cm.gray)
            arr[t].set_xticks([])
            arr[t].set_yticks([])
        else:
            plt.matshow(img,cmap=plt.cm.gray)
            imgname='%s_%d.png' % (prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif
            plt.savefig(imgname)
            print(imgname)
    f=plt.figure()
    plt.plot(Lxs,label='Reconstruction Loss Lx')
    plt.plot(Lzs,label='Latent Loss Lz')
    plt.xlabel('iterations')
    plt.legend()
    if interactive:
        plt.show()
    else:
        plt.savefig('%s_loss.png' % (prefix))