Implementation of policy iteration (Sutton and Barto, section 4.3, page 80)

In [73]:
import pandas as pd
import numpy as np

In [126]:
class Agent:

    '''
    Initializes the agent
    @param rows - The number of rows in the grid world
    @param columns - The number of columns in the grid world
    @param terminal_state - The final state the agent is trying to find the best path to
    '''
    def __init__(self, rows, columns, terminal_state, theta=0.05, gamma=0.05):
        # self.grid = np.random.rand(rows, columns) * -1 # Initializes
        self.rows = rows
        self.columns = columns
        self.values = np.zeros((rows, columns)) # Initializes all state values to zero
        self.policies = self.initializePolicies()
        # self.policies = self.initializePolicies() # Initializes policies for each cell,
        #                                           # policies[row][column] is an array where the value
        #                                           # at index 0 is the probability that up is the optimal choice,
        #                                           # 1 is right, 2 is down, 3 is left (clockwise around Cartesian)
        self.terminal_state = terminal_state # Where the end state is located on the grid
        self.theta = theta # A value close to zero signifying completion, determines the accuracy of the policy estimation
        self.gamma = gamma # A value signifying by how much to discount future rewards 

    '''
    Initializes the policies for each cell in the grid world
    '''
    def initializePolicies(self):
        policies = np.empty((self.rows, self.columns, 4))

        # Fill the array with random values that sum to 1
        for i in range(self.rows):
            for j in range(self.columns):
                policies[i, j] = np.random.dirichlet(np.ones(4))
                # policies[i, j] = np.zeros(4)
                # policies[i, j][np.random.randint(4)] = 1

        return policies

    def isValidState(self, coords):
        return coords[0] >= 0 and coords[0] < self.rows and coords[1] >= 0 and coords[1] < self.columns

    '''
    Gets the four available successor states
    Note: some of the returned successor states may be invalid, a validity check is needed
    @param coords - coordinates in tuple form (x, y) of our state
    @return a list of successors
    '''
    def getAvailableSuccessorStates(self, coords):
        successors = []

        successors.append((coords[0] - 1, coords[1]))
        successors.append((coords[0], coords[1] + 1))
        successors.append((coords[0] + 1, coords[1]))
        successors.append((coords[0], coords[1] - 1))

        return successors

    '''
    Get the discounted value of each successor state
    @return a numpy array of successor state values
    '''
    def getDiscountedValuesForSuccessors(self, state):
        values = np.zeros(4)
        
        successors = self.getAvailableSuccessorStates(state) # Get possible next states
        
        value = 0
        for successor in range(len(successors)):
            if self.isValidState(successors[successor]):
                probability_successor_chosen = self.policies[state[0]][state[1]][successor]
                successor_value = self.getReward(successors[successor]) + self.gamma * self.values[successors[successor]]
                values[successor] = probability_successor_chosen * successor_value

        return values

    '''
    Update the policy based on the values of successor states
    '''
    def updatePolicy(self, coords):
        successors = self.getAvailableSuccessorStates(coords)
        values = np.ones(4) / 4
        for successor in range(len(successors)):
            if self.isValidState(successors[successor]):
                values[successor] = self.values[successors[successor][0]][successors[successor][1]] + self.getReward(successors[successor])
            else:
                values[successor] = 0
        
        sum = np.sum(values)
        policies = np.ones(4) / 4 if sum == 0 else values / sum

        self.policies[coords[0]][coords[1]] = policies

    '''
    Gets the reward of a particular state
    High reward if terminal state, no reward otherwise
    '''
    def getReward(self, state):
        if state[0] == self.terminal_state[0] and state[1] == self.terminal_state[1]: # Terminal state
            return 100
        return 0
        

    '''
    Performs a policy evaluation step, updates the policy
    '''
    def policyEvaluation(self):
        while True:
            delta = 0 # This is the check to know when to stop evaluation
            for coords, old_value in np.ndenumerate(self.values): # Iterate over states - in the form ((row, column), value)
                
                value = np.sum(self.getDiscountedValuesForSuccessors(coords)) # Cell value is equal to the sum of the discounted successor state values
                self.values[coords] = value

                delta = max(delta, np.absolute(old_value - value)) # Get the change in the value

            if delta < self.theta: # Check if the update was significant
                break
    
    '''
    Performs a policy improvement step
    '''
    def policyImprovement(self):
        while True:
            stable = True
            for coords, value in np.ndenumerate(self.values): # Iterate over all states
                old_action = np.argmax(self.policies[coords]) # Get the previous policy value
                self.updatePolicy(coords) # Update the policy based on the updated value function
                action = np.argmax(self.policies[coords]) # Get the best action according to our new policy
                if old_action != action: # If our policy has changed, it is unstable, so we perform a policy evaluation step
                    stable = False
            
            if not stable:
                self.policyEvaluation()
            else:
                break
    
    '''
    Outputs the policy in a readable format
    '''
    def outputPolicy(self):

        mapping = {0: "Up", 1: "Right", 2: "Down", 3: "Left"}

        output = []

        for row in range(len(self.policies)):
            row_output = []
            for cell in range(len(self.policies[row])):
                if row == self.terminal_state[0] and cell == self.terminal_state[1]:
                    print('{:^6s}'.format("End"), end="")
                else:
                    print('{:^6s}'.format(mapping[np.argmax(self.policies[row][cell])]), end="")
            print()

    def valueUpdate(self, state):
        values = np.zeros(4)
        
        successors = self.getAvailableSuccessorStates(state) # Get possible next states
        
        value = 0
        for successor in range(len(successors)):
            if self.isValidState(successors[successor]):
                probability_successor_chosen = self.policies[state[0]][state[1]][successor]
                successor_value = self.getReward(successors[successor]) + self.gamma * self.values[successors[successor]]
                values[successor] = probability_successor_chosen * successor_value

        return values

    '''
    Value iteration alternative to policy iteration
    '''
    def valueIteration(self):
        while True:
            delta = 0
            for coords, old_value in np.ndenumerate(self.values):
                value = np.amax(self.getDiscountedValuesForSuccessors(coords))
                print(coords, value)
                self.values[coords] = value
                delta = max(delta, np.absolute(old_value - value))
            
            print(delta)
            if delta < self.gamma:
                break   
        
        for coords, old_value in np.ndenumerate(self.values):
            self.updatePolicy(coords)


In [127]:
a = Agent(10, 10, (2, 4), gamma=0.01)

In [128]:
a.valueIteration()

(0, 0) 0.0
(0, 1) 0.0
(0, 2) 0.0
(0, 3) 0.0
(0, 4) 0.0
(0, 5) 0.0
(0, 6) 0.0
(0, 7) 0.0
(0, 8) 0.0
(0, 9) 0.0
(1, 0) 0.0
(1, 1) 0.0
(1, 2) 0.0
(1, 3) 0.0
(1, 4) 4.746914126632535
(1, 5) 0.0025031357712945266
(1, 6) 2.980615306205682e-06
(1, 7) 1.3811205569848539e-08
(1, 8) 2.2611883179468196e-11
(1, 9) 6.25756128325571e-14
(2, 0) 0.0
(2, 1) 0.0
(2, 2) 0.0
(2, 3) 41.101523724099195
(2, 4) 0.18458146978988968
(2, 5) 17.74019094641468
(2, 6) 0.02717611808760327
(2, 7) 2.6219955556362146e-06
(2, 8) 2.4953600857364143e-09
(2, 9) 2.5347430862524173e-12
(3, 0) 0.0
(3, 1) 0.0
(3, 2) 0.0
(3, 3) 0.20577853634363344
(3, 4) 6.138429192818769
(3, 5) 0.06796002793108488
(3, 6) 0.00027193125280036083
(3, 7) 2.455559642372791e-07
(3, 8) 5.360377012053366e-10
(3, 9) 6.905274609823678e-13
(4, 0) 0.0
(4, 1) 0.0
(4, 2) 0.0
(4, 3) 0.00022733316182068028
(4, 4) 0.016611700399258588
(4, 5) 0.0004152539434043595
(4, 6) 2.107482299406e-06
(4, 7) 4.121196838361005e-09
(4, 8) 1.240648840731412e-11
(4, 9) 3.73123

In [125]:
a.outputPolicy()

  Up    Up   Down  Down  Down  Down  Down  Left  Left  Left 
  Up  Right Right Right  Down  Left  Left  Left  Left  Left 
Right Right Right Right  End   Left  Left  Left  Left  Left 
Right Right Right Right   Up    Up   Left  Left  Left  Left 
Right Right   Up    Up    Up    Up   Left   Up   Left  Left 
Right Right Right Right   Up    Up   Left   Up    Up   Left 
Right   Up  Right Right   Up   Left  Left   Up   Left  Left 
Right Right Right Right   Up    Up   Left   Up    Up   Left 
Right Right   Up    Up    Up    Up   Left  Left   Up    Up  
Right Right   Up    Up    Up    Up    Up    Up    Up    Up  


In [109]:
a.policyImprovement()

In [110]:
a.values.round(1)

array([[ 0. ,  0. ,  0. ,  0. ,  1. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0.5, 97.2,  0.5,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  1. , 97.2,  0.2, 97.2,  1. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0.5, 97.2,  0.5,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0. ,  1. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ]])

In [111]:
a.outputPolicy()

 Down  Down Right Right  Down  Left  Left  Down  Down  Down 
 Down  Down  Down  Down  Down  Down  Down  Down  Down  Down 
Right Right Right Right  End   Left  Left  Left  Left  Left 
  Up    Up    Up    Up    Up    Up    Up    Up    Up    Up  
  Up    Up    Up  Right   Up   Left   Up    Up    Up    Up  
  Up    Up  Right Right   Up   Left  Left   Up    Up    Up  
  Up    Up  Right Right   Up   Left  Left  Left  Left   Up  
  Up    Up  Right Right   Up   Left  Left  Left  Left   Up  
  Up    Up  Right Right   Up   Left  Left  Left  Left  Left 
  Up    Up  Right Right   Up   Left  Left  Left  Left  Left 
