Make Environment

In [None]:
from itertools import product

import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import chapter03.gridworld
import chapter04.gridworld
import chapter04.car_rental
import chapter04.gambler

id = "GridWorld-v0"

if id == "GridWorld-v0":
    exceptional_reward_dynamics = {
        "A": {"from": (0, 1), "to": (4, 1), "reward": 10.0},
        "B": {"from": (0, 3), "to": (2, 3), "reward": 5.0}
    }
    grid = True
    env = gym.make(id, shape=(5, 5), reward_dynamics=exceptional_reward_dynamics)
    policy = 0.25 * np.ones(tuple(env.observation_space.nvec) + (env.action_space.n,), dtype=np.float32)
    action_map = ["right", "up", "left", "down"]
    gamma = 0.9
    val_text = True
    plot_pol = False
elif id == "GridWorld-v1":
    grid = True
    terminal_states = ((0, 0), (3, 3))
    env = gym.make(id, shape=(4, 4), terminal_states=terminal_states)
    policy = 0.25 * np.ones(tuple(env.observation_space.nvec) + (env.action_space.n,), dtype=np.float32)
    action_map = ["right", "up", "left", "down"]
    gamma = 1.0
    val_text = True
    plot_pol = False
elif id.startswith("CarRental"):
    grid = True
    if id == "CarRental-v0":
        env = gym.make("CarRental-v0")
    elif id == "CarRental-v1":
        env = gym.make("CarRental-v0", modified=True)
    policy = np.zeros(tuple(env.observation_space.nvec) + (env.action_space.n,), dtype=np.float32)
    policy[:, :, 0] = 1.0
    action_map = []
    gamma = 0.9
    val_text = False
    plot_pol = True
elif id.startswith("Gambler"):
    grid = False
    if id == "Gambler-v0":
        env = gym.make("Gambler-v0", prob_heads=0.40)
    elif id == "Gambler-v1":
        env = gym.make("Gambler-v0", prob_heads=0.25)
    elif id == "Gambler-v2":
        env = gym.make("Gambler-v0", prob_heads=0.55)
    policy = np.zeros((env.observation_space.n, env.action_space.n), dtype=np.float32)
    policy[:, 0] = 1.0
    action_map = [i + 1 for i in range(env.action_space.n)]
    gamma = 1.0
    val_text = False
    plot_pol = True

Reset Environment

In [None]:
state, _ = env.reset()
print("Initial state:", state)
env.render()

Step Environment

In [None]:
action = np.random.choice(env.action_space.n, p=policy[tuple(state) if grid else state].flatten())
if id.startswith("CarRental"):
    action -= env.action_space.n * (action // (-env.action_space.start + 1))
state, reward, terminated, _, info = env.step(action)
print("Action:", action_map[action] if action_map else action)
print("Reward:", reward)
print("Terminated:", terminated)
print("Info:", info)
print("State:", state)
env.render()

Initialize Variables

In [None]:
num_states = env.observation_space.nvec.prod() if grid else env.observation_space.n
num_actions = env.action_space.n
A = env.unwrapped.prob.reshape(num_states, num_actions, num_states)
b = (A * env.unwrapped.rewards.reshape(num_states, num_actions, num_states)).sum(axis=2, keepdims=True)

if hasattr(env.unwrapped, "terminal_states"):
    if grid:
        terminal = np.arange(num_states).reshape(env.observation_space.nvec)[tuple(zip(*env.unwrapped.terminal_states))]
    else:
        terminal = env.unwrapped.terminal_states
else:
    terminal = []
diagonal = np.diag(np.logical_not(np.logical_or.reduce(np.eye(num_states)[terminal])))

Policy Iteration

In [None]:
while True:
    # Policy Evaluation
    pol_eval = policy.reshape(num_states, 1, num_actions)
    A_eval = diagonal - gamma * (pol_eval @ A).squeeze()
    b_eval = (pol_eval @ b).squeeze()
    v_eval = np.linalg.solve(A_eval, b_eval)

    # Policy Improvement
    old_policy = policy.copy()
    A_impr = gamma * A @ v_eval + b.squeeze()
    if id in ["Gambler-v0", "Gambler-v1"]:
        A_impr = np.round(A_impr, decimals=6)
    argmax = A_impr == np.max(A_impr, axis=1, keepdims=True)
    policy = (argmax / np.sum(argmax, axis=1, keepdims=True))
    if grid:
        policy = policy.reshape((*env.observation_space.nvec, num_actions))
    if not np.any(policy - old_policy):
        break

# Plot Value Function
v = v_eval.reshape(env.observation_space.nvec) if grid else v_eval
if grid:
    if val_text:
        plt.imshow(v, cmap="winter")
        for state in product(*[range(i) for i in v.shape]):
            plt.text(*state[::-1], f"{v[state]:.2f}", ha="center", va="center", color="white")
    else:
        plt.imshow(v, cmap="plasma")
    plt.title("Optimal state-value function")
    plt.colorbar()
    plt.show()
else:
    plt.plot(v)
    plt.title("Optimal state-value function")
    plt.grid(True)
    plt.show()

if plot_pol:
    # Plot Optimal Policy
    if grid:
        deterministic = np.argmax(policy, axis=-1)
        deterministic -= env.action_space.n * (deterministic // (-env.action_space.start + 1))
        plt.imshow(deterministic, cmap="plasma", vmin=env.action_space.start, vmax=-env.action_space.start)
        plt.title("Optimal policy")
        plt.colorbar()
        plt.show()
    else:
        deterministic = np.argmax(policy, axis=-1) + 1
        plt.step(np.arange(env.observation_space.n), deterministic, where="mid")
        plt.title("Optimal policy")
        plt.grid(True)
        plt.show()

Value Iteration

In [None]:
v_iter = np.zeros(num_states)
theta = 0.0001

# Value Iteration
while True:
    old_v = v_iter.copy()
    v_iter = np.max(gamma * A @ v_iter + b.squeeze(), axis=1)
    if max(abs(v_iter - old_v)) < theta:
        break
    
# Optimal Policy
A_impr = gamma * A @ v_iter + b.squeeze()
if id in ["Gambler-v0"]:
    A_impr = np.round(A_impr, decimals=6)
argmax = A_impr == np.max(A_impr, axis=1, keepdims=True)
policy = (argmax / np.sum(argmax, axis=1, keepdims=True))
if grid:
    policy = policy.reshape((*env.observation_space.nvec, num_actions))

# Plot Value Function
v = v_iter.reshape(env.observation_space.nvec) if grid else v_iter
if grid:
    if val_text:
        plt.imshow(v, cmap="winter")
        for state in product(*[range(i) for i in v.shape]):
            plt.text(*state[::-1], f"{v[state]:.2f}", ha="center", va="center", color="white")
    else:
        plt.imshow(v, cmap="plasma")
    plt.title("Optimal state-value function")
    plt.colorbar()
    plt.show()
else:
    plt.plot(v)
    plt.title("Optimal state-value function")
    plt.grid(True)
    plt.show()

if plot_pol:
    # Plot Optimal Policy
    if grid:
        deterministic = np.argmax(policy, axis=-1)
        deterministic -= env.action_space.n * (deterministic // (-env.action_space.start + 1))
        plt.imshow(deterministic, cmap="plasma", vmin=env.action_space.start, vmax=-env.action_space.start)
        plt.title("Optimal policy")
        plt.colorbar()
        plt.show()
    else:
        deterministic = np.argmax(policy, axis=-1) + 1
        plt.step(np.arange(env.observation_space.n), deterministic, where="mid")
        plt.title("Optimal policy")
        plt.grid(True)
        plt.show()