Implementation of policy iteration (Sutton and Barto, section 4.3, page 80)

In [2]:
import pandas as pd
import numpy as np

In [76]:
class Agent:

    '''
    Initializes the agent
    @param rows - The number of rows in the grid world
    @param columns - The number of columns in the grid world
    @param terminal_state - The final state the agent is trying to find the best path to
    '''
    def __init__(self, rows, columns, terminal_state):
        # self.grid = np.random.rand(rows, columns) * -1 # Initializes
        self.rows = rows
        self.columns = columns
        self.values = np.zeros((rows, columns)) # Initializes all state values to zero
        self.policies = self.initializePolicies() # Initializes policies for each cell,
                                                  # policies[row][column] is an array where the value
                                                  # at index 0 is the probability that up is the optimal choice,
                                                  # 1 is right, 2 is down, 3 is left (clockwise around Cartesian)
        self.terminal_state = terminal_state # Where the end state is located on the grid

    '''
    Initializes the policies for each cell in the grid world
    '''
    def initializePolicies(self):
        policies = np.empty((self.rows, self.columns, 4))

        # Fill the array with random values that sum to 1
        for i in range(self.rows):
            for j in range(self.columns):
                policies[i, j] = np.random.dirichlet(np.ones(4))

        return policies

    '''
    Performs a policy evaluation step, updates the policy
    @param theta - A value close to zero signifying completion, determines the accuracy of the estimation
    @param gamma - A value signifying by how much to discount future rewards 
    '''
    def policyEvaluation(self, theta, gamma):
        delta = 0 # This is the check to know when to stop evaluation
        for state in np.ndenumerate(self.values): # Iterate over states - in the form ((row, column), value)
            prev_value = self.values[state[0]] # Get value of the state
            successors = self.getAvailableSuccessorStates(state[0]) # Get possible next states
            
            value = 0
            for successor in range(len(successors)):
                if self.isValidState(successors[successor]):
                    value += 





    def isValidState(self, coords):
        return coords[0] >= 0 and coords[0] < self.columns and coords[1] >= 0 and coords[1] < self.rows

    '''
    Gets the available successor states (necessary because our grid world has walls)
    @param coords - coordinates in tuple form (x, y) of our state
    @return a list of successors
    '''
    def getAvailableSuccessorStates(self, coords):
        successors = []
        
        # if coords[1] != 0:
        successors.append((coords[0], coords[1] - 1))

        # if coords[0] != self.rows - 1:
        successors.append((coords[0] + 1, coords[1]))

        # if coords[1] != self.columns - 1:
        successors.append((coords[0], coords[1] + 1))

        # if coords[0] != 0:
        successors.append((coords[0] - 1, coords[1]))

        return successors
            
                

SyntaxError: invalid syntax (Temp/ipykernel_6264/1627684492.py, line 44)

In [65]:
a = Agent(3, 5, (2, 4))
a.terminal_state
a.policyEvaluation(0.01)

((0, 0), 0.0)
((0, 1), 0.0)
((0, 2), 0.0)
((0, 3), 0.0)
((0, 4), 0.0)
((1, 0), 0.0)
((1, 1), 0.0)
((1, 2), 0.0)
((1, 3), 0.0)
((1, 4), 0.0)
((2, 0), 0.0)
((2, 1), 0.0)
((2, 2), 0.0)
((2, 3), 0.0)
((2, 4), 0.0)


In [75]:
import numpy as np

# Set the seed for reproducibility
np.random.seed(0)

# Create the array with 3 rows and 5 columns
array = np.empty((3, 5, 4))

# Fill the array with random values that sum to 1
for i in range(3):
    for j in range(5):
        array[i, j] = np.random.dirichlet(np.ones(4))

print(array)
array[0][1]

[[[0.21154331 0.33382619 0.24539256 0.20923794]
  [0.12557359 0.23657699 0.13115001 0.50669941]
  [0.54164484 0.07901886 0.25635204 0.12298427]
  [0.23301182 0.72122972 0.02045376 0.0253047 ]
  [0.00381501 0.3338595  0.2812441  0.3810814 ]]

 [[0.50688886 0.21160757 0.08158917 0.1999144 ]
  [0.02999541 0.24340343 0.03687151 0.68972966]
  [0.24041855 0.1745123  0.10012628 0.48494287]
  [0.25067527 0.34584997 0.00780675 0.39566801]
  [0.15970671 0.16182343 0.48534736 0.19312251]]

 [[0.19555088 0.2521788  0.52500612 0.02726419]
  [0.42531765 0.42984014 0.09142027 0.05342194]
  [0.168217   0.20068307 0.37483209 0.25626784]
  [0.89585399 0.02164674 0.0471209  0.03537837]
  [0.46866893 0.12929399 0.2779681  0.12406898]]]


array([0.12557359, 0.23657699, 0.13115001, 0.50669941])