# Reinforcement Learning: multi-armed bandit with Tensorflow

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
bandits = [0.2,0,-0.2,-5]
num_bandits = len(bandits)

In [3]:
def pullBandit(banditNumber):
    bandit = bandits[banditNumber]
    result = np.random.randn(1)
    
    if result > bandit:
        return 1
    else:
        return -1
    

In [4]:
tf.reset_default_graph()

weights = tf.Variable(tf.ones([num_bandits]))
best_action = tf.argmax(weights,0)

selected_action = tf.placeholder(shape = [1], dtype=tf.int32)
selected_action_reward = tf.placeholder(shape = [1],dtype=tf.float32)
selected_action_weight = tf.slice(weights,selected_action,[1])
advantage = selected_action_reward
loss = -(tf.log(selected_action_weight)*advantage)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
update = optimizer.minimize(loss)

In [5]:
total_episodes = 1000
total_reward = np.zeros(num_bandits)
explore_probability = 0.1
stats_print_steps = 50

In [6]:
initialize = tf.global_variables_initializer()

with tf.Session() as session:
    session.run(initialize)
    
    for step in range(total_episodes):
        if np.random.rand(1) < explore_probability:
            action = np.random.randint(num_bandits)
        else:
            action = session.run(best_action)
            
        reward = pullBandit(action)
        
        feed_dict = {selected_action:[action],selected_action_reward:[reward]}
        _,selected_weight,all_weights = session.run([update,selected_action_weight,weights],
                                                   feed_dict=feed_dict)
        
        total_reward[action]+=reward 
        
        if step% stats_print_steps == 0:
            print("Step "+str(step)+" ,bandits rewards:"+str(total_reward))
    print("Final rewards:"+str(total_reward))
        
    print("The agent thinks bandit "+str(np.argmax(all_weights)+1)+ " is the best ")
    if np.argmax(all_weights) == np.argmax(-np.array(bandits)):
        print("...The agent got the right answer")
    else:
        print("...The agent got the wrong answer")

Step 0 ,bandits rewards:[ 1.  0.  0.  0.]
Step 50 ,bandits rewards:[-1.  8.  0.  0.]
Step 100 ,bandits rewards:[ -1.   2.   0.  26.]
Step 150 ,bandits rewards:[ -1.   1.   1.  74.]
Step 200 ,bandits rewards:[  -1.    1.    1.  122.]
Step 250 ,bandits rewards:[  -1.    4.    1.  169.]
Step 300 ,bandits rewards:[  -1.    3.    1.  216.]
Step 350 ,bandits rewards:[  -3.    1.    0.  259.]
Step 400 ,bandits rewards:[  -3.    0.   -1.  305.]
Step 450 ,bandits rewards:[  -5.    1.    1.  350.]
Step 500 ,bandits rewards:[  -5.    3.    2.  395.]
Step 550 ,bandits rewards:[  -5.    4.    1.  441.]
Step 600 ,bandits rewards:[  -5.    4.    2.  488.]
Step 650 ,bandits rewards:[  -6.    6.    3.  532.]
Step 700 ,bandits rewards:[  -5.    6.    3.  581.]
Step 750 ,bandits rewards:[  -4.    6.    3.  628.]
Step 800 ,bandits rewards:[  -2.    8.    4.  673.]
Step 850 ,bandits rewards:[  -1.   11.    5.  718.]
Step 900 ,bandits rewards:[  -2.   12.    5.  764.]
Step 950 ,bandits rewards:[  -3.   12. 