TensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"
Switch branches/tags
Nothing to show
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Failed to load latest commit information.
data Initial Commit Feb 22, 2016
img Initial Commit Feb 22, 2016
LICENSE Initial commit Feb 22, 2016
README.md Update README.md Jul 26, 2016
draw.py Fix filterbank bug. Nov 28, 2017
plot_data.py Initial Commit Feb 22, 2016



TensorFlow implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.

With Attention Without Attention

Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.

For a gentle walkthrough through the paper and implementation, see the writeup here: http://blog.evjang.com/2016/06/understanding-and-implementing.html.


python draw.py --data_dir=/tmp/draw downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to /tmp/draw/draw_data.npy

You can visualize the results by running the script python plot_data.py <prefix> <output_data>

For example,

python myattn /tmp/draw/draw_data.npy

To run training without attention, do:

python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False

Restoring from Pre-trained Model

Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in draw.py and editing the path to your checkpoint file as needed. Save electricity!

saver.restore(sess, "/tmp/draw/drawmodel.ckpt")

This git repository contains the following pre-trained in the data/ folder:

Filename Description
draw_data_attn.npy Training outputs for DRAW with attention
drawmodel_attn.ckpt Saved weights for DRAW with attention
draw_data_noattn.npy Training outputs for DRAW without attention
drawmodel_noattn.ckpt Saved weights for DRAW without attention

These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.

Useful Resources