In [1]:
import gym
import numpy as np

In [2]:
def value_iteration(env, gamma=0.9, theta=1e-6):
    V = np.zeros(env.observation_space.n)
    while True:
        delta = 0
        for s in range(env.observation_space.n):
            Q_values = [sum(prob * (r + gamma * V[s_]) for prob, s_, r, _ in env.P[s][a])
                        for a in range(env.action_space.n)]
            max_q = max(Q_values)
            delta = max(delta, abs(V[s] - max_q))
            V[s] = max_q
        if delta < theta:
            break
    policy = np.zeros(env.observation_space.n, dtype=int)
    for s in range(env.observation_space.n):
        policy[s] = np.argmax([sum(prob * (r + gamma * V[s_]) for prob, s_, r, _ in env.P[s][a])
                               for a in range(env.action_space.n)])
    return policy, V

In [3]:
def policy_iteration(env, gamma=0.9, theta=1e-6):
    policy = np.random.choice(env.action_space.n, env.observation_space.n)
    V = np.zeros(env.observation_space.n)
    while True:
        while True:
            delta = 0
            for s in range(env.observation_space.n):
                v = V[s]
                V[s] = sum(prob * (r + gamma * V[s_]) for prob, s_, r, _ in env.P[s][policy[s]])
                delta = max(delta, abs(v - V[s]))
            if delta < theta:
                break
        policy_stable = True
        for s in range(env.observation_space.n):
            old_action = policy[s]
            policy[s] = np.argmax([sum(prob * (r + gamma * V[s_]) for prob, s_, r, _ in env.P[s][a])
                                   for a in range(env.action_space.n)])
            if old_action != policy[s]:
                policy_stable = False
        if policy_stable:
            break
    return policy, V

In [4]:
def run_taxi_rl():
    env = gym.make("Taxi-v3")
    print("Running Value Iteration...")
    policy_vi, V_vi = value_iteration(env)
    print("Optimal Policy (Value Iteration):", policy_vi)
    print("Running Policy Iteration...")
    policy_pi, V_pi = policy_iteration(env)
    print("Optimal Policy (Policy Iteration):", policy_pi)
    env.close()

In [5]:
if __name__ == "__main__":
    run_taxi_rl()

  deprecation(
  deprecation(


Running Value Iteration...
Optimal Policy (Value Iteration): [4 4 4 4 0 0 0 0 0 0 0 0 0 0 0 0 5 0 0 0 3 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 3
 0 0 0 0 0 0 0 2 2 2 2 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 2 2 2 2 0 0 0 0 0 0
 0 0 0 2 0 0 0 0 0 0 4 4 4 4 0 0 0 0 0 0 0 0 0 5 0 0 1 1 1 1 0 0 0 0 0 0 0
 0 0 0 0 0 1 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 2 2 2 2
 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 2 2 2 2 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 1
 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 1 1 1 1 2 2 2 2 0 0 0 0 2 2 2 2 1 2 0 2 1 1
 1 1 2 2 2 2 3 3 3 3 2 2 2 2 1 2 3 2 3 3 3 3 2 2 2 2 3 3 3 3 2 2 2 2 3 2 3
 2 3 3 3 3 2 2 2 2 3 3 3 3 0 0 0 0 3 2 3 0 3 3 3 3 1 1 1 1 3 3 3 3 0 0 0 0
 3 1 3 0 1 1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 0 1 1 1 1 1 2 2 2 2 1 1 1 1 2
 2 2 2 1 2 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 1 1
 1 1 0 0 0 0 1 2 1 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 1 1 1 0 1 1 1 1 1 1 1
 1 4 4 4 4 1 1 1 1 1 1 5 1 1 1 1 1 2 2 2 2 1 1 1 1 2 2 2 2 1 2 1 2 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 