The Recurrent Attention Model (RAM) is introduced in [1] & [2].
It is inspired by the way humans perceive their surroundings, i.e. focusing on selective parts of the environment to acquire information and combining it, instead of observing the scene in its entirety.
In [1], the performance of the model is demonstrated by calssifying the MNIST dataset. In contrast to the existing approaches that processes the whole image, the RAM uses the information of glimpses at selected locations. These glimpses are then perceived in a retina-like representation to classify the given symbols.
As suggested in [1], [2], the network is trained using the REINFORCE [3] learning rule. The baseline is trained by reducing the mean squared error between the baseline and the received reward.
In contrast to the model introduced in [1], not only the mean, but also the standard deviation of the location policy is learned.
The code is inspired by [3], [4] & [5].
Required packages:
- Numpy
- Tensorflow
- OpenCv for evaluation
- Matplotlib for plotting
- H5Py for saving the trained network weights
The parameters for the training are all defined in the configuration files
run_mnist.py
and run_translated_mnist.py
.
After training, the network-model is saved. It can be loaded for further training or evaluation.
During training information about the current losses, accuracy
and the behavior of the location network can be gathered using tensorboard
.
tensorboard --logdir=./summary
To create images of the glimpses that the network uses after training, simply execute the evaluation script. The first parameter is the name of the configuration file and the second is the path to the network model.
evaluate.py run_mnist ./model/
To plot the accuracy of the classification over the number of trained epochs use the plotting script.
python plot.py ./results.json
To train the network on classifying the standard MNIST dataset, start the training via the corresponding configuration file:
python run_mnist.py
Current Highscore: 97.97% +/- 0.14 accuracy on the MNIST test-dataset.
The plot below shows the accuracy for the test-dataset over the number of trained epochs.
Examples of the images and the corresponding glimpses used by the network are displayed in the table.
Originial Image | Glimpse 1 | Glimpse 3 | Glimpse 5 | Glimpse 6 |
---|---|---|---|---|
In [1], the network is tested on non-centered digits. Therefore, the digits forming the MNIST dataset are incorporated into a larger image patch and then randomly translated.
To train the network on classifying the "translated" MNIST dataset, start the code via the corresponding configuration file:
python run_translated_mnist.py
Current Highscore: 97.5% +/- 0.16 accuracy on the translated MNIST test-dataset.
The plot below shows the accuracy for the test-dataset over the number of trained epochs.
Examples of the images and the corresponding glimpses used by the network are displayed in the table.
Originial Image | Glimpse 1 | Glimpse 2 | Glimpse 5 | Glimpse 7 |
---|---|---|---|---|
[1] Mnih, Volodymyr, Nicolas Heess, and Alex Graves. "Recurrent models of visual attention." Advances in neural information processing systems. 2014.
[2] Ba, Jimmy, Volodymyr Mnih, and Koray Kavukcuoglu. "Multiple object recognition with visual attention." arXiv preprint arXiv:1412.7755 (2014).
[3] Williams, Ronald J. "Simple statistical gradient-following algorithms for connectionist reinforcement learning." Machine learning 8.3-4 (1992): 229-256.
[4] https://github.com/jlindsey15/RAM