TensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"
Latest commit 00c55a4 Aug 27, 2016 @ericjang committed on GitHub Merge pull request #1 from guohengkai/master
Update usage of LSTMCell for newly Tensorflow
Failed to load latest commit information.
data Initial Commit Feb 22, 2016
img Initial Commit Feb 22, 2016
LICENSE Initial commit Feb 22, 2016
README.md Update README.md Jul 26, 2016
draw.py Update usage of LSTMCell for newly Tensorflow Aug 27, 2016
plot_data.py Initial Commit Feb 22, 2016



TensorFlow implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.

With Attention Without Attention

Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.

For a gentle walkthrough through the paper and implementation, see the writeup here: http://blog.evjang.com/2016/06/understanding-and-implementing.html.


python draw.py --data_dir=/tmp/draw downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to /tmp/draw/draw_data.npy

You can visualize the results by running the script python plot_data.py <prefix> <output_data>

For example,

python myattn /tmp/draw/draw_data.npy

To run training without attention, do:

python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False

Restoring from Pre-trained Model

Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in draw.py and editing the path to your checkpoint file as needed. Save electricity!

saver.restore(sess, "/tmp/draw/drawmodel.ckpt")

This git repository contains the following pre-trained in the data/ folder:

Filename Description
draw_data_attn.npy Training outputs for DRAW with attention
drawmodel_attn.ckpt Saved weights for DRAW with attention
draw_data_noattn.npy Training outputs for DRAW without attention
drawmodel_noattn.ckpt Saved weights for DRAW without attention

These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.

Useful Resources