TensorFlow implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.
|With Attention||Without Attention|
Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.
For a gentle walkthrough through the paper and implementation, see the writeup here: http://blog.evjang.com/2016/06/understanding-and-implementing.html.
python draw.py --data_dir=/tmp/draw downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to
You can visualize the results by running the script
python plot_data.py <prefix> <output_data>
python myattn /tmp/draw/draw_data.npy
To run training without attention, do:
python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False
Restoring from Pre-trained Model
Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in
draw.py and editing the path to your checkpoint file as needed. Save electricity!
This git repository contains the following pre-trained in the
|draw_data_attn.npy||Training outputs for DRAW with attention|
|drawmodel_attn.ckpt||Saved weights for DRAW with attention|
|draw_data_noattn.npy||Training outputs for DRAW without attention|
|drawmodel_noattn.ckpt||Saved weights for DRAW without attention|
These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.