In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [7]:
class GridWorld(object):
    
    def __init__(self):
        
        ### Attributes defining the Gridworld #######

        # Shape of the gridworld
        self.shape = (6,6)
        
        # Locations of the obstacles
        self.obstacle_locs = [(1,1),(2,3),(2,5),(3,1),(4,1),(4,2),(4,4)]
        
        # Locations for the absorbing states
        self.absorbing_locs = [(1,3),(4,3)]
        
        # Rewards for each of the absorbing states 
        self.special_rewards = [+10,-100] # Corresponds to each of the absorbing_locs
        
        # Reward for all the other states
        self.default_reward = -1
        
        # Starting location
        #TODO: randomize this
        self.starting_loc =  None
        
        # Action names
        self.action_names = ['N','E','S','W'] # Action 0 is 'N', 1 is 'E' and so on
        
        # Number of actions
        self.action_size = len(self.action_names)
        
        # Randomizing action results: [1 0 0 0] to no Noise in the action results.
        #TODO
        self.action_randomizing_array = [0.8, 0.1, 0.0 , 0.1]
        
        
    #takes list of locations and a loc and gives the index corresponding to input loc
    def loc_to_state(self,loc,locs):
            return locs.index(tuple(loc))
        
        
    #Check if a specific location is valid and not an obstacle
    def is_location(self,loc):
        if(loc[0] < 0 or loc[1] < 0 or loc[0] > self.shape[0] - 1 or loc[1] > self.shape[1] -1):
            return False
        elif(loc in self.obstacle_locs):
            return False
        else:
            return True
        
    def get_neighbour(self, loc,direction):
        
        i = loc[0]
        j = loc[1]
        
        nr = (i-1,j)
        ea = (i,j+1)
        so = (i+1,j)
        we = (i,j-1)
        
        #If the neighbour is a valid location, accept it, otherwise, stay put
        if (direction == 'nr' and self.is_location(nr)):
            return nr
        elif (direction == 'ea' and self.is_location(ea)):
            return ea
        elif (direction == 'so' and self.is_location(so)):
            return so
        elif (direction == 'we' and self.is_location(we)):
            return we
        else:
            #Default stay where you are
            return loc
        
    def get_topology(self)
    
        height = self.shape[0]
        width = self.shape[1]
        
        #index = 1
        locs = []
        neighbour_locs = []
        
        for in range(height):
            for j in range(width):
                
                #Get the location of each state
                loc = (i,j)
                
                #And append if is is a valid state i.e. not absorbing
                if (self.is_location(loc)):
                    locs.append(loc)
                    
                    #Get an array with all neighbours of each state, in terms of locations
                    local_neighbours = [self.get_neighbour(loc,direction) for direction in ['nr','ea','so','we']]
                    neighbour_locs.append(local_neighbours)
                    
        #translate neighbour lists from locations to states
        num_states = len(locs)
        state_neighbours = np.zeros((num_states),4)
        
        for state in range(num_states):
            
            for direction in range(4):
                #Get neighbour location
                nloc = neighbour_locs[state][direction]
                
                #Turn location into state number 
                nstate = loc_to_state(nloc,locs)
                
                #Insert into neighbour matrix
                state_neighbours[state][direction] = nstate
                
        #Translate absorbing locations into state indices
        
        
        

In [8]:
locs = [(0,0),(0,1),(1,0),(1,1)]
loc = (1,1)
GridWorld().loc_to_state(loc,locs)

3

In [10]:
GridWorld().is_location((3,3))

True

In [12]:
GridWorld().get_neighbour((2,1),'nr')

(2, 1)