In [1]:
import numpy as np

In [2]:
# Define the grid for gridworld
grid = (4, 4)

In [3]:
# Terminal states for the gridworld
terminal_states = [
    (0, 0),
    (3, 3)
]

In [4]:
# All states
states = [
    (i, j)
    for i in range(0, grid[0])
    for j in range(0, grid[1])
]

In [5]:
class Policy:
    def __init__(self,):
        # Define an action to take (0, 1, 2, 3) for each state
        # Initialize to a random action
        self.actions = {
            state: np.random.choice([0,1,2,3], size=1, p=[0.25, 0.25, 0.25, 0.25])[0] for state in states
        }
        
        # Define the values to be 0 in all states
        self.values = {
            state: 0 for state in states
        }
        
    def __call__(self, state):
        return self.actions[state]

In [6]:
def reward(state):
    # Reward is -1 in all states
    return -1 if state not in terminal_states else 0

In [7]:
def state_transition(state, action):
    # North is 0, East is 1, South is 2, West is 3
    if action == 0:
        state_ = (state[0]-1, state[1])
    elif action == 1:
        state_ = (state[0], state[1]+1)
    elif action == 2:
        state_ = (state[0]+1, state[1])
    elif action == 3:
        state_ = (state[0], state[1]-1)
        
    # Get reward for the next state
    r = reward(state_)
        
    # If next state is greater than the bounds, move in bounds
    state_ = (min(grid[0]-1, max(0, state_[0])), min(grid[1]-1, max(0, state_[1])))
    
    return state_, r

In [8]:
# Discount factor
discount = 0.9

In [9]:
# Policy evaluation
def policy_evaluation(policy):
    thresh = 1e-6
    delta = 1
    while delta > thresh:
        delta = 0
        
        # Loop over every state
        for state in states:
            # Skip terminal states
            if state in terminal_states:
                continue
            
            # Get the old value of that state
            old_value = policy.values[state]
            
            # Get the value of the state following the new policy
            action = policy(state)
            new_state, r = state_transition(state, action)
            new_value = r + discount * policy.values[new_state]
            
            # Update the value
            policy.values[state] = new_value
            
            # Update delta
            delta = max(delta, abs(old_value - new_value))

In [10]:
# Policy imporvemen
def policy_improvement(policy):
    stable = False
    num_iters = 0
    
    # Iterate until the policy is stable
    while not stable:
        stable = True
        
        # Update the value of each state by evaluating the policy
        policy_evaluation(policy)
    
        # Iterate over each state
        for state in states:
            # Skip terminal states
            if state in terminal_states:
                continue
            
            # Get the old action from the policy
            old_action = policy(state)
            
            # Get the action that maximizes the q funciton, that is the reward
            # plus discounted action
            q_values = []
            for action in [0, 1, 2, 3]:
                # Get the next state and reward
                next_state, r = state_transition(state, action)
                
                # Value of the next state
                v = policy.values[next_state]
                
                # q value for this action and state
                q_values.append(r + discount * v)
            new_action = np.argmax(q_values)
            
            # Update the policy
            policy.actions[state] = new_action
            
            # If the action changed, the policy is not stable
            if old_action != new_action:
                stable = False
                
        num_iters += 1
        
    print(f"Took {num_iters} iterations to converge")

In [11]:
policy = Policy()
policy_improvement(policy)

Took 4 iterations to converge


In [18]:
# Print actions on a grid
for i in range(0, grid[0]):
    # Action is N for terminal states
    for j in range(0, grid[1]):
        if (i, j) in terminal_states:
            print("N ", end=" ")
            continue
        action = policy.actions[(i, j)]
        # Change to arrow
        if action == 0:
            print("^", end=" ")
        elif action == 1:
            print(">", end=" ")
        elif action == 2:
            print("v", end=" ")
        elif action == 3:
            print("<", end=" ")
        print("", end=" ")
    print()

N  <  <  v  
^  ^  ^  v  
^  ^  >  v  
^  >  >  N  


In [19]:
# Print values in a grid to 2 decimal places
for i in range(0, grid[0]):
    for j in range(0, grid[1]):
        print(f"{policy.values[(i, j)]:.2f}", end=" ")
    print()

0.00 0.00 -1.00 -1.90 
0.00 -1.00 -1.90 -1.00 
-1.00 -1.90 -1.00 0.00 
-1.90 -1.00 0.00 0.00 
