# RL with Tensorflow

The motivation is to adapt https://medium.com/emergent-future/simple-reinforcement-learning-with-tensorflow-part-1-5-contextual-bandits-bff01d1aad9c for a DSR and PSR task, and see if anything similar to a learning set emerges. This will not be a temporal difference model, instead we will seek to minimize the policy error with the Bellman equations. Not sure this is exactly policy iteration and am generally unsure as to how to frame the problem. So far in the tutorial he's dealt with stationary problems i think and DSR and PSR are non-stationary. Whatever, let's try it and see what happens. 

In [28]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

class Agent():
    def __init__(self, lr, s_size, a_size, h_size):
        '''
        These lines established the feed-forward part of the network. The agent takes a state and produces
        an action.       
        '''
        self.state_in= tf.placeholder(shape = [None, s_size], dtype = tf.float32)
        hidden = slim.fully_connected(self.state_in, h_size, biases_initializer=None, activation_fn=tf.nn.relu)
        self.output = slim.fully_connected(hidden, a_size, activation_fn=tf.nn.softmax, biases_initializer=None)
        self.chosen_action = tf.argmax(self.output, 1)


        '''
        The next six lines establish the training proceedure. We feed the reward and chosen action into the
        network to compute the loss, and use it to update the network.
        '''
        self.reward_holder = tf.placeholder(shape=[None], dtype=tf.float32)
        self.action_holder = tf.placeholder(shape=[None], dtype=tf.int32)
        
        self.indexes = tf.range(0, tf.shape(self.output)[0]) * tf.shape(self.output)[1] + self.action_holder
        self.responsible_outputs = tf.gather(tf.reshape(self.output, [-1]), self.indexes)

        self.loss = -tf.reduce_mean(tf.log(self.responsible_outputs) * self.reward_holder)
        
        tvars = tf.trainable_variables()
        self.gradient_holders = []
        for idx, var in enumerate(tvars):
            placeholder = tf.placeholder(tf.float32, name = str(idx) + '_holder')
            self.gradient_holders.append(placeholder)
        
        self.gradients = tf.gradients(self.loss, tvars)
        
        optimizer = tf.train.AdamOptimizer(learning_rate = lr)
        self.update_batch = optimizer.apply_gradients(zip(self.gradient_holders, tvars))



In [45]:
gamma = 0.99

def discount_rewards(r):
    """ take 1D float array of rewards and compute discounted reward """
    discounted_r = np.zeros_like(r)
    running_add = 0
    for t in reversed(xrange(0, r.size)):
        running_add = running_add * gamma + r[t]
        discounted_r[t] = running_add
    return discounted_r



In [95]:
import pandas as pd

class Environment(object):
    
    def __init__(self,
                 sessionID = 'test',
                 task = 'DSR',
                 noTrials = 120):

        if task == 'DSR':
            self.p = 1
        elif task == 'PSR':
            self.p = 0.85
        self.task = task
        self.noTrials = noTrials
        self.info = pd.DataFrame(np.zeros((self.noTrials, 5)),
                                 columns = ['GA','Choice','Correct','AR','State'])
        self.sessionID = sessionID
        self.block = [0] * 12
        self.current_trial = -1
        self.GA = int(np.random.random() > 0.5)
        self.done = False

    def step(self, choice):
        self.current_trial += 1
        if self.current_trial < self.noTrials:
            #record current goal arm
            self.info['GA'].iat[self.current_trial] = self.GA
            #saving choice
            self.info['Choice'].iat[self.current_trial] = choice
            #save that information
            self.info['Correct'].iat[self.current_trial] = int(choice == self.GA)

            #determine whether you actually get a reward
            if self.info['Correct'].iat[self.current_trial] == 1:
                self.info['AR'].iat[self.current_trial] = int(np.random.random() <= self.p)
            elif self.info['Correct'].iat[self.current_trial] == 0:
                self.info['AR'].iat[self.current_trial] = int(np.random.random() <= (1 - self.p))

            #see if a reversal has been accomplished
            #update block
            self.block[0:-1] = self.block[1:]
            self.block[-1] = self.info['Correct'].iat[self.current_trial]
            #check if criteria happened
            if np.sum(self.block) >= 10:
                #reset counter
                self.block = [0] * 12
                #reversal
                self.GA = (self.GA + 1) % 2
            
        else:
            self.done = 1    
            self.current_trial = self.noTrials - 1
            rev_points = np.nonzero(self.info['GA'].diff())[0]
            rev_points = [w for w in rev_points] + [self.noTrials]
            rev_points = np.diff(rev_points)
            mtuples = [(self.sessionID, i + 1, v) \
                           for i,w in enumerate(rev_points) for v in range(w)]
            index = pd.MultiIndex.from_tuples(mtuples, names=['training_session','block','trials'])
            self.info.index = index
            
        return 0, self.info['AR'].iat[self.current_trial], self.done, self.info

    def reset(self):
        self.__init__(sessionID = self.sessionID, task = self.task, noTrials = self.noTrials)
        return 0

In [43]:
'''
For debugging environment class
'''

env = Environment(task = 'DSR', noTrials = 100)
for i in range(110):
    observation, reward, done, info = env.step(1)
  
env.reset()


array([0.])

In [102]:
tf.reset_default_graph() #Clear the Tensorflow graph.

myAgent = Agent(lr=1e-2, s_size=1, a_size=2, h_size=8) #Load the agent.

total_episodes = 3 #Set total number of episodes to train agent on.
noTrials = 120
update_frequency = 1

init = tf.global_variables_initializer()
env = Environment(task = 'DSR', noTrials = noTrials)

# Launch the tensorflow graph
with tf.Session() as sess:
    sess.run(init)
    i = 0
    total_reward = []
        
    gradBuffer = sess.run(tf.trainable_variables())
    for ix, grad in enumerate(gradBuffer):
        gradBuffer[ix] = grad * 0
        
    while i < total_episodes:
        s = env.reset()
        running_reward = 0
        
        while not env.done:
            #Probabilistically pick an action given our network outputs.
            a_dist = sess.run(myAgent.output, feed_dict={myAgent.state_in: np.array([[s]])})
            a = np.random.choice(a_dist[0], p = a_dist[0])
            a = np.argmax(a_dist == a)
            s1, r, done, info = env.step(a) #Get our reward for taking an action given a bandit.
            print(env.current_trial, r, a, s1, a_dist)
            
            s = s1
            running_reward += r
    
            #Update the network.
            info['AR'] = discount_rewards(info['AR'])
            feed_dict={myAgent.reward_holder : np.array([r]),
                       myAgent.action_holder : np.array([a]),
                       myAgent.state_in : np.array([[s]])}


            grads = sess.run(myAgent.gradients, feed_dict=feed_dict)
            print grads
            for idx, grad in enumerate(grads):
                gradBuffer[idx] += grad
                
            
            feed_dict = dictionary = dict(zip(myAgent.gradient_holders, gradBuffer))
            _ = sess.run(myAgent.update_batch, feed_dict=feed_dict)
            for ix, grad in enumerate(gradBuffer):
                gradBuffer[ix] = grad * 0

        total_reward.append(running_reward)
        i += 1

(0, 1.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(1, 1.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(2, 1.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(3, 1.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],

(46, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(47, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(48, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(49, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 

(102, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(103, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(104, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(105, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [

(35, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(36, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(37, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(38, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 

[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(91, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(92, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(93, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float3

[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(24, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(25, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(26, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float3

(78, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(79, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(80, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)]
(81, 0.0, 0, 0, array([[0.5, 0.5]], dtype=float32))
[array([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 

In [80]:
print total_reward

[10.0, 10.0, 0.0]
