In [3]:
import numpy as np

In [14]:
class grid_world_mdp(object):
    def __init__(self, grid_cols, grid_rows, actions, terminal_states=None, prob_func=None, reward_func=None):
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        
        self.n = grid_cols * grid_rows
        self.m = actions
        
        self.states = range(self.n)
        self.actions = range(self.m)
        
        grid = np.zeros((self.grid_rows, self.grid_cols))

        if self.m == 4:
            
            # Action of N, S, E, W.
            self.actions_to_idx = {(-1,0): 0, (1,0): 1, (0,1): 2, (0,-1): 3}
            
            self.idx_to_actions = {v:k for k,v in self.actions_to_idx.items()}    
            
        elif self.m == 8:
            
            # Action of N, NE, NW, S, SE, SW, E, W 
            self.actions_to_idx = {(-1,0): 0, (-1,1): 1, (-1,-1): 2, (1,0): 3, 
                              (1,1): 4, (1,-1): 5, (0,1): 6, (0,-1): 7}
            
            self.idx_to_actions = {v:k for k,v in self.actions_to_idx.items()}
            

        self.idx_to_states = {self.states[i]:(i/self.grid_rows, i%self.grid_rows) for i in range(self.n)}
        self.states_to_idx = {v:k for k,v in self.idx_to_states.items()}
        
        
        if terminal_states is None:
            self.terminal_states = []
        else:
            self.terminal_states = [0,15]
        
        self.create_prob_dist(prob_func)
        
        self.create_rewards(reward_func)
        
        self.check_valid_dist()
        
    
    def create_prob_dist(self, prob_func=None):
        """
        
        """
        
        if prob_func is None:
            self.get_prob_dist()
        else:
            self.P = prob_func(self)
        
    
    def get_prob_dist(self):
        """
        
        """
        
        self.P = np.zeros((self.n, self.m, self.n))
        
        for state in self.states: 
            for action in self.actions: 
                
                if state in self.terminal_states:
                    self.P[state, action, state] = 1
                    continue

                curr_pos = self.idx_to_states[state]

                new_pos = (curr_pos[0] + self.idx_to_actions[action][0], 
                           curr_pos[1] + self.idx_to_actions[action][1])

                if new_pos in self.states_to_idx:
                    new_state = self.states_to_idx[new_pos]
                    self.P[state, action, new_state] = 1

                else:
                    self.P[state, action, state] = 1
        
    
    def create_rewards(self, reward_func=None):
        """
        
        """
        
        if reward_func is None:
            self.get_rewards()
        else:
            self.R = reward_func(self)
        
    
    def get_rewards(self):
        """
        
        """
        
        self.R = -1*np.ones((self.n, self.m, self.n))

        for state in self.terminal_states:
            self.R[state] = 0
    
    
    def check_valid_dist(self):
        """
        
        """
        
        for state in xrange(self.n):
            for action in xrange(self.m):
                assert abs(sum(self.P[state, action, :]) - 1) < 1e-3, 'Transitions do not sum to 1'
    
    

In [15]:
class RL(object):
    def __init__(self, mdp):
        """
        
        """

        self.mdp = mdp
        
    
    def iterative_policy_evaluation(self, gamma=1):
        """
        
        """
        
        v = np.zeros((self.mdp.n,1))

        while True:

            delta = 0

            for s in self.mdp.states:            
                temp = v[s].copy()       

                v[s] = sum(1/float(self.mdp.m)*sum(self.mdp.P[s, a, s_new]*(self.mdp.R[s, a, s_new] + gamma*v[s_new]) 
                                          for s_new in self.mdp.states) for a in self.mdp.actions)

                delta = max(delta, abs(temp - v[s]))

            if delta < .0001:
                break
                
        policy = self.get_iterative_policy(v)
        
        return v, policy
        
    
    def get_iterative_policy(self, v):
        """
        
        """
        return []
        
        

In [24]:
mdp = grid_world_mdp(4,4,8, [0,15])

In [25]:
mdp.terminal_states

[0, 15]

In [26]:
rl_obj = RL(mdp)

In [27]:
v, policy = rl_obj.iterative_policy_evaluation()

In [28]:
v

array([[  0.        ],
       [-18.28478679],
       [-23.4273588 ],
       [-25.71291862],
       [-18.28481275],
       [-19.42762821],
       [-22.28462841],
       [-23.42741633],
       [-23.42741373],
       [-22.28465361],
       [-19.42767162],
       [-18.28487552],
       [-25.71299944],
       [-23.42746862],
       [-18.28489896],
       [  0.        ]])