# Multiple object recognition with visual attention

In this notebook, I implement the model from Ba et al. "Multiple object recognition with visual attention"

In [1]:
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import os
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials import mnist
from tensorflow.python.ops import array_ops


In [2]:
T = 4  # times is the total number of pixels
num_classes = 10 
num_units = 128  # number of units in network
max_iter = 50000
init_lr = 0.001  # initial learning rate
batch_size = 512

In [3]:
class mnist_provider(object):
  """MNIST data provider."""
  def __init__(self, data_directory, split="train"):
    mnist_data = mnist.input_data.read_data_sets(data_directory, one_hot=True)
    if split == "train":
      self.mnist_data = mnist_data.train
    elif split == "valid":
      self.mnist_data = mnist_data.validation
    elif split == "test":
      self.mnist_data = mnist_data.test

  def next_batch(self, batch_size):
    images, one_hot_labels = self.mnist_data.next_batch(batch_size)
    images = np.reshape(images, [-1, 28, 28, 1], order='C')
    return images, one_hot_labels

In [4]:
class GlimpseNet(object):
  """Glimpse network.
  
  "The glimpse network is a non-linear function that receives the current input image
  patch, or glimpse, xn and its location tuple ln , where ln = (xn, yn), as input and 
  outputs a vector gn"
  """
  def __init__(self, 
               layers_size=(
                 (3, 3, 1, 64),
                 (3, 3, 64, 64),
                 (3, 3, 64, 64),
                 (16 * 16 * 64, num_units)
               )):
    """ 
    
    Args:
      layers_size: layer sizes (tuple of shape tuples).
        default:
          layer 1: 3x3 conv, input with 1 color channel, 64 channels output
          layer 2: 3x3 conv, 64 channels output
          layer 3: 3x3 conv, 64 channels output
          layer 4: fully connected, input with size 16x16, output size 128
    
    """
    
    self.params = {}
    with tf.variable_scope("glimpse_net"):
      self.params["W_conv1"] = tf.get_variable(name="W_conv1", shape=layers_size[0],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv1"] = tf.get_variable(name="b_conv1", shape=layers_size[0][-1],
                          initializer=tf.zeros_initializer())
      
      self.params["W_conv2"] = tf.get_variable(name="W_conv2", shape=layers_size[1],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv2"] = tf.get_variable(name="b_conv2", shape=layers_size[1][-1],
                          initializer=tf.zeros_initializer())
      
      self.params["W_conv3"] = tf.get_variable(name="W_conv3", shape=layers_size[2],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv3"] = tf.get_variable(name="b_conv3", shape=layers_size[2][-1],
                          initializer=tf.zeros_initializer())
      
      
      self.params["W_fc1"] = tf.get_variable(name="W_fc1", shape=layers_size[3],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_fc1"] = tf.get_variable(name="b_fc1", shape=layers_size[3][-1],
                          initializer=tf.zeros_initializer())      
      
      self.params["W_loc1"] = tf.get_variable(name="W_loc1", shape=(2, layers_size[3][-1]),
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_loc1"] = tf.get_variable(name="b_loc1", shape=layers_size[3][-1],
                          initializer=tf.zeros_initializer())
    
  def build(self, x_l, l, activation=tf.nn.relu):
    """Build glimpse network based on observation x_l at location l and location l."""
    endpoints = {}
    with tf.name_scope("glimpse_net"):
      # image network
      # convolutional layer 1
      endpoints["conv_layer1"] = activation(
        tf.nn.conv2d(
          x_l, self.params["W_conv1"], strides=(1, 1, 1, 1), padding='SAME'
        ) + self.params["b_conv1"])

      # convolutional layer 2
      endpoints["conv_layer2"] = activation(
        tf.nn.conv2d(
          endpoints["conv_layer1"], self.params["W_conv2"], strides=[1, 1, 1, 1], padding='SAME'
        ) + self.params["b_conv2"])

      # convolutional layer 3
      endpoints["conv_layer3"] = activation(
        tf.nn.conv2d(
          endpoints["conv_layer2"], self.params["W_conv3"], strides=[1, 1, 1, 1], padding='SAME'
        ) + self.params["b_conv3"])

      endpoints["conv_layer3_flattened"] = tf.reshape(
        endpoints["conv_layer3"],
        shape=(-1, np.prod(endpoints["conv_layer3"].get_shape().as_list()[1:])))

      # fully connected layer
      endpoints["fc_layer1"] = activation(
        tf.matmul(
          endpoints["conv_layer3_flattened"], self.params["W_fc1"]) + self.params["b_fc1"])

      # location network
      endpoints["loc_layer1"] = activation(
        tf.matmul(l, self.params["W_loc1"]) + self.params["b_loc1"]
      )

      # combined output
      g = endpoints["fc_layer1"] * endpoints["loc_layer1"]

    return g

In [5]:
class EmissionNet(object):
  """Emission network.
  
  "The emission network takes the current state of recurrent network as input and 
  makes a prediction on where to extract the next image patch for the glimpse network."
  
  """
  def __init__(self, layers_size=((num_units, 2),)):
    """ 
    
    Args:
      layers_size: layer sizes (tuple of shape tuples).
        default:
          layer 1: fully connected, input with size 128, output size 2
    
    """
    
    self.params = {}
    with tf.variable_scope("emission_net"):
      
      self.params["W_fc1"] = tf.get_variable(name="W_fc1", shape=layers_size[0],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_fc1"] = tf.get_variable(name="b_fc1", shape=layers_size[0][-1],
                          initializer=tf.zeros_initializer())      
    
  def build(self, r, activation=tf.nn.relu):
    """Build emission network based on rnn state r."""
    endpoints = {}
    
    with tf.name_scope("emission_net"):
      # fully connected layer
      endpoints["fc_layer1"] = activation(
        tf.matmul(
          r, self.params["W_fc1"]) + self.params["b_fc1"])

      # normalized location (-1, 1) range. [0, 0] is center.
      l = 2 * tf.nn.sigmoid(endpoints["fc_layer1"]) - 1
    return l

In [6]:
class ContextNet(object):
  """Context network.
  
  "The context network provides the initial state for the recurrent network and its
  output is used by the emission network to predict the location of the first glimpse. The context
  network C(·) takes a down-sampled low-resolution version of the whole input image Icoarse and
  outputs a fixed length vector cI . The contextual information provides sensible hints on where the
  potentially interesting regions are in a given image."
  """
  def __init__(self, 
               layers_size=(
                 (3, 3, 1, 64),
                 (3, 3, 64, 64),
                 (3, 3, 64, 64),
                 (16 * 16 * 64, num_units)
               )):
    self.params = {}
    with tf.variable_scope("context_net"):
      self.params["W_conv1"] = tf.get_variable(name="W_conv1", shape=layers_size[0],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv1"] = tf.get_variable(name="b_conv1", shape=layers_size[0][-1],
                          initializer=tf.zeros_initializer())
      
      self.params["W_conv2"] = tf.get_variable(name="W_conv2", shape=layers_size[1],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv2"] = tf.get_variable(name="b_conv2", shape=layers_size[1][-1],
                          initializer=tf.zeros_initializer())
      
      self.params["W_conv3"] = tf.get_variable(name="W_conv3", shape=layers_size[2],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_conv3"] = tf.get_variable(name="b_conv3", shape=layers_size[2][-1],
                          initializer=tf.zeros_initializer())
      
      
      self.params["W_fc1"] = tf.get_variable(name="W_fc1", shape=layers_size[3],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
      self.params["b_fc1"] = tf.get_variable(name="b_fc1", shape=layers_size[3][-1],
                          initializer=tf.zeros_initializer())      

  
  def build(self, x_coarse, activation=tf.nn.relu):
    """Build context network."""
    endpoints = {}
    with tf.name_scope("context_net"):
      # image network
      # convolutional layer 1
      endpoints["conv_layer1"] = activation(
        tf.nn.conv2d(
          x_coarse, self.params["W_conv1"], strides=(1, 1, 1, 1), padding='SAME'
        ) + self.params["b_conv1"])

      # convolutional layer 2
      endpoints["conv_layer2"] = activation(
        tf.nn.conv2d(
          endpoints["conv_layer1"], self.params["W_conv2"], strides=[1, 1, 1, 1], padding='SAME'
        ) + self.params["b_conv2"])

      # convolutional layer 3
      endpoints["conv_layer3"] = activation(
        tf.nn.conv2d(
          endpoints["conv_layer2"], self.params["W_conv3"], strides=[1, 1, 1, 1], padding='SAME'
        ) + self.params["b_conv3"])

      endpoints["conv_layer3_flattened"] = tf.reshape(
        endpoints["conv_layer3"],
        shape=(-1, np.prod(endpoints["conv_layer3"].get_shape().as_list()[1:])))

      # fully connected layer
      endpoints["fc_layer1"] = activation(
        tf.matmul(
          endpoints["conv_layer3_flattened"], self.params["W_fc1"]) + self.params["b_fc1"])
      
    return endpoints["fc_layer1"]
  

In [7]:
class DRAM(object):
  """Class for deep recurrent attention model (DRAM).
  
  Args:
    weights: dictionary of weights as tensorflow variables.
    biases: dictionary of biases as tensorflow variables.
    outputs: rnn output pre logit (list of size number of time steps).
    state: final rnn state.
    cell: tensorflow cell obbject
    logits_list: list of logits at each time point.
  """
  def __init__(self, images, num_units, num_classes, time_steps, activation=tf.nn.tanh):
    """Init function.
    
    Inputs:
      images: model input.
      num_units: number of network units.
      num_classes: network output classes.
      time_steps: number of time steps for RNNS
      activation: activation function to use for cell (default tanh).
    """
    self.weights = {}
    self.biases = {}
    
    # initial states 
    # context information
    glimpse_size = coarse_size = (16, 16)
    C = ContextNet()
    images_coarse = tf.image.resize_images(images, coarse_size)
    h0 = C.build(images_coarse)
    
    # initial location
    E = EmissionNet()
    l0 = E.build(h0)
    
    # first glimpse
    G = GlimpseNet()

    #TODO(gamaleldin): change GRU to LSTM
    with tf.variable_scope("rnn_layer1"):
      with tf.variable_scope("internal"):
        self.rnn_layer1 = tf.nn.rnn_cell.GRUCell(
          num_units=num_units, activation=activation)
        
        self.rnn_layer2 = tf.nn.rnn_cell.GRUCell(
          num_units=num_units, activation=activation)
        
      
      state_layer1 = self.rnn_layer1.zero_state(
        batch_size=array_ops.shape(images)[0],
        dtype=tf.float32)
      
      state_layer2 = h0
      self.outputs1 = []
      self.states1 = []
      
      self.outputs2 = []
      self.states2 = []
      l = l0
      for t in range(time_steps):
        with tf.variable_scope("internal", reuse=(t != 0)):
          images_glimpse0 = tf.image.extract_glimpse(
            images,
            size=glimpse_size,
            offsets=l,
            centered=True,
            normalized=True)
          g = G.build(images_glimpse0, l)
          
          output1, state_layer1 = self.rnn_layer1(g, state_layer1)
          self.outputs1.append(output1)
          self.states1.append(state_layer1) 
          
          output2, state_layer2 = self.rnn_layer1(output1, state_layer2)
          self.outputs2.append(output2)
          self.states2.append(state_layer2) 
            
          l = E.build(output2)
          
      with tf.variable_scope("output"):
        self.weights["out"] = tf.get_variable(name="w",
                        shape=(num_units, num_classes),
                        initializer=tf.truncated_normal_initializer(stddev=0.1))
        
        self.biases["out"] = tf.get_variable(name="b",
                        shape=(num_classes),
                        initializer=tf.zeros_initializer())

    # compute logits at each time step
    self.logits_list = []
    for output in self.outputs1:
      self.logits_list.append(
        tf.matmul(output, self.weights["out"]) + self.biases["out"]) 
    

In [8]:
tf.reset_default_graph()
data_dir = "/Users/gamal/git_local_repo/playground/data/mnist"
data_provider_train = mnist_provider(data_dir, split='train')
data_provider_valid = mnist_provider(data_dir, split='valid')
data_provider_test = mnist_provider(data_dir, split='test')
g = tf.Graph()
with g.as_default():
  # place golder for data
  images = tf.placeholder(shape=(None, 28, 28, 1), dtype=tf.float32)
  one_hot_labels = tf.placeholder(dtype=tf.float32, shape=(None, num_classes))
  
  # build DRAM model
  D = DRAM(images, num_units=num_units, num_classes=num_classes, time_steps=T)


Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /Users/gamal/git_local_repo/playground/data/mnist/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /Users/gamal/git_local_repo/playground/data/mnist/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting /Users/gamal/git_local_repo/playground/data/mnist/t10k-images-idx3-ubyte.gz
Extracting /Users/gamal/git_local_repo/playground/data/mnist/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Extracting /Users/gamal/git_local_repo/playground/data/mnist/train-images-idx3-ubyte.gz
Extracting /Users/gamal/git_local_repo/playground/data/mn

In [9]:
with g.as_default():
  logits = D.logits_list[-1]
  
  # accuracy metric
  top1_op = tf.nn.in_top_k(logits, tf.argmax(one_hot_labels, 1), 1)
  accuracy = tf.reduce_mean(tf.cast(top1_op, dtype=tf.float32))
  

In [10]:
with g.as_default():
  loss = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                               labels=one_hot_labels)
  )
  opt = tf.train.AdamOptimizer(learning_rate=init_lr)
  
  train_op = opt.minimize(loss)
  
  init_op = tf.global_variables_initializer()
  
  

In [11]:
with g.as_default():
  saver = tf.train.Saver(tf.global_variables())
  train_dir = "/Users/gamal/git_local_repo/playground/experiments/visual_attention/train"
  if not os.path.exists(train_dir):
    os.mkdir(train_dir)

## Optimize

In [None]:
with g.as_default():
  with tf.Session() as sess:
    sess.run(init_op)
    for i in xrange(max_iter):
      imgs, lbls = data_provider_train.next_batch(
        batch_size=batch_size)
      _, ls, acc = sess.run(
        (train_op, loss, accuracy),
        feed_dict={images: imgs, one_hot_labels: lbls}
      )
      if i % 100 == 0:
        saver.save(sess, os.path.join(train_dir, 'dram_model.ckpt'), global_step=i)
        print(
          "iter %d: (train loss: %.3f) (train accuracy: %.3f)" 
          %(i, ls, acc)
        )
    

iter 0: (train loss: 2.303) (train accuracy: 0.570)
