<a href="https://colab.research.google.com/github/jenny005/Reinforcement-Learning-by-Sutton-Barto/blob/main/Chapter_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chapter 3 : Finite Markov Decision Processes

# Example 3.5 Gridworld

In [22]:
# modify from
#######################################################################
# Copyright (C)                                                       #
# 2016-2018 Shangtong Zhang(zhangshangtong.cpp@gmail.com)             #
# 2016 Kenta Shimada(hyperkentakun@gmail.com)                         #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################


import numpy as np
import pandas as pd

# --- Gridworld Setup (Figure 3.2) ---
WORLD_SIZE = 5
DISCOUNT = 0.9
ACTION_PROB = 0.25

A_POS = (0, 1)
A_PRIME_POS = (4, 1)
B_POS = (0, 3)
B_PRIME_POS = (2, 3)

ACTIONS = [
    np.array([0, -1]),  # Left
    np.array([-1, 0]),  # Up
    np.array([0, 1]),   # Right
    np.array([1, 0])    # Down
]

def step(state, action):
    """Transition dynamics following Figure 3.2 logic."""
    if tuple(state) == A_POS:
        return A_PRIME_POS, 10
    if tuple(state) == B_POS:
        return B_PRIME_POS, 5

    next_state = np.array(state) + action
    if 0 <= next_state[0] < WORLD_SIZE and 0 <= next_state[1] < WORLD_SIZE:
        return tuple(next_state), 0  # valid move
    else:
        return tuple(state), -1      # wall hit

def evaluate_uniform_random_policy():
    """Evaluate uniform random policy using iterative policy evaluation."""
    value = np.zeros((WORLD_SIZE, WORLD_SIZE))
    while True:
        new_value = np.zeros_like(value)
        for i in range(WORLD_SIZE):
            for j in range(WORLD_SIZE):
                v = 0
                for action in ACTIONS:
                    (ni, nj), reward = step((i, j), action)
                    v += ACTION_PROB * (reward + DISCOUNT * value[ni, nj])
                new_value[i, j] = v
        if np.sum(np.abs(new_value - value)) < 1e-4:
            break
        value = new_value
    return np.round(value, decimals=1)

def format_value_grid(value_grid):
    """Formats the grid for clean display."""
    return [["{:+.1f}".format(v) for v in row] for row in value_grid]

if __name__ == "__main__":
    value_grid = evaluate_uniform_random_policy()
    formatted_grid = format_value_grid(value_grid)

    # Display as a pandas DataFrame
    df = pd.DataFrame(formatted_grid)
    print("=== Value Function: Uniform Random Policy (Figure 3.2) ===")
    print(df.to_string(index=False, header=False))


=== Value Function: Uniform Random Policy (Figure 3.2) ===
+3.3 +8.8 +4.4 +5.3 +1.5
+1.5 +3.0 +2.3 +1.9 +0.5
+0.1 +0.7 +0.7 +0.4 -0.4
-1.0 -0.4 -0.4 -0.6 -1.2
-1.9 -1.3 -1.2 -1.4 -2.0


In [25]:
# modify from
#######################################################################
# Copyright (C)                                                       #
# 2016-2018 Shangtong Zhang(zhangshangtong.cpp@gmail.com)             #
# 2016 Kenta Shimada(hyperkentakun@gmail.com)                         #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################

import numpy as np
import pandas as pd

# --- Gridworld Setup (Figure 3.5) ---
WORLD_SIZE = 5
DISCOUNT = 0.9
A_POS = (0, 1)
A_PRIME_POS = (4, 1)
B_POS = (0, 3)
B_PRIME_POS = (2, 3)

ACTIONS = [
    np.array([0, -1]),  # Left
    np.array([-1, 0]),  # Up
    np.array([0, 1]),   # Right
    np.array([1, 0])    # Down
]
ACTION_SYMBOLS = ['←', '↑', '→', '↓']

def step(state, action):
    if tuple(state) == A_POS:
        return A_PRIME_POS, 10
    if tuple(state) == B_POS:
        return B_PRIME_POS, 5

    next_state = np.array(state) + action
    if 0 <= next_state[0] < WORLD_SIZE and 0 <= next_state[1] < WORLD_SIZE:
        return tuple(next_state), 0
    else:
        return tuple(state), -1

def value_iteration(theta=1e-4):
    value = np.zeros((WORLD_SIZE, WORLD_SIZE))
    while True:
        delta = 0
        new_value = np.copy(value)
        for i in range(WORLD_SIZE):
            for j in range(WORLD_SIZE):
                v_list = []
                for action in ACTIONS:
                    (ni, nj), reward = step((i, j), action)
                    v = reward + DISCOUNT * value[ni, nj]
                    v_list.append(v)
                best_v = max(v_list)
                new_value[i, j] = best_v
                delta = max(delta, abs(best_v - value[i, j]))
        value = new_value
        if delta < theta:
            break
    return np.round(value, 1)

def extract_policy(value):
    policy_grid = [['' for _ in range(WORLD_SIZE)] for _ in range(WORLD_SIZE)]
    for i in range(WORLD_SIZE):
        for j in range(WORLD_SIZE):
            v_list = []
            for idx, action in enumerate(ACTIONS):
                (ni, nj), reward = step((i, j), action)
                v = reward + DISCOUNT * value[ni, nj]
                v_list.append((v, idx))
            max_v = max(v_list)[0]
            best_actions = [ACTION_SYMBOLS[idx] for v, idx in v_list if v == max_v]
            policy_grid[i][j] = ''.join(best_actions)
    return policy_grid

def print_grid(grid, title):
    print(f"\n=== {title} ===")
    for row in grid:
        print(' '.join(str(x).rjust(4) for x in row))

if __name__ == "__main__":
    optimal_value = value_iteration()
    optimal_policy = extract_policy(optimal_value)

    print_grid(optimal_value, "Optimal Value Function (Figure 3.5)")
    print_grid(optimal_policy, "Optimal Policy (Arrows)")



=== Optimal Value Function (Figure 3.5) ===
22.0 24.4 22.0 19.4 17.5
19.8 22.0 19.8 17.8 16.0
17.8 19.8 17.8 16.0 14.4
16.0 17.8 16.0 14.4 13.0
14.4 16.0 14.4 13.0 11.7

=== Optimal Policy (Arrows) ===
   → ←↑→↓    ← ←↑→↓    ←
  ↑→    ↑   ←↑    ←    ←
  ↑→    ↑   ←↑   ←↑   ←↑
  ↑→    ↑   ←↑   ←↑   ←↑
  ↑→    ↑   ←↑   ←↑   ←↑
