# Multiple object recognition with visual attention

In this notebbook, I implement the network from Ba et al. "Multiple object recognition with visual attention"

In [24]:
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np

In [89]:
class glimpse_network(object):
  """Glimpse network"""
  def __init__(self, 
               layer_size=((5, 5, 3, 32), (5, 5, 32, 64), (5, 5, 64, 32), (28 * 28 * 32, 1024))):
    
    self.params = {}
    with tf.variable_scope("glimpse_net"):
      self.params["W_conv1"] = tf.get_variable(name="W_conv1", shape=layer_size[0],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv1"] = tf.get_variable(name="b_conv1", shape=layer_size[0][-1],
                          initializer=tf.zeros_initializer())
      
      self.params["W_conv2"] = tf.get_variable(name="W_conv2", shape=layer_size[1],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv2"] = tf.get_variable(name="b_conv2", shape=layer_size[1][-1],
                          initializer=tf.zeros_initializer())
      
      self.params["W_conv3"] = tf.get_variable(name="W_conv3", shape=layer_size[2],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv3"] = tf.get_variable(name="b_conv3", shape=layer_size[2][-1],
                          initializer=tf.zeros_initializer())
      
      
      self.params["W_fc1"] = tf.get_variable(name="W_fc1", shape=layer_size[3],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_fc1"] = tf.get_variable(name="b_fc1", shape=layer_size[3][-1],
                          initializer=tf.zeros_initializer())      
      
      self.params["W_loc1"] = tf.get_variable(name="W_loc1", shape=(2, layer_size[3][-1]),
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_loc1"] = tf.get_variable(name="b_loc1", shape=layer_size[3][-1],
                          initializer=tf.zeros_initializer())
    
  def build(self, x, l, activation=tf.nn.relu):
    """Build glimpse network based on observation x and location l."""
    endpoints = {}
    
    # image network
    # convolutional layer 1
    endpoints["conv_layer1"] = activation(
      tf.nn.conv2d(
        x, self.params["W_conv1"], strides=(1, 1, 1, 1), padding='SAME'
      ) + self.params["b_conv1"])
    
    # convolutional layer 2
    endpoints["conv_layer2"] = activation(
      tf.nn.conv2d(
        endpoints["conv_layer1"], self.params["W_conv2"], strides=[1, 1, 1, 1], padding='SAME'
      ) + self.params["b_conv2"])
    
    # convolutional layer 3
    endpoints["conv_layer3"] = activation(
      tf.nn.conv2d(
        endpoints["conv_layer2"], self.params["W_conv3"], strides=[1, 1, 1, 1], padding='SAME'
      ) + self.params["b_conv3"])
    
    endpoints["conv_layer3_flattened"] = tf.reshape(
      endpoints["conv_layer3"],
      shape=(-1, np.prod(endpoints["conv_layer3"].get_shape().as_list()[1:])))
        
    # fully connected layer
    endpoints["fc_layer1"] = activation(
      tf.matmul(
        endpoints["conv_layer3_flattened"], self.params["W_fc1"]) + self.params["b_fc1"])
    
    # location network
    endpoints["loc_layer1"] = activation(
      tf.matmul(l, self.params["W_loc1"]) + self.params["b_loc1"]
    )
    
    # combined output
    g = endpoints["fc_layer1"] * endpoints["loc_layer1"]

    return g