In [1]:
import numpy as np

In [2]:
class GridWorld:
    def __init__(self, height=3, width=4):
        self.height = height
        self.width = width
        self.grid = np.zeros((height, width))
        self.terminal_states = {(0, 3): 1, (1, 3): -1}  # (state): reward
        self.living_reward = -0.04
        self.gamma = 1.0
        self.p_intended = 0.8
        self.p_perpendicular = 0.1  # For each perpendicular direction
        self.actions = [(0, 1), (1, 0), (0, -1), (-1, 0)]  # Right, Down, Left, Up
        
    def is_valid_state(self, state):
        row, col = state
        if row < 0 or row >= self.height or col < 0 or col >= self.width:
            return False
        if (row, col) == (1, 1):  # Wall
            return False
        return True
    
    def get_transition_probs(self, state, action):
        if state in self.terminal_states:
            return [(state, 1.0)]
        
        transitions = []

        perp1 = (action[1], action[0])    # Rotate 90° clockwise
        perp2 = (-action[1], -action[0])  # Rotate 90° counterclockwise
        
        for next_action, prob in [(action, self.p_intended), 
                                  (perp1, self.p_perpendicular),
                                  (perp2, self.p_perpendicular)]:
            next_state = (state[0] + next_action[0], state[1] + next_action[1])
            if self.is_valid_state(next_state):
                transitions.append((next_state, prob))
            else:
                transitions.append((state, prob))  # Stay in current state
                
        return transitions
    
    def get_reward(self, state):
        if state in self.terminal_states:
            return self.terminal_states[state]
        return self.living_reward

# Value iteration

In [3]:
def value_iteration(grid, threshold=1e-3):
    # Initialize values
    V = {(i, j): 0 for i in range(grid.height) for j in range(grid.width) 
            if grid.is_valid_state((i, j))}
    
    iteration = 0
    while True:
        biggest_change = 0
        V_new = V.copy()
        
        # Update each state
        for state in V:
            if state in grid.terminal_states:
                V_new[state] = grid.get_reward(state)
                
            else:
                # Calculate max_a \sum_{s'} P(s'|s,a) V(s')
                max_q = float('-inf')

                for action in grid.actions:
                    q = 0
                    for next_state, prob in grid.get_transition_probs(state, action):
                        q += prob * V[next_state]
                    max_q = max(max_q, q)
                
                V_new[state] = grid.get_reward(state) + grid.gamma * max_q
            biggest_change = max(biggest_change, abs(V_new[state] - V[state]))
        
        V = V_new
        iteration += 1
        
        # Check convergence
        if biggest_change < threshold:
            break
            
    return V, iteration

In [4]:

grid = GridWorld()
V, iterations = value_iteration(grid)

print(f"\nConverged after {iterations} iterations")
print("\nFinal values:")
for i in range(grid.height):
    for j in range(grid.width):
        if not grid.is_valid_state((i, j)):
            print("   XXXXX ", end="")
        else:
            print(f" {V[(i, j)]:7.3f} ", end="")
    print()


Converged after 20 iterations

Final values:
   0.812    0.868    0.918    1.000 
   0.762    XXXXX    0.660   -1.000 
   0.705    0.655    0.611    0.387 


In [5]:
def policy_extraction(grid, V):
    policy = {state: None for state in V}
    
    for state in V:
        if state in grid.terminal_states:
            policy[state] = None
        
        else:
            max_q = float('-inf')
            best_action = None
            for action in grid.actions:
                q = 0
                for next_state, prob in grid.get_transition_probs(state, action):
                    q += prob * V[next_state]
                if q > max_q:
                    max_q = q
                    best_action = action
            policy[state] = best_action
        
    return policy

policy = policy_extraction(grid, V)
policy


{(0, 0): (0, 1),
 (0, 1): (0, 1),
 (0, 2): (0, 1),
 (0, 3): None,
 (1, 0): (-1, 0),
 (1, 2): (-1, 0),
 (1, 3): None,
 (2, 0): (-1, 0),
 (2, 1): (0, -1),
 (2, 2): (0, -1),
 (2, 3): (0, -1)}

In [6]:
def print_policy(grid, policy):
    for i in range(grid.height):
        for j in range(grid.width):
            if not grid.is_valid_state((i, j)):
                print("\tX", end="")
            else:
                if policy[(i, j)] == (0, 1):
                    print("\t\u2192", end="")
                elif policy[(i, j)] == (1, 0):
                    print("\t\u2193", end="")
                elif policy[(i, j)] == (0, -1):
                    print("\t\u2190", end="")
                elif policy[(i, j)] == (-1, 0):
                    print("\t\u2191", end="")
                else:
                    print("\t*", end="")
        print()

print_policy(grid, policy)

	→	→	→	*
	↑	X	↑	*
	↑	←	←	←


# Policy iteration

In [7]:
def policy_evaluation(grid, policy, updates=20):
    V = {state: 0 for state in policy}

    for _ in range(updates):
        V_new = V.copy()
        for state in V:
            if state in grid.terminal_states:
                V_new[state] = grid.get_reward(state)

            else:
                action = policy[state]
                q = 0
                for next_state, prob in grid.get_transition_probs(state, action):
                    q += prob * V[next_state]
                V_new[state] = grid.get_reward(state) + grid.gamma * q

        V = V_new

    return V


def policy_iteration(grid, threshold=1e-3):
    # Initialize policy
    policy = {(i, j): grid.actions[0] for i in range(grid.height) for j in range(grid.width) 
            if grid.is_valid_state((i, j))}
    
    iteration = 0
    while True:
        # Policy evaluation
        V = policy_evaluation(grid, policy)
        
        # Policy improvement
        policy_stable = True
        for state in V:
            old_action = policy[state]
            max_q = float('-inf')
            best_action = None

            for action in grid.actions:
                q = 0
                for next_state, prob in grid.get_transition_probs(state, action):
                    q += prob * V[next_state]
                if q > max_q:
                    max_q = q
                    best_action = action
                    
            policy[state] = best_action
            if best_action != old_action:
                policy_stable = False
        
        iteration += 1
        if policy_stable:
            break
            
    return policy, V, iteration

In [8]:
policy, V, iterations = policy_iteration(grid)

print(f"\nConverged after {iterations} iterations")
print("\nFinal values:")
for i in range(grid.height):
    for j in range(grid.width):
        if not grid.is_valid_state((i, j)):
            print("   XXXXX ", end="")
        else:
            print(f" {V[(i, j)]:7.3f} ", end="")
    print()


Converged after 3 iterations

Final values:
   0.812    0.868    0.918    1.000 
   0.762    XXXXX    0.660   -1.000 
   0.705    0.655    0.611    0.387 
