In [2]:
import gym 
import numpy as np 


def decay_schedule(
    init_value, min_value,
    decay_ratio, max_steps,
    log_start=-2, log_base=10):
    decay_steps = int(max_steps * decay_ratio)
    rem_steps = max_steps - decay_steps
    values = np.logspace(
    log_start, 0, decay_steps,
    base=log_base, endpoint=True)[::-1]
    values = (values - values.min()) / \
    (values.max() - values.min())
    values = (init_value - min_value) * values + min_value
    values = np.pad(values, (0, rem_steps), 'edge')
    return values
# Q-Learning Algorithm 
def q_learning(env,
    gamma=1.0,
    init_alpha=0.5,
    min_alpha=0.01,
    alpha_decay_ratio=0.5,
    init_epsilon=1.0,
    min_epsilon=0.1,
    epsilon_decay_ratio=0.9,
    n_episodes=3000):

    nS, nA = env.observation_space.n, env.action_space.n
    pi_track = []
    pi_track=[] 
    Q = np.zeros((nS, nA), dtype=np.float64)
    Q_track = np.zeros((n_episodes, nS, nA), dtype=np.float64)
    select_action = lambda state, Q, epsilon: \
        np.argmax(Q[state]) \
        if np.random.random() > epsilon \
        else np.random.randint(len(Q[state]))
    
    alphas = decay_schedule(
        init_alpha, min_alpha,
        alpha_decay_ratio,
        n_episodes) 
    epsilons = decay_schedule(
        init_epsilon, min_epsilon,
        epsilon_decay_ratio,
        n_episodes)
    for e in range(n_episodes): 
        state, _ = env.reset()
        done = False
        while not done: 
            action = select_action(state, Q, epsilons[e])
            next_state, reward, done, _, _ = env.step(action)
            td_target = reward + gamma * Q[next_state].max() * (not done)
            td_error = td_target - Q[state][action]
            Q[state][action] = Q[state][action] + alphas[e] * td_error
            state = next_state
        Q_track[e] = Q
        pi_track.append(np.argmax(Q, axis=1))
    V = np.max(Q, axis=1)
    pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s] 
    return Q, V, pi, Q_track, pi_track
q_learning(gym.make('FrozenLake-v1'))

  if not isinstance(terminated, (bool, np.bool8)):


(array([[0.74995351, 0.6807002 , 0.68066623, 0.68079815],
        [0.35363666, 0.33626203, 0.31168319, 0.60183691],
        [0.46488813, 0.31384783, 0.28529757, 0.35813318],
        [0.0633004 , 0.07682739, 0.05185138, 0.30524356],
        [0.75150137, 0.49193473, 0.50174793, 0.47013392],
        [0.        , 0.        , 0.        , 0.        ],
        [0.25220991, 0.18120387, 0.36849018, 0.12507714],
        [0.        , 0.        , 0.        , 0.        ],
        [0.49477658, 0.46147414, 0.4822511 , 0.75497888],
        [0.49095078, 0.75792864, 0.45555169, 0.38368907],
        [0.69370755, 0.45293576, 0.35984855, 0.29030929],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.46039382, 0.62327574, 0.84112826, 0.48460498],
        [0.69882398, 0.91170202, 0.77386132, 0.78960212],
        [0.        , 0.        , 0.        , 0.        ]]),
 array([0.74995351, 0.60183691, 0.46488813, 0.30524356, 0.75150137,
  