Skip to content

Commit

Permalink
Initial Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric Jang committed Feb 22, 2016
1 parent a043942 commit 7b3ec04
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 0 deletions.
52 changes: 52 additions & 0 deletions README.md
@@ -0,0 +1,52 @@
# draw

TensorFlow implementation of [DRAW: A Recurrent Neural Network For Image Generation](http://arxiv.org/pdf/1502.04623.pdf) on the MNIST generation task.

For a gentle walkthrough through the paper and implementation, see the writeup here: [https://evjang/articles/draw](http://evjang/articles/draw).

| With Attention | Without Attention |
| ------------- | ------------- |
| ![AttnGIF](img/mnist_attn.gif) | ![NoAttnGIF](img/mnist_noattn.gif) |

Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.

## Usage

`python draw.py --data_dir=/tmp/draw` downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to `/tmp/draw/draw_data.npy`

You can visualize the results by running the script `python plot_data.py <prefix> <output_data>`

For example,

`python fubar /tmp/draw/draw_data.npy`

To run training without attention, do:

`python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False`

## Restoring from Pre-trained Model

Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in `draw.py` and editing the path to your checkpoint file as needed. Save electricity!

```python
saver.restore(sess, "/tmp/draw/drawmodel.ckpt")
```

This git repository contains the following pre-trained in the `data/` folder:

| Filename | Description |
| ------------- | ------------- |
| draw_data_attn.npy | Training outputs for DRAW with attention |
| drawmodel_attn.ckpt | Saved weights for DRAW with attention |
| draw_data_noattn.npy | Training outputs for DRAW without attention |
| drawmodel_noattn.ckpt | Saved weights for DRAW without attention |

These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.

## Useful Resources

- https://github.com/vivanov879/draw
- https://github.com/jbornschein/draw
- https://github.com/ikostrikov/TensorFlow-VAE-GAN-DRAW (wish I had found this earlier)
- [Video Lecture on Variational Autoencoders and Image Generation]( https://www.youtube.com/watch?v=P78QYjWh5sM&list=PLE6Wd9FR--EfW8dtjAuPoTuPcqmOV53Fu&index=3)

246 changes: 246 additions & 0 deletions draw.py
@@ -0,0 +1,246 @@
#!/usr/bin/env python

""""
Simple implementation of http://arxiv.org/pdf/1502.04623v2.pdf in TensorFlow
Example Usage:
python draw.py --data_dir=/tmp/draw --read_attn=True --write_attn=True
Author: Eric Jang
"""

import tensorflow as tf
from tensorflow.models.rnn.rnn_cell import LSTMCell
from tensorflow.examples.tutorials import mnist
import numpy as np
import os

tf.flags.DEFINE_string("data_dir", "", "")
tf.flags.DEFINE_boolean("read_attn", True, "enable attention for reader")
tf.flags.DEFINE_boolean("write_attn",True, "enable attention for writer")
FLAGS = tf.flags.FLAGS

## MODEL PARAMETERS ##

A,B = 28,28 # image width,height
img_size = B*A # the canvas size
enc_size = 256 # number of hidden units / output size in LSTM
dec_size = 256
read_n = 5 # read glimpse grid width/height
write_n = 5 # write glimpse grid width/height
read_size = 2*read_n*read_n if FLAGS.read_attn else 2*img_size
write_size = write_n*write_n if FLAGS.write_attn else img_size
z_size=10 # QSampler output size
T=10 # MNIST generation sequence length
batch_size=100 # training minibatch size
train_iters=10000
learning_rate=1e-3 # learning rate for optimizer
eps=1e-8 # epsilon for numerical stability

## BUILD MODEL ##

DO_SHARE=None # workaround for variable_scope(reuse=True)

x = tf.placeholder(tf.float32,shape=(batch_size,img_size)) # input (batch_size * img_size)
e=tf.random_normal((batch_size,z_size), mean=0, stddev=1) # Qsampler noise
lstm_enc = LSTMCell(enc_size, read_size+dec_size) # encoder Op
lstm_dec = LSTMCell(dec_size, z_size) # decoder Op

def linear(x,output_dim):
"""
affine transformation Wx+b
assumes x.shape = (batch_size, num_features)
"""
w=tf.get_variable("w", [x.get_shape()[1], output_dim])
b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0))
return tf.matmul(x,w)+b

def filterbank(gx, gy, sigma2,delta, N):
grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1])
mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19
mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20
a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1])
b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1])
mu_x = tf.reshape(mu_x, [-1, N, 1])
mu_y = tf.reshape(mu_y, [-1, N, 1])
sigma2 = tf.reshape(sigma2, [-1, 1, 1])
Fx = tf.exp(-tf.square((a - mu_x) / (2*sigma2))) # 2*sigma2?
Fy = tf.exp(-tf.square((b - mu_y) / (2*sigma2))) # batch x N x B
# normalize, sum over A and B dims
Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps)
Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps)
return Fx,Fy

def attn_window(scope,h_dec,N):
with tf.variable_scope(scope,reuse=DO_SHARE):
params=linear(h_dec,5)
gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(1,5,params)
gx=(A+1)/2*(gx_+1)
gy=(B+1)/2*(gy_+1)
sigma2=tf.exp(log_sigma2)
delta=(max(A,B)-1)/(N-1)*tf.exp(log_delta) # batch x N
return filterbank(gx,gy,sigma2,delta,N)+(tf.exp(log_gamma),)

## READ ##
def read_no_attn(x,x_hat,h_dec_prev):
return tf.concat(1,[x,x_hat])

def read_attn(x,x_hat,h_dec_prev):
Fx,Fy,gamma=attn_window("read",h_dec_prev,read_n)
def filter_img(img,Fx,Fy,gamma,N):
Fxt=tf.transpose(Fx,perm=[0,2,1])
img=tf.reshape(img,[-1,B,A])
glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt))
glimpse=tf.reshape(glimpse,[-1,N*N])
return glimpse*tf.reshape(gamma,[-1,1])
x=filter_img(x,Fx,Fy,gamma,read_n) # batch x (read_n*read_n)
x_hat=filter_img(x_hat,Fx,Fy,gamma,read_n)
return tf.concat(1,[x,x_hat]) # concat along feature axis

read = read_attn if FLAGS.read_attn else read_no_attn

## ENCODE ##
def encode(state,input):
"""
run LSTM
state = previous encoder state
input = cat(read,h_dec_prev)
returns: (output, new_state)
"""
with tf.variable_scope("encoder",reuse=DO_SHARE):
return lstm_enc(input,state)

## Q-SAMPLER (VARIATIONAL AUTOENCODER) ##

def sampleQ(h_enc):
"""
Samples Zt ~ normrnd(mu,sigma) via reparameterization trick for normal dist
mu is (batch,z_size)
"""
with tf.variable_scope("mu",reuse=DO_SHARE):
mu=linear(h_enc,z_size)
with tf.variable_scope("sigma",reuse=DO_SHARE):
logsigma=linear(h_enc,z_size)
sigma=tf.exp(logsigma)
return (mu + sigma*e, mu, logsigma, sigma)

## DECODER ##
def decode(state,input):
with tf.variable_scope("decoder",reuse=DO_SHARE):
return lstm_dec(input, state)

## WRITER ##
def write_no_attn(h_dec):
with tf.variable_scope("write",reuse=DO_SHARE):
return linear(h_dec,img_size)

def write_attn(h_dec):
with tf.variable_scope("writeW",reuse=DO_SHARE):
w=linear(h_dec,write_size) # batch x (write_n*write_n)
N=write_n
w=tf.reshape(w,[batch_size,N,N])
Fx,Fy,gamma=attn_window("write",h_dec,write_n)
Fyt=tf.transpose(Fy,perm=[0,2,1])
wr=tf.batch_matmul(Fyt,tf.batch_matmul(w,Fx))
wr=tf.reshape(wr,[batch_size,B*A])
#gamma=tf.tile(gamma,[1,B*A])
return wr*tf.reshape(1.0/gamma,[-1,1])

write=write_attn if FLAGS.write_attn else write_no_attn

## STATE VARIABLES ##

cs=[0]*T # sequence of canvases
mus,logsigmas,sigmas=[0]*T,[0]*T,[0]*T # gaussian params generated by SampleQ. We will need these for computing loss.
# initial states
h_dec_prev=tf.zeros((batch_size,dec_size))
enc_state=lstm_enc.zero_state(batch_size, tf.float32)
dec_state=lstm_dec.zero_state(batch_size, tf.float32)

## DRAW MODEL ##

# construct the unrolled computational graph
for t in range(T):
c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1]
x_hat=x-tf.sigmoid(c_prev) # error image
r=read(x,x_hat,h_dec_prev)
h_enc,enc_state=encode(enc_state,tf.concat(1,[r,h_dec_prev]))
z,mus[t],logsigmas[t],sigmas[t]=sampleQ(h_enc)
h_dec,dec_state=decode(dec_state,z)
cs[t]=c_prev+write(h_dec) # store results
h_dec_prev=h_dec
DO_SHARE=True # from now on, share variables

## LOSS FUNCTION ##

def binary_crossentropy(t,o):
return -(t*tf.log(o+eps) + (1.0-t)*tf.log(1.0-o+eps))

# reconstruction term appears to have been collapsed down to a single scalar value (rather than one per item in minibatch)
x_recons=tf.nn.sigmoid(cs[-1])

# after computing binary cross entropy, sum across features then take the mean of those sums across minibatches
Lx=tf.reduce_sum(binary_crossentropy(x,x_recons),1) # reconstruction term
Lx=tf.reduce_mean(Lx)

kl_terms=[0]*T
for t in range(T):
mu2=tf.square(mus[t])
sigma2=tf.square(sigmas[t])
logsigma=logsigmas[t]
kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*.5 # each kl term is (1xminibatch)
KL=tf.add_n(kl_terms) # this is 1xminibatch, corresponding to summing kl_terms from 1:T
Lz=tf.reduce_mean(KL) # average over minibatches

cost=Lx+Lz

## OPTIMIZER ##

optimizer=tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads=optimizer.compute_gradients(cost)
for i,(g,v) in enumerate(grads):
if g is not None:
grads[i]=(tf.clip_by_norm(g,5),v) # clip gradients
train_op=optimizer.apply_gradients(grads)

## RUN TRAINING ##

data_directory = os.path.join(FLAGS.data_dir, "mnist")
if not os.path.exists(data_directory):
os.makedirs(data_directory)
train_data = mnist.input_data.read_data_sets(data_directory, one_hot=True).train # binarized (0-1) mnist data

fetches=[]
fetches.extend([Lx,Lz,train_op])
Lxs=[0]*train_iters
Lzs=[0]*train_iters

sess=tf.InteractiveSession()

saver = tf.train.Saver() # saves variables learned during training
tf.initialize_all_variables().run()
#saver.restore(sess, "/tmp/draw/drawmodel.ckpt") # to restore from model, uncomment this line

for i in range(train_iters):
xtrain,_=train_data.next_batch(batch_size) # xtrain is (batch_size x img_size)
feed_dict={x:xtrain}
results=sess.run(fetches,feed_dict)
Lxs[i],Lzs[i],_=results
if i%100==0:
print("iter=%d : Lx: %f Lz: %f" % (i,Lxs[i],Lzs[i]))

## TRAINING FINISHED ##

canvases=sess.run(cs,feed_dict) # generate some examples
canvases=np.array(canvases) # T x batch x img_size

out_file=os.path.join(FLAGS.data_dir,"draw_data.npy")
np.save(out_file,[canvases,Lxs,Lzs])
print("Outputs saved in file: %s" % out_file)

ckpt_file=os.path.join(FLAGS.data_dir,"drawmodel.ckpt")
print("Model saved in file: %s" % saver.save(sess,ckpt_file))

sess.close()

print('Done drawing! Have a nice day! :)')
Binary file added img/loss_attn.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/loss_noattn.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/mnist_attn.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/mnist_noattn.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
67 changes: 67 additions & 0 deletions plot_data.py
@@ -0,0 +1,67 @@
# takes data saved by DRAW model and generates animations
# example usage: python plot_data.py noattn /tmp/draw/draw_data.npy

import matplotlib
import sys
import numpy as np

interactive=False # set to False if you want to write images to file

if not interactive:
matplotlib.use('Agg') # Force matplotlib to not use any Xwindows backend.
import matplotlib.pyplot as plt


def xrecons_grid(X,B,A):
"""
plots canvas for single time step
X is x_recons, (batch_size x img_size)
assumes features = BxA images
batch is assumed to be a square number
"""
padsize=1
padval=.5
ph=B+2*padsize
pw=A+2*padsize
batch_size=X.shape[0]
N=int(np.sqrt(batch_size))
X=X.reshape((N,N,B,A))
img=np.ones((N*ph,N*pw))*padval
for i in range(N):
for j in range(N):
startr=i*ph+padsize
endr=startr+B
startc=j*pw+padsize
endc=startc+A
img[startr:endr,startc:endc]=X[i,j,:,:]
return img

if __name__ == '__main__':
prefix=sys.argv[1]
out_file=sys.argv[2]
[C,Lxs,Lzs]=np.load(out_file)
T,batch_size,img_size=C.shape
X=1.0/(1.0+np.exp(-C)) # x_recons=sigmoid(canvas)
B=A=int(np.sqrt(img_size))
if interactive:
f,arr=plt.subplots(1,T)
for t in range(T):
img=xrecons_grid(X[t,:,:],B,A)
if interactive:
arr[t].matshow(img,cmap=plt.cm.gray)
arr[t].set_xticks([])
arr[t].set_yticks([])
else:
plt.matshow(img,cmap=plt.cm.gray)
imgname='%s_%d.png' % (prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif
plt.savefig(imgname)
print(imgname)
f=plt.figure()
plt.plot(Lxs,label='Reconstruction Loss Lx')
plt.plot(Lzs,label='Latent Loss Lz')
plt.xlabel('iterations')
plt.legend()
if interactive:
plt.show()
else:
plt.savefig('%s_loss.png' % (prefix))

0 comments on commit 7b3ec04

Please sign in to comment.