# Task 2: Recurrent Attention Model (RAM)

Reference: https://papers.nips.cc/paper/5542-recurrent-models-of-visual-attention.pdf

Useful links:
https://medium.com/towards-data-science/visual-attention-model-in-deep-learning-708813c2912c

In [None]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

# Notebook auto reloads code. (Ref: http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython)
% load_ext autoreload
% autoreload 2

## Load MNIST data

In [None]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=False)

print("number of training samples: ", mnist.train.num_examples)
print("number of validation samples: ", mnist.validation.num_examples)
print("number of test samples: ", mnist.test.num_examples)

## Task 2, Part 1: RAM networks

Recurrent attention model (RAM) is a model which processes inputs sequentially. For example, in the image classification problem, rather than using the whole image as input, RAM only takes a look at the small patch of the image at each step. RAM itself can learn what/where it should pay attention to, depending on which task it is executing. The core network consists of three parts: **glimpse net**, **rnn net**, and **location net/action net**. 

As shown in the following figure, 

* Glimpse Net: shown in A) and B) in the figure. It includes a glimpse sensor which extracts a small patch from the original image, and a glimpse net which combines both glimpse info and location info together with fully connected network and $g_t$ output vector.

* RNN Net: takes $g_t$ as input, passes it through a one-layer rnn network, and outputs the hidden states $h_t$ of the rnn cell.

* Location Net: uses $h_t$ to estimate the next location $l_t$ for the glimpse.

* Action/Classification Net: uses the last $h_t$ set as features to classify the label of a digit image.

* At each step, $l_t$ is fed back to the glimpse net, to get the next $g_{t+1}$ for the RNN, as shown in C).

![ram](./img/whole_net.png)

### step 1: Glimpse network

<span style="color:red">TODO:</span> Complete **glimpse_sensor**, and **\__call__** functions of the **GlimpseNet class** in the **ecbm4040/ram/networks.py**.

* Glimpse sensor/Retina and location encodings: The retina encoding ρ(x, l) extracts k square patches centered at location l, with the first patch being [glimpse_win × glimpse_win] pixels in size, and each successive patch having twice the width of the previous. The k patches are then all resized to [glimpse_win, glimpse_win] and concatenated. Glimpse locations l were encoded as real-valued (x, y) coordinates with (0, 0) being the center of the image x, and (−1, −1) being the top left corner of x. So the value of x or y is between -1 and 1.

* Glimpse net: it is one kind of a MLP network.
  * $hl = ReLU( Linear(l) )$
  * $hg = ReLU(Linear(ρ(x, l)))$
  * $g = ReLU(Linear(hl) + Linear(hg))$

![glimpse_net](./img/glimpse_net.png)

### step 2: Location network

<span style="color:red">TODO:</span> Finish **\__call__** function of **LocNet class** in **ecbm4040/ram/networks.py**.

* Location net:
  * ${E}[l_t] = Linear(h_t)$
  * Gaussian stochastic policy: $l_t \sim N(E[l_t], \sigma^2)$, next location for glimpse is sampled from a gaussian distribution with $E[l_t]$ as mean and a fixed $\sigma$ as std. deviation.
  * Here $\sigma$ is a fixed number.
  * **Location net is a stochastic net.** In the paper, it uses REINFORCE to train it. You can find more details in the following link. https://medium.com/towards-data-science/visual-attention-model-in-deep-learning-708813c2912c

### step 3: Action network (classification in this experiment)

<span style="color:red">TODO:</span> Complete **\__call__** function of **ActionNet class** in **ecbm4040/ram/networks.py**.

* Action net:
  * $a = Linear(h_T)$, here $h_T$ is the last output of the rnn network.
  * $softmax\_a = softmax(a)$

### step 4: Core RNN network

<span style="color:red">TODO:</span> Complete **core_rnn_net** function in **ecbm4040/ram/model.py**.

In this experiment, we use the LSTM cell.

In the core rnn net, 
* First, define the LSTM cell.
* Then, initialize the init state $g$.
* Build a loop function which keeps feeding new glimpse into the LSTM cell.
* Output hidden states to the location net or action net.

![core_rnn_net](./img/core_rnn_net.png)

## Task 2, Part 2: Training 

### Goal 

Train the glimpse net, core RNN net, and action/location net.

### Hybrid Supervised Loss

The code is provided. It is in file **ecbm4040/ram/loss.py** and **model** function in **ecbm4040/ram/model.py**.

To train RAM, we need to define the loss function. Originally, RAM can be perceived as an agent trying to solve a Partially Observable Markov Decision Process (POMDP) problem. At each step, after it takes an action like choosing the glimpse location, the agent will receive a reward signal, so its main goal is to maximize the total sum of reward signals. In the case of object recognition, for example, $R$ = 1 if the object is classified correctly after T steps, and 0 otherwise. T is the total number of glimpses except for the initial random glimpse.

To **maximize** this reward signal, the objective function is defined as Eq. (1) in the paper.

$$J = \frac{1}{M}\sum_{i=1}^M \sum_{t=1}^T log(\pi(l_t^i\ |\ s_{1:t}^i;\theta)) \times R^i$$ 

Here, $\theta$ means all trainable parameters in the network. $R_i$ is the reward signal of sample $i$. $M$ implies that it uses **Monte Carlo** sampling to estimate the loss - this is a famous method used in reinforcement learning (RL). LocNet is a stochastic net, where the output next_loc is sampled from a gaussian distribution. This makes the net indifferentiable, and we can not use an  ordinary back-propagation method to update the parameters. So we need to use Monte Carlo sampling, taking  the average loss from M samples as the estimation of loss, and applying back-propagation via this loss. Also, during back-propagation we need to avoid taking this stochastic probability into account. Stochastic LocNet uses REINFORCE for training.

However, this loss may have high variance. To reduce the variance, it uses a **baseline network**. Then, the objective function becomes

$$J = \frac{1}{M}\sum_{i=1}^M \sum_{t=1}^T log(\pi(l_t^i\ |\ s_{1:t}^i;\theta)) \times (R_t^i - b_t^i)$$

where $R_t^i$ is always equal to $R^i$ based on the reward definition above. $b^i$ is an estimation of $E[R_t]$, and $R_t$ here only relies on its state value $h_t$ and is independent of the LocNet and its action value $l_t$. In practice, it uses another baseline network to estimate this value. And the baseline network is defined as a single-layer fully-connected network with the goal of reducing the **squared error between $R_t^i$ and $b_t^i$**. Also, remember that with respect to $J$, $b_t^i$ is a constant value, which means that $b_t^i$ should not be considered in backpropagtion through this part of the network.

**Hybrid supervised loss: **
As mentioned in the paper, the algorithm described above allows us to train the agent when the “best” actions are unknown, and the learning signal is only provided via the reward. For instance, we may not know a priori which sequence of fixations provides the most information about an unknown image, but the total reward at the end of an episode will give us an indication whether the tried sequence was good or bad.

However, in some situations we do know the correct action to take: For instance, in the object detection task the agent has to output the label of the object as the final action. For training images this label will be known, and we can directly optimize the policy to **output the correct label associated with a training image at the end of an observation sequence**. We follow the approach for the classification problems and optimize the **cross entropy loss** to train the action network fa and backpropagate the gradients through the core and glimpse networks.

so the final hybrid loss is defined as,

$$Hybrid\ Loss = -J + \frac{1}{M}\sum_{i=1}^M cross\_entropy(softmax\_a) + \frac{1}{M}\sum_{i=1}^M \sum_{t=1}^T (R^i-b_t^i)^2$$

### Optimizer and Back-propagation

Code is given in **ecbm4040/ram/model.py**. Study and try to understand the code.

In [None]:
# network configuration
config = {
    # input configuration
    "image_size": 28,
    "num_channels": 1,
    # network settings
    ## glimpse
    "glimpse_win": 12,
    "glimpse_scale": 1,
    "hg_dim": 128,
    "hl_dim": 128,
    "g_dim": 256,
    ## rnn
    "num_glimpses": 6,
    "cell_dim": 256,
    ## location
    "loc_dim": 2,
    "loc_std": 0.1, # you can try different std
    "use_sample": True,
    ## action/classification
    "num_classes": 10
}

In [None]:
# training configuration
train_cfg = {
    "max_grad_norm": 5.,
    "lr_init": 1e-4,
    "lr_min": 1e-5,
    "decay_rate": 0.95,
    "num_epochs": 15, # you should try more epoch
    "num_train": mnist.train.num_examples,
    "batch_size": 32,
    "eval_size": 1000,
    # monte carlo sampling
    "M": 10
}

### Build the network

In [None]:
from ecbm4040.ram.model import model

In [None]:
tf.reset_default_graph()
out = model(config, train_cfg, reuse_core=False, reuse_action=False)
images_ph, labels_ph, hybrid_loss, J, cross_ent, b_mse, r_avg, correct_num, lr, train_step, loc_means, loc_samples = out

In [None]:
# run this cell, to verify that the program works well, before going to training.
with tf.Session() as sess:
    images = mnist.train.images[:10,:]
    labels = mnist.train.labels[:10]
    images = images.reshape((10, 28, 28, 1))
    
    sess.run(tf.global_variables_initializer())
    test_J, test_ent, test_bmse, test_r = sess.run([J, cross_ent, b_mse, r_avg],
                                           feed_dict={images_ph: images,labels_ph: labels})
print("outputs: J={:5f}, cross_ent={:5f}, baseline_mse={:5f}, reward_avg={:5f}".format(test_J, test_ent, test_bmse, test_r))

## Part 3: Experiments

It is recommended to use a GPU to complete the following experiments.

In [None]:
# load functions to visualize the glimpse path.
from ecbm4040.ram.utils import glimpse_path

### Original 28x28 MNIST

In [None]:
# display some samples
num_display_samples = 4
images, labels = mnist.train.next_batch(num_display_samples)
images = images.reshape((num_display_samples, 28, 28, 1))
f, axarr = plt.subplots(1, num_display_samples, figsize=(4*num_display_samples,4))
for i in range(num_display_samples):
    axarr[i].imshow(np.squeeze(images[i,:,:,:]), cmap="gray")
    axarr[i].set_title(labels[i])
plt.show()

**Build the model:**

In [None]:
config["use_sample"] = True
tf.reset_default_graph()
out = model(config, train_cfg, reuse_core=False, reuse_action=False)
images_ph, labels_ph, hybrid_loss, J, cross_ent, b_mse, r_avg, correct_num, lr, train_step, loc_means, loc_samples = out

**Train the model:** 

* You should reach <span style="color:red">90%</span> validation & test acc in this experiment.
* This network may be sentitive to its init weigths.
* Overfitting may happen during training.

In [None]:
# train
M = train_cfg["M"]
num_epochs = train_cfg["num_epochs"]
num_steps_per_epoch = train_cfg["num_train"] // train_cfg["batch_size"]
eval_size = train_cfg["eval_size"]
# save
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(num_epochs):
        #####################################
        #           training phase          #
        #####################################
        for i in range(num_steps_per_epoch):
            images, labels = mnist.train.next_batch(train_cfg["batch_size"])
            images = images.reshape((train_cfg["batch_size"], 28, 28, 1))
            # Monte Carlo Estimation: duplicate M times, see Eqn (1) and (2) in paper
            images = np.tile(images, [M, 1, 1, 1])
            labels = np.tile(labels, [M])
            # training
            train_loss, train_J, train_ent, train_bmse, train_r, train_lr, _ = sess.run([hybrid_loss, J, cross_ent, b_mse, r_avg, lr, train_step],
                                                   feed_dict={images_ph: images,labels_ph: labels})
            # report progress
            if i and i % 500 == 0:
                print("epoch {} step {}: lr = {:.5f}\treward = {:.4f}\tloss = {:.4f}".
                      format(e+1, i, train_lr, train_r, train_loss))
                print("epoch {} step {}: J = {:.5f}\tcross_ent = {:.4f}\tbaseline_mse = {:.4f}".
                      format(e+1, i, train_J, train_ent, train_bmse))
                
        #####################################
        #         evaluation phase          #
        #####################################
        # validation set
        val_correct_num = 0.0
        for i in range(mnist.validation.num_examples//eval_size):
            images, labels = mnist.validation.next_batch(eval_size)
            images = images.reshape((eval_size, 28, 28, 1))
            
            val_correct_num += sess.run(correct_num, feed_dict={images_ph: images,labels_ph: labels})
        val_acc = val_correct_num/mnist.validation.num_examples
        print("------epoch {}: val_acc = {:.4f}".format(e+1, val_acc))
    
    #####################################
    #            save model             #
    #####################################
    # Save the variables to disk.
    if not os.path.exists("./tmp/"):
        os.mkdir("./tmp/")
    save_path = saver.save(sess, "./tmp/model_28_15.ckpt")
    print("Model saved in file: %s" % save_path)

In [None]:
saver = tf.train.Saver()
# evaluation: output val_acc and test_acc
eval_size = train_cfg["eval_size"]
with tf.Session() as sess:
    # Restore variables from disk.
    saver.restore(sess, "./tmp/model_28_15.ckpt")
    print("Model restored.")
    # validation set
    val_correct_num = 0.0
    for i in range(mnist.validation.num_examples//eval_size):
        images = mnist.validation.images[i*eval_size:(i+1)*eval_size]
        labels = mnist.validation.labels[i*eval_size:(i+1)*eval_size]
        images = images.reshape((eval_size, 28, 28, 1))

        val_correct_num += sess.run(correct_num, feed_dict={images_ph: images,labels_ph: labels})
    val_acc = val_correct_num/mnist.validation.num_examples
    print("val_acc = {:.4f}".format(val_acc))

    # test set
    test_correct_num = 0.0
    for i in range(mnist.test.num_examples//eval_size):
        images = mnist.test.images[i*eval_size:(i+1)*eval_size]
        labels = mnist.test.labels[i*eval_size:(i+1)*eval_size]
        images = images.reshape((eval_size, 28, 28, 1))

        test_correct_num += sess.run(correct_num, feed_dict={images_ph: images,labels_ph: labels})
    test_acc = test_correct_num/mnist.test.num_examples
    print("test_acc = {:.4f}".format(test_acc))   

In [None]:
# display the glimpse path. You can use "glimpse_path" or create your own function.

### Translated 60x60 MNIST

In [None]:
from ecbm4040.ram.utils import translate_60_mnist

# display some translated samples: all samples in one batch share the same translation transform.
num_display_samples = 4
images, labels = mnist.train.next_batch(num_display_samples)
images = translate_60_mnist(images, image_size=28, num_channels=1)
f, axarr = plt.subplots(1, num_display_samples, figsize=(4*num_display_samples,4))
for i in range(num_display_samples):
    axarr[i].imshow(np.squeeze(images[i,:,:,:]), cmap="gray")
    axarr[i].set_title(labels[i])
plt.show()

**Train the model:** set a larger glimpse window and add more scales.

In [None]:
# TODO: train

In [None]:
# TODO: evaluation, output val_acc and test_acc

In [None]:
# TODO: display the glimpse path and glimpse patches.

### MNIST pair addition (Optional)

In this experiment, the task is to predict the sum of a pair of digits. So, remember to change the "num_classes" into 19. Note that running this experiment may take several hours with GPU.

**Warning:** This experiment is really challenging. Based on previous experiments, it may take many hours to reach only 65% accuracy. So we make this part optional. 

In [None]:
from ecbm4040.ram.utils import mnist_addition_pair

# display some samples: image -- a pair of digit; label -- the sum of two digits.
num_display_samples = 4
images, labels = mnist.train.next_batch(num_display_samples)
images, labels = mnist_addition_pair(images, labels, image_size=28, num_channels=1)
f, axarr = plt.subplots(1, num_display_samples, figsize=(4*num_display_samples,4))
for i in range(num_display_samples):
    axarr[i].imshow(np.squeeze(images[i,:,:,:]), cmap="gray")
    axarr[i].set_title(labels[i])
plt.show()

**train**: may need more glimpse to give the sensor more freedom and **a large enough M** to reduce the variance. 

In [None]:
# TODO: train

In [None]:
# TODO: evaluation: output val_acc and test_acc

In [None]:
# TODO: display the glimpse path and glimpse patches.

## Other recommended reading

[1] [Multiple object recognition with visual attention](https://arxiv.org/pdf/1412.7755.pdf)