In [1]:
import numpy as np
import pandas as pd
from gridworld import Environment, Agent
import matplotlib.pyplot as plt

In [2]:
def update_q(q, eligibility_trace, alpha, gamma, state_prime, state, reward_prime, action_prime, action):
    delta = reward_prime + gamma*q[state_prime, action_prime]-q[state, action]
    q = q + alpha*delta*eligibility_trace
    return q
    
def update_eligibilty_trace(eligibility_trace, gamma,lambd, state, action):
    eligibility_trace_prime = gamma*lambd*eligibility_trace
    eligibility_trace_prime[state, action] =  gamma*lambd*eligibility_trace[state,action]+1
    return eligibility_trace_prime

def update_policy(q):
    policy = q.argmax(axis = 1)
    return policy

def initialize():
    q = np.zeros((12,4))
    eligibility_trace = np.zeros((12,4))
    agent = Agent()
    env = Environment(random_initial_state=True)
    state = env.current_state
    reward = env.initial_reward
    episode_finished=False
    return agent, env, state, reward, episode_finished

In [5]:
def sarsa(convergence_criterion = 10000, alpha=0.01, gamma=0.999, lambd = 0.0, epsilon = None):
    n_iters = 0
    same_policy_iter = 0
    n_episodes = 0
    q = np.zeros((12,4))
    eligibility_trace = np.zeros((12,4))
    agent, env, state, reward, episode_finished = initialize()
    no_action_states = env.impossible_states + env.terminal_states
    while same_policy_iter < convergence_criterion:
        action = agent.step(state, epsilon = epsilon)
        state_prime, reward_prime, episode_finished = env.step(action)
        action_prime =  agent.step(state_prime) # Just observing the action, not applying it
        eligibility_trace = update_eligibilty_trace(eligibility_trace, gamma,lambd, state, action)
        q = update_q(q, eligibility_trace, alpha,gamma,state_prime, state, reward_prime, action_prime, action)
        previous_policy = agent.policy
        agent.policy = update_policy(q)
        state, reward = state_prime, reward_prime
        current_policy = agent.policy
        
        previous_policy = previous_policy.ravel()
        previous_policy[no_action_states] = 0
        current_policy = current_policy.ravel()
        current_policy[no_action_states] = 0
    
        
        if np.array_equal(previous_policy, agent.policy):
            same_policy_iter += 1
        else:
            same_policy_iter = 0
        n_iters += 1
        if n_iters%10000 == 0:
            print('Iteration {} ---- Current policy same for {} iterations'.format(n_iters, same_policy_iter))
            agent.render_policy()
        if episode_finished:
            n_episodes += 1
            agent, env, state, reward, episode_finished = initialize()
    print('Final iteration number {}'.format(n_iters))
    return agent.policy

In [4]:
sarsa()

Iteration 10000 ---- Current policy same for 0 iterations
[['v' '>' 'v' '^']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Iteration 20000 ---- Current policy same for 2 iterations
[['v' '<' 'v' '^']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Iteration 30000 ---- Current policy same for 0 iterations
[['v' '<' 'v' '<']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Iteration 40000 ---- Current policy same for 6880 iterations
[['v' '<' '<' '<']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Final iteration number 43120


In [5]:
sarsa(lambd = 0.5)

Iteration 10000 ---- Current policy same for 0 iterations
[['v' '<' 'v' '<']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Iteration 20000 ---- Current policy same for 0 iterations
[['v' '<' 'v' '<']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Iteration 30000 ---- Current policy same for 783 iterations
[['v' '<' '<' '<']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Iteration 40000 ---- Current policy same for 5 iterations
[['v' '<' 'v' '<']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Iteration 50000 ---- Current policy same for 6444 iterations
[['v' '<' '<' '<']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Final iteration number 53556


In [None]:
best_policy = sarsa(epsilon = 0.1)

Iteration 10000 ---- Current policy same for 3 iterations
[['v' '>' 'v' '^']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


Iteration 20000 ---- Current policy same for 6 iterations
[['v' '<' 'v' '<']
 ['v' '*' 'v' '*']
 ['>' '>' '>' '*']]


