In [1]:
import numpy as np

In [3]:
import numpy as np

class GridworldMDP:
    def __init__(self, grid_size=4, gamma=0.9):
        self.grid_size = grid_size
        self.states = [(i, j) for i in range(grid_size) for j in range(grid_size)]
        self.actions = ['up', 'down', 'left', 'right']
        self.gamma = gamma
        self.rewards = self.create_rewards()
        self.transition_probs = self.create_transition_probs()

    def create_rewards(self):
        rewards = {}
        for state in self.states:
            rewards[state] = -1
        rewards[(0, self.grid_size - 1)] = 10  # Goal state
        rewards[(2, 2)] = -10  # Trap state
        return rewards

    def create_transition_probs(self):
        transition_probs = {}
        for state in self.states:
            transition_probs[state] = {}
            for action in self.actions:
                transition_probs[state][action] = self.get_next_state_probs(state, action)
        return transition_probs

    def get_next_state_probs(self, state, action):
        i, j = state
        if state == (0, self.grid_size - 1):  # Goal state
            return {(0, self.grid_size - 1): 1.0}

        if action == 'up':
            next_state = (max(i - 1, 0), j)
        elif action == 'down':
            next_state = (min(i + 1, self.grid_size - 1), j)
        elif action == 'left':
            next_state = (i, max(j - 1, 0))
        elif action == 'right':
            next_state = (i, min(j + 1, self.grid_size - 1))

        return {next_state: 1.0}

    def policy_iteration(self):
        # Step 1: Initialize random policy
        policy = {state: np.random.choice(self.actions) for state in self.states}
        V = {state: 0 for state in self.states}

        while True:
            # Step 2: Policy Evaluation
            while True:
                delta = 0
                for state in self.states:
                    old_value = V[state]
                    action = policy[state]
                    new_value = 0
                    for next_state, prob in self.transition_probs[state][action].items():
                        new_value += prob * (self.rewards[next_state] + self.gamma * V[next_state])
                    V[state] = new_value
                    delta = max(delta, abs(old_value - new_value))
                if delta < 1e-3:
                    break

            # Step 3: Policy Improvement
            policy_stable = True
            for state in self.states:
                old_action = policy[state]
                action_values = {}
                for action in self.actions:
                    action_value = 0
                    for next_state, prob in self.transition_probs[state][action].items():
                        action_value += prob * (self.rewards[next_state] + self.gamma * V[next_state])
                    action_values[action] = action_value
                best_action = max(action_values, key=action_values.get)
                policy[state] = best_action
                if old_action != best_action:
                    policy_stable = False

            if policy_stable:
                break

        return V, policy

    def print_policy(self, policy):
        grid_policy = np.zeros((self.grid_size, self.grid_size), dtype=str)
        for state, action in policy.items():
            i, j = state
            grid_policy[i, j] = action[0]  # First letter of the action
        print("Optimal Policy:")
        for row in grid_policy:
            print(' '.join(row))


# Create Gridworld MDP
gridworld = GridworldMDP(grid_size=4, gamma=0.9)

# Solve MDP using Policy Iteration
optimal_values, optimal_policy = gridworld.policy_iteration()

# Print the results
print("Optimal Value Function:")
for state, value in sorted(optimal_values.items()):
    print(f"State {state}: {value:.2f}")

print("\nOptimal Policy:")
gridworld.print_policy(optimal_policy)


Optimal Value Function:
State (0, 0): 79.10
State (0, 1): 89.00
State (0, 2): 100.00
State (0, 3): 100.00
State (1, 0): 70.19
State (1, 1): 79.10
State (1, 2): 89.00
State (1, 3): 100.00
State (2, 0): 62.17
State (2, 1): 70.19
State (2, 2): 79.10
State (2, 3): 89.00
State (3, 0): 54.95
State (3, 1): 62.17
State (3, 2): 70.19
State (3, 3): 79.10

Optimal Policy:
Optimal Policy:
r r r u
r r r u
r u r u
r r r u


In [2]:
def policy_iteration(env, eps=0.1, gamma=1):
  #Initialization
  np.random.seed(1)
  states = env.state_space
  actions = env.action_space
  policy = {s: {np.random.choice(actions):1} for s in states}
  v = {s: 0 for s in states}


  while True:
    #Policy Evaluation
    v = policy_evaluation(env, policy, v=v, eps=eps, gamma=gamma)

    old_policy = policy
    policy = {}

    # Policy Improvement
    for s in states:
      policy[s], _ = policy_improvement(env, v, s, actions, gamma)

    if old_policy == policy:
      break
  print("Optimal policy found!")
  return policy, v