In [1]:
import numpy as np
import pandas as pd
from gridworld import Environment, Agent
import matplotlib.pyplot as plt

In [2]:
def update_q(q, eligibility_trace, alpha, gamma, state_prime, state, reward_prime, action_prime, action):
    delta = reward_prime + gamma*q[state_prime, action_prime]-q[state, action]
    q = q + alpha*delta*eligibility_trace
    return q
    
def update_eligibilty_trace(eligibility_trace, gamma,lambd, state, action):
    eligibility_trace_prime = gamma*lambd*eligibility_trace
    eligibility_trace_prime[state, action] =  gamma*lambd*eligibility_trace[state,action]+1
    return eligibility_trace_prime

def update_policy(q):
    policy = q.argmax(axis = 1)
    return policy

In [3]:
def SARSA_lambda_episode(q, eligibility_trace, alpha=0.1, gamma=0.9, lambd = 0.5):
    ''' Runs the temporal difference algorithm for one episode
    Args:
        q (2d array, (n_states, n_actions)): the initial q fuction
        eligibility_trace (2d array, (n_states, n_actions)): the initial eligibility trace
        alpha (float, optional): learning_rate
        gamma (float, optional): discount rate
        lamb (float, optional): eligibilty trace decay rate
    Returns:
        q (2d array, (n_states, n_actions)): resulting q function
        df (Dataframe): progress over iterations
    '''
    agent = Agent()
    env = Environment()
    state = env.current_state
    reward = env.initial_reward
    episode_finished=False
    # for tracking utility
    i = 0
    while not episode_finished:
        i += 1
        action = agent.step(state)
        state_prime, reward_prime, episode_finished = env.step(action)
        action_prime =  agent.step(state_prime) # Just observing the action, not applying it
        eligibility_trace = update_eligibilty_trace(eligibility_trace, gamma,lambd, state, action)
        q = update_q(q, eligibility_trace, alpha,gamma,state_prime, state, reward_prime, action_prime, action)
        agent.policy = update_policy(q)
        state, reward = state_prime, reward_prime
        agent.render_policy()
    return q

In [4]:
# initializing
q = np.zeros((12,4))
eligibility_trace = np.zeros((12,4))
lambd = 0.5
q = SARSA_lambda_episode(q, eligibility_trace, lambd = lambd)

[['<' '<' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['<' '>' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['>' '>' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '>' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '^' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' 'v' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' 'v' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '^' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '>' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '<' '>' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '<' '^' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '<' '^' '>']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '<' 'v' '>']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '<' 'v' '^']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '<' 'v' 'v']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


[['^' '<' 'v' '<']
 ['<' '<' '<' '<']
 ['<' '<' '<' '<']]


