## The Multi-armed bandit

This tutorial contains a simple example of how to build a policy-gradient based agent that can solve the multi-armed bandit problem.

In [1]:
import tensorflow as tf
import numpy as np

### The Bandits

Here we define our bandits. For this example we are using a four-armed bandit. The pullBandit function generates a random number from a normal distribution with a mean of 0. The lower the bandit number, the more likely a positive reward will be returned. We want our agent to learn to always choose the bandit that will give that positive reward.

In [2]:
# Currently bandit 4 (index#3) is set to most often provide a positive reward
bandits = [0.2, 0, -0.2, -5]
n_bandits = len(bandits)

def pull_bandit(bandit):
    result = np.random.randn(1)
    return 1 if result > bandit else -1

### The Agent

The code below establishes our simple neural agent. It consists of a set of values for each of the bandits. Each value is an estimate of the value of the return from choosing the bandit. We use a policy gradient method to update the agent by moving the value for the selected action toward the received reward.

In [3]:
tf.reset_default_graph()

# These two lines established the feed-forward part of the network. This does the actual choosing.
W = tf.Variable(tf.ones([n_bandits]))
a = tf.argmax(W, 0)

reward_ph = tf.placeholder(shape=[1], dtype=tf.float32)
action_ph = tf.placeholder(shape=[1], dtype=tf.int32)
responsible_weight = tf.slice(W, action_ph, [1])
loss = -tf.log(responsible_weight) * reward_ph
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
update_op = optimizer.minimize(loss)

### Training the Agent

We will train our agent by taking actions in our environment, and receiving rewards. Using the rewards and actions, we can know how to properly update our network in order to more often choose actions that will yield the highest rewards over time.

In [4]:
n_episodes = 1000
total_reward = np.zeros(n_bandits)  # set scoreboard for bandits to 0.
epsilon = 0.1 # set the chance of taking a random action

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(n_episodes):
        # explore else exploit
        if np.random.rand(1) < epsilon:
            action = np.random.randint(n_bandits)
        else:
            action = sess.run(a)
            
        reward = pull_bandit(bandits[action])
        
        # Update the network
        _, resp_w, W1 = sess.run([update_op, responsible_weight, W], feed_dict={
            reward_ph: [reward],
            action_ph: [action]
        })
        
        # Update the running tally of scores
        total_reward[action] += reward
        
        if i % 50 == 0:
            print('Running reward for the %i bandits: %s' % (n_bandits, str(total_reward)))

Running reward for the 4 bandits: [1. 0. 0. 0.]
Running reward for the 4 bandits: [-2. -2.  1. 40.]
Running reward for the 4 bandits: [-2. -2.  1. 88.]
Running reward for the 4 bandits: [ -2.  -2.   1. 138.]
Running reward for the 4 bandits: [ -3.  -2.   1. 185.]
Running reward for the 4 bandits: [ -3.  -2.   1. 233.]
Running reward for the 4 bandits: [ -2.  -3.   2. 278.]
Running reward for the 4 bandits: [ -1.  -2.   2. 324.]
Running reward for the 4 bandits: [  0.  -1.   3. 371.]
Running reward for the 4 bandits: [  2.   0.   2. 417.]
Running reward for the 4 bandits: [  0.   0.   1. 464.]
Running reward for the 4 bandits: [  0.   0.   1. 514.]
Running reward for the 4 bandits: [ -2.   0.   2. 559.]
Running reward for the 4 bandits: [ -2.   0.   2. 605.]
Running reward for the 4 bandits: [ -4.   1.   1. 649.]
Running reward for the 4 bandits: [ -5.   1.   2. 697.]
Running reward for the 4 bandits: [ -5.   3.   2. 745.]
Running reward for the 4 bandits: [ -5.   2.   5. 791.]
Running 

In [5]:
print('The agent thinks bandit %s is the most promising' % str(np.argmax(W1) + 1))
if np.argmax(W1) == np.argmax(-np.array(bandits)):
    print('and it is right')
else:
    print('and it is wrong')

The agent thinks bandit 4 is the most promising
and it is right
