In [1]:
!pip install gym



In [2]:
import gym
import numpy as np
from collections import defaultdict
import sys
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [3]:
env = gym.make('Taxi-v3')

In [4]:
def make_epsilon_greedy_policy(Q, epsilon, nA):
    def policy_fn(observation):
        A = np.ones(nA, dtype=float) * epsilon / nA
        best_action = np.argmax(Q[observation])
        A[best_action] += (1.0 - epsilon)
        return A
    return policy_fn

In [5]:
def sarsa_lambda(env, num_episodes, discount_factor=0.8, decay_param=0.9, learning_rate=0.01, epsilon=0.1):

  Q = defaultdict(lambda: np.zeros(env.action_space.n))
  E = defaultdict(lambda: np.zeros(env.action_space.n))

  policy = make_epsilon_greedy_policy(Q, epsilon, env.action_space.n)

  for i_episode in range(1, num_episodes+1):
    if i_episode % 10 == 0:
      print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
      sys.stdout.flush()

    state = env.reset()
    probs = policy(state)
    action = np.random.choice(np.arange(len(probs)), p=probs)

     # Repeat for each step of episode:
    for _ in range(200):
      next_state, reward, done, _ = env.step(action)

      next_probs = policy(next_state)
      next_action = np.random.choice(np.arange(len(next_probs)), p=next_probs)

      delta = reward + discount_factor*Q[next_state][next_action] - Q[state][action]

      E[state][action] += 1

      # Calculate the new action value function and eligibility trace of every state-action pair.
      for state in Q.keys():
        Q[state][:] += learning_rate * delta * E[state][:]
        E[state][:] *= discount_factor * decay_param

      if done:
        break

      state = next_state
      action = next_action

  return Q

In [6]:
Q = sarsa_lambda(env, num_episodes=10000)

Episode 10000/10000.

In [7]:
print(Q)

defaultdict(<function sarsa_lambda.<locals>.<lambda> at 0x7f50fade83a0>, {271: array([-5.73032732, -5.79020626, -5.75884749, -5.45632127, -6.57883076,
       -6.59107921]), 371: array([-5.87727181, -5.72963135, -5.90744579, -6.18520627, -6.79330132,
       -6.39092276]), 471: array([-6.04401828, -6.00748174, -5.79084668, -5.97862412, -6.30073374,
       -6.70116351]), 171: array([-5.8080849 , -5.76517331, -5.83472692, -5.81679316, -6.93711736,
       -6.24585814]), 71: array([-5.85946014, -5.92883468, -6.07608952, -5.71899533, -6.68985146,
       -6.41028621]), 291: array([-5.91688073, -6.01810093, -5.89749106, -5.75720616, -6.19831544,
       -6.63938213]), 391: array([-5.91150573, -5.86312908, -5.96878716, -5.71532643, -6.92714951,
       -6.84493012]), 491: array([-5.97295282, -5.73530189, -5.92238016, -5.8897196 , -6.90734687,
       -6.66498944]), 191: array([-5.96066181, -5.86795883, -6.15067185, -5.93528526, -7.06191414,
       -6.96165004]), 251: array([-5.64264658, -5.5776943 

In [8]:
policy = defaultdict(int)
for state in Q.keys():
  action = np.argmax(Q[state])
  policy[state] = action
print(policy)

defaultdict(<class 'int'>, {271: 3, 371: 1, 471: 2, 171: 1, 71: 3, 291: 3, 391: 3, 491: 1, 191: 1, 251: 3, 351: 1, 451: 3, 431: 1, 331: 1, 151: 0, 231: 3, 131: 0, 211: 0, 31: 0, 11: 2, 111: 0, 311: 0, 411: 4, 419: 1, 319: 1, 219: 2, 119: 0, 223: 3, 323: 1, 423: 1, 123: 0, 243: 3, 343: 3, 443: 1, 143: 0, 263: 3, 363: 1, 463: 2, 163: 0, 283: 3, 383: 1, 483: 1, 183: 3, 43: 2, 203: 1, 303: 1, 403: 1, 103: 1, 3: 4, 23: 3, 19: 2, 239: 2, 339: 2, 439: 1, 139: 0, 259: 2, 359: 1, 459: 1, 159: 2, 279: 2, 299: 0, 399: 3, 499: 1, 199: 0, 149: 2, 249: 3, 349: 1, 449: 1, 49: 0, 169: 0, 269: 3, 369: 1, 469: 1, 289: 3, 389: 3, 489: 3, 189: 3, 329: 1, 429: 2, 229: 3, 129: 0, 29: 0, 109: 2, 209: 0, 309: 0, 409: 4, 417: 1, 317: 1, 9: 0, 217: 2, 193: 0, 293: 3, 393: 3, 493: 3, 93: 0, 173: 0, 273: 0, 373: 0, 473: 4, 477: 1, 377: 1, 497: 1, 73: 0, 253: 2, 353: 1, 453: 1, 153: 2, 53: 0, 233: 2, 333: 1, 133: 0, 433: 2, 213: 2, 313: 1, 413: 1, 113: 0, 68: 0, 168: 3, 268: 3, 368: 1, 468: 1, 88: 0, 188: 3, 288: 

In [9]:
import pickle

pickle_out = open("taxi_policy", "wb")
pickle.dump(policy, pickle_out)
pickle_out.close()

In [10]:
policy[532]

0