In [10]:
import gym
import numpy as np

def decay_schedule(
    init_value, min_value,
    decay_ratio, max_steps,
    log_start=-2, log_base=10):
    decay_steps = int(max_steps * decay_ratio)
    rem_steps = max_steps - decay_steps
    values = np.logspace(
    log_start, 0, decay_steps,
    base=log_base, endpoint=True)[::-1]
    values = (values - values.min()) / \
    (values.max() - values.min())
    values = (init_value - min_value) * values + min_value
    values = np.pad(values, (0, rem_steps), 'edge')
    return values

def sarsa(env, gamma=1.0, init_alpha=0.5, min_alpha=0.01, alpha_decay_ratio=0.5, init_epsilon=1.0, min_epsilon=0.1, epsilon_decay_ratio=0.9, n_episodes=3000):
    nS, nA = env.observation_space.n, env.action_space.n
    pi_track = []
    Q = np.zeros((nS, nA), dtype=np.float64)
    Q_track = np.zeros((n_episodes, nS, nA), dtype=np.float64)
    select_action = lambda state, Q, epsilon: np.argmax(Q[state]) if np.random.random() > epsilon else np.random.randint(len(Q[state]))
    alphas = decay_schedule(
    init_alpha, min_alpha,
    alpha_decay_ratio,
    n_episodes)
    epsilons = decay_schedule(
    init_epsilon, min_epsilon,
    epsilon_decay_ratio,
    n_episodes)
    for e in range(n_episodes):
        state, info = env.reset()
        done = False
        action = select_action(state, Q, epsilons[e])
        while not done:
            next_state, reward, done, _, _ = env.step(action)
            next_action = select_action(next_state,
            Q,
            epsilons[e])
            td_target = reward + gamma * \
            Q[next_state][next_action] * (not done)
            td_error = td_target - Q[state][action]
            Q[state][action] = Q[state][action] + \
            alphas[e] * td_error
            state, action = next_state, next_action
        Q_track[e] = Q
        pi_track.append(np.argmax(Q, axis=1))
    V = np.max(Q, axis=1)
    pi = lambda s: {s:a for s, a in enumerate(\
    np.argmax(Q, axis=1))}[s]
    return Q, V, pi, Q_track, pi_track

Q, V, pi, Q_track, pi_track = sarsa(gym.make('CliffWalking-v0'))
print(f"Q:\n{Q}")
print(f"V:\n{V}")
print(f"pi:\n{pi}")
print(f"Q_track:\n{Q_track}")
print(f"pi_track:\n{pi_track}")

Q:
[[-1.12623422e+03 -4.22176232e+02 -1.27360669e+03 -1.13959827e+03]
 [-9.79340017e+02 -3.65358471e+02 -1.31056042e+03 -1.02251236e+03]
 [-7.91064796e+02 -3.03747209e+02 -1.33005722e+03 -1.07065023e+03]
 [-7.16647818e+02 -2.39933264e+02 -1.19362808e+03 -7.99215622e+02]
 [-6.14297240e+02 -1.75637303e+02 -1.06026103e+03 -7.52138347e+02]
 [-5.20546590e+02 -1.53546801e+02 -1.00473804e+03 -5.88244095e+02]
 [-3.93622268e+02 -1.01823670e+02 -7.27423310e+02 -4.81186246e+02]
 [-2.83951785e+02 -8.17279416e+01 -6.88908895e+02 -3.79480194e+02]
 [-1.97163998e+02 -5.98642459e+01 -5.06986015e+02 -3.19200562e+02]
 [-1.50278734e+02 -4.60029842e+01 -3.84292423e+02 -2.13462036e+02]
 [-1.27708214e+02 -2.93264707e+01 -1.64542165e+02 -1.67190013e+02]
 [-7.98093307e+01 -9.08032266e+01 -2.07614795e+01 -1.31058351e+02]
 [-5.08315917e+02 -1.37963925e+03 -1.32374075e+03 -1.24046017e+03]
 [-6.75901477e+02 -2.00434568e+03 -2.54924123e+03 -2.06917354e+03]
 [-8.01417516e+02 -2.02553698e+03 -3.15074327e+03 -2.232408