## Policy Iteration
#### Bellman Expectation Equation
##### *Mintae Kim, Hybrid Robotics Lab, UC Berkeley*
##### 07/15/2023

### Import Packages



In [None]:
import numpy as np
from environment import GraphicDisplay, Env

### Implementation of Policy Iteration
```python
class PolicyIteration:
    def __init__(self, env):
        pass
    def policy_evaluation(self):
        pass
    def policy_improvement(self):
        pass
    def get_action(self, state):
        pass
    def get_policy(self, state):
        pass
    def get_value(self, state):
        pass
```

In [None]:
class PolicyIteration:
    def __init__(self, env):
        # Declare environment
        self.env = env
        # Initialize value function table (2D list)
        self.value_table = [[0.0] * env.width for _ in range(env.height)]
        # Initialize of policy: P(Up) = P(Down) = P(Left) = P(Right) = 0.25
        self.policy_table = [[[0.25, 0.25, 0.25, 0.25]] * env.width
                            for _ in range(env.height)]
        # Terminal state: Stops when the agent reaches at [2][2]
        self.policy_table[2][2] = []
        # Discount factor $\gamma$
        self.discount_factor = 0.9

    # Policy evaluation through Bellman Expectation Equation
    def policy_evaluation(self):
        # Initialize next value function
        next_value_table = [[0.00] * self.env.width
                           for _ in range(self.env.height)]

        # Compute value function by solving Bellman Expectation Equation for every state
        for state in self.env.get_all_states():
            value = 0.0
            # Terminal state value = 0
            if state == [2, 2]:
                next_value_table[state[0]][state[1]] = value
                continue

            # Bellman Expectation Equation
            for action in self.env.possible_actions:
                next_state = self.env.state_after_action(state, action)
                reward = self.env.get_reward(state, action)
                next_value = self.get_value(next_state)
                value += (self.get_policy(state)[action] *
                          (reward + self.discount_factor * next_value))

            next_value_table[state[0]][state[1]] = value

        self.value_table = next_value_table

    # Policy improvement with respect to the updated value function
    def policy_improvement(self):
        next_policy = self.policy_table

        # Compute policy for every state
        for state in self.env.get_all_states():
            if state == [2, 2]:
                continue
            
            value_list = []
            # Initialize policy to return (Up, Down, Left, Right)
            result = [0.0, 0.0, 0.0, 0.0]

            # Compute return (Reward + Discount Factor * Next State Value) for every action
            for index, action in enumerate(self.env.possible_actions):
                next_state = self.env.state_after_action(state, action)
                reward = self.env.get_reward(state, action)
                next_value = self.get_value(next_state)
                value = reward + self.discount_factor * next_value
                value_list.append(value)

            # Greedy policy improvement with respect to the actions with greatest return
            max_idx_list = np.argwhere(value_list == np.amax(value_list))
            max_idx_list = max_idx_list.flatten().tolist()
            prob = 1 / len(max_idx_list)

            for idx in max_idx_list:
                result[idx] = prob

            next_policy[state[0]][state[1]] = result

        self.policy_table = next_policy

    # Return random action when state is given
    def get_action(self, state):
        policy = self.get_policy(state)
        policy = np.array(policy)
        return np.random.choice(4, 1, p=policy)[0]

    # Return policy when state is given
    def get_policy(self, state):
        return self.policy_table[state[0]][state[1]]

    # Return state value 
    def get_value(self, state):
        return self.value_table[state[0]][state[1]]

### Execution

In [None]:
if __name__ == "__main__":
    env = Env()
    policy_iteration = PolicyIteration(env)
    grid_world = GraphicDisplay(policy_iteration)
    grid_world.mainloop()