In [2]:
import gym 
import numpy as np 


def decay_schedule(
    init_value, min_value,
    decay_ratio, max_steps,
    log_start=-2, log_base=10):
    decay_steps = int(max_steps * decay_ratio)
    rem_steps = max_steps - decay_steps
    values = np.logspace(
    log_start, 0, decay_steps,
    base=log_base, endpoint=True)[::-1]
    values = (values - values.min()) / \
    (values.max() - values.min())
    values = (init_value - min_value) * values + min_value
    values = np.pad(values, (0, rem_steps), 'edge')
    return values
# Q-Learning Algorithm 
def q_learning(env,
    gamma=1.0,
    init_alpha=0.5,
    min_alpha=0.01,
    alpha_decay_ratio=0.5,
    init_epsilon=1.0,
    min_epsilon=0.1,
    epsilon_decay_ratio=0.9,
    n_episodes=3000):

    nS, nA = env.observation_space.n, env.action_space.n
    pi_track = []
    pi_track=[] 
    Q = np.zeros((nS, nA), dtype=np.float64)
    Q_track = np.zeros((n_episodes, nS, nA), dtype=np.float64)
    select_action = lambda state, Q, epsilon: \
        np.argmax(Q[state]) \
        if np.random.random() > epsilon \
        else np.random.randint(len(Q[state]))
    
    alphas = decay_schedule(
        init_alpha, min_alpha,
        alpha_decay_ratio,
        n_episodes) 
    epsilons = decay_schedule(
        init_epsilon, min_epsilon,
        epsilon_decay_ratio,
        n_episodes)
    for e in range(n_episodes): 
        state, _ = env.reset()
        done = False
        while not done: 
            action = select_action(state, Q, epsilons[e])
            next_state, reward, done, _, _ = env.step(action)
            td_target = reward + gamma * Q[next_state].max() * (not done)
            td_error = td_target - Q[state][action]
            Q[state][action] = Q[state][action] + alphas[e] * td_error
            state = next_state
        Q_track[e] = Q
        pi_track.append(np.argmax(Q, axis=1))
    V = np.max(Q, axis=1)
    pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s] 
    return Q, V, pi, Q_track, pi_track
q_learning(gym.make('CliffWalking-v0'))

(array([[ -15.,  -14.,  -14.,  -15.],
        [ -14.,  -13.,  -13.,  -15.],
        [ -13.,  -12.,  -12.,  -14.],
        [ -12.,  -11.,  -11.,  -13.],
        [ -11.,  -10.,  -10.,  -12.],
        [ -10.,   -9.,   -9.,  -11.],
        [  -9.,   -8.,   -8.,  -10.],
        [  -8.,   -7.,   -7.,   -9.],
        [  -7.,   -6.,   -6.,   -8.],
        [  -6.,   -5.,   -5.,   -7.],
        [  -5.,   -4.,   -4.,   -6.],
        [  -4.,   -4.,   -3.,   -5.],
        [ -15.,  -13.,  -13.,  -14.],
        [ -14.,  -12.,  -12.,  -14.],
        [ -13.,  -11.,  -11.,  -13.],
        [ -12.,  -10.,  -10.,  -12.],
        [ -11.,   -9.,   -9.,  -11.],
        [ -10.,   -8.,   -8.,  -10.],
        [  -9.,   -7.,   -7.,   -9.],
        [  -8.,   -6.,   -6.,   -8.],
        [  -7.,   -5.,   -5.,   -7.],
        [  -6.,   -4.,   -4.,   -6.],
        [  -5.,   -3.,   -3.,   -5.],
        [  -4.,   -3.,   -2.,   -4.],
        [ -14.,  -12.,  -14.,  -13.],
        [ -13.,  -11., -113.,  -13.],
        [ -1