In [2]:
import gym
import numpy as np

def decay_schedule(
    init_value, min_value,
    decay_ratio, max_steps,
    log_start=-2, log_base=10):
    decay_steps = int(max_steps * decay_ratio)
    rem_steps = max_steps - decay_steps
    values = np.logspace(
    log_start, 0, decay_steps,
    base=log_base, endpoint=True)[::-1]
    values = (values - values.min()) / \
    (values.max() - values.min())
    values = (init_value - min_value) * values + min_value
    values = np.pad(values, (0, rem_steps), 'edge')
    return values

def sarsa(env, gamma=1.0, init_alpha=0.5, min_alpha=0.01, alpha_decay_ratio=0.5, init_epsilon=1.0, min_epsilon=0.1, epsilon_decay_ratio=0.9, n_episodes=3000):
    nS, nA = env.observation_space.n, env.action_space.n
    pi_track = []
    Q = np.zeros((nS, nA), dtype=np.float64)
    Q_track = np.zeros((n_episodes, nS, nA), dtype=np.float64)
    select_action = lambda state, Q, epsilon: np.argmax(Q[state]) if np.random.random() > epsilon else np.random.randint(len(Q[state]))
    alphas = decay_schedule(
    init_alpha, min_alpha,
    alpha_decay_ratio,
    n_episodes)
    epsilons = decay_schedule(
    init_epsilon, min_epsilon,
    epsilon_decay_ratio,
    n_episodes)
    for e in range(n_episodes):
        state, info = env.reset()
        done = False
        action = select_action(state, Q, epsilons[e])
        while not done:
            next_state, reward, done, _, _ = env.step(action)
            next_action = select_action(next_state,
            Q,
            epsilons[e])
            td_target = reward + gamma * \
            Q[next_state][next_action] * (not done)
            td_error = td_target - Q[state][action]
            Q[state][action] = Q[state][action] + \
            alphas[e] * td_error
            state, action = next_state, next_action
        Q_track[e] = Q
        pi_track.append(np.argmax(Q, axis=1))
    V = np.max(Q, axis=1)
    pi = lambda s: {s:a for s, a in enumerate(\
    np.argmax(Q, axis=1))}[s]
    return Q, V, pi, Q_track, pi_track

sarsa(gym.make('CliffWalking-v0'))

  if not isinstance(terminated, (bool, np.bool8)):


(array([[-1.26494313e+03, -5.26968435e+02, -1.47629433e+03,
         -1.27089892e+03],
        [-1.09823532e+03, -4.52344660e+02, -1.42656715e+03,
         -1.22035827e+03],
        [-9.86189251e+02, -3.98894621e+02, -1.47309144e+03,
         -1.11734876e+03],
        [-8.60409431e+02, -3.12047209e+02, -1.46781328e+03,
         -1.07542533e+03],
        [-6.76626155e+02, -2.33357758e+02, -1.19993269e+03,
         -8.04027085e+02],
        [-5.49201931e+02, -1.95847668e+02, -1.11826370e+03,
         -7.22779601e+02],
        [-4.70329207e+02, -1.57966483e+02, -8.90822382e+02,
         -5.63763260e+02],
        [-3.72368867e+02, -1.28931129e+02, -7.65678520e+02,
         -4.83998833e+02],
        [-2.75961267e+02, -8.37925433e+01, -5.70595660e+02,
         -4.05087831e+02],
        [-2.22218939e+02, -6.13245449e+01, -4.43784348e+02,
         -2.78896871e+02],
        [-1.66471623e+02, -4.46179760e+01, -2.45535997e+02,
         -2.20231142e+02],
        [-1.25983808e+02, -1.33420712e+02, 