In [179]:
from abc import ABC, abstractmethod
import random
import numpy as np
import copy

In [18]:
class DiscreteEnvironment(ABC):
    
    @abstractmethod
    def step(self,action):
        pass
    
    @abstractmethod
    def reset(self):
        pass
    
class DiscreteAgent(ABC):
    @abstractmethod
    def __init__(self,env):
        pass
    
    @abstractmethod
    def update(self):
        pass
    
    @abstractmethod
    def get_action(self,state):
        pass

In [376]:
class GridWorld(DiscreteEnvironment):
    def __init__(self,r,c):
        self.maxrow = r
        self.maxcol = c
        self.tot_act = 4
        self.actions = [0,1,2,3]
        self.reset()  
        
    def foul_state(self,row,col):
        if row < 0 or row >= self.maxrow or col < 0 or col >= self.maxcol:
            return 1
        return 0
    
    def step(self,action):
        row_copy = self.navigate_row
        col_copy = self.navigate_col
        
        if self.navigate_row == self.goal_row and self.navigate_col == self.goal_col:
            reward = 10
            done = True
            state = (self.navigate_row,self.navigate_col)
        else:
            if action == 0:
                self.navigate_row -= 1
            elif action == 1:
                self.navigate_row += 1
            elif action == 2:
                self.navigate_col -= 1
            elif action == 3:
                self.navigate_col += 1

            if self.foul_state(self.navigate_row,self.navigate_col): 
                reward = -10
                done = True
                state = (row_copy,col_copy)
                #print(state)
            elif self.navigate_row == self.goal_row and self.navigate_col == self.goal_col:
                reward = 10
                done = True
                state = (self.navigate_row,self.navigate_col)
            else: 
                reward = 0
                done = False
                state = (self.navigate_row,self.navigate_col)
        
        return state,reward,done
    
    '''
    Used for dynamic programming methods where the start state is selected iteratively
    '''
    def set_start_state(self,state):
        self.navigate_row = state[0]
        self.navigate_col = state[1]
        
    '''
    Initialize random start and goal state
    '''
    def reset(self):
        
        self.start_row = 0#random.randint(0,self.maxrow-1)
        self.start_col = 0#random.randint(0,self.maxcol-1)
        self.goal_row = 7#random.randint(0,self.maxrow-1)
        self.goal_col = 7#random.randint(0,self.maxcol-1)
        self.navigate_row = self.start_row
        self.navigate_col = self.start_col

In [360]:
class ValueIteration(DiscreteAgent):
    def __init__(self,env):
        self.env = env
        self.gamma = 0.9
        self.Q = np.zeros((env.maxrow,env.maxcol,env.tot_act))
        self.V = np.zeros((env.maxrow,env.maxcol))
        
    def update(self):
        for row in range(len(self.V)):
            for col in range(len(self.V[row])):
                state=(row,col)
                for action in env.actions:
                    env.set_start_state(state)
                    new_state,reward,done = env.step(action)
                    #print(reward)
                    self.Q[state][action] = reward + self.gamma * self.V[new_state]
                self.V[state] = np.max(self.Q[state])
    
    def get_action(self):
        pass

In [438]:
class PolicyIteration(DiscreteAgent):
    def __init__(self,env):
        self.env = env
        self.gamma = 0.9
        self.Q = np.zeros((env.maxrow,env.maxcol,env.tot_act))
        self.V = np.zeros((env.maxrow,env.maxcol))
        self.policy = np.random.randint(env.tot_act,size=(env.maxrow,env.maxcol))
        self.policy_stable = False
    
    def update(self):
        while not self.policy_stable:
            self.policy_evaluation()
            self.policy_improvement()
    
    def policy_evaluation(self):
        eps = 1e-10
        while True:
            delta = 0
            for row in range(len(self.V)):
                for col in range(len(self.V[row])):
                    state=(row,col)
                    v = self.V[state]
                    env.set_start_state(state)
                    action = self.policy[state]
                    new_state,reward,done = env.step(action)
                    self.Q[state][action] = reward + self.gamma * self.V[new_state]
                    self.V[state] = self.Q[state][action]
                    delta = max(delta,np.abs(v-self.V[state]))
                    
            if delta < eps:
                break
            
    def policy_improvement(self):
        self.policy_stable = True
        for row in range(len(self.V)):
            for col in range(len(self.V[row])):
                state=(row,col)
                v = self.V[state]
                for action in env.actions:
                    env.set_start_state(state)
                    new_state,reward,done = env.step(action)
                    self.Q[state][action] = reward + self.gamma * self.V[new_state]
                #if v != np.max(self.Q[state]):
                self.V[state] = np.max(self.Q[state])
                if np.argmax(self.Q[state]) != self.policy[state]:
                    self.policy[state] = np.argmax(self.Q[state])
                    self.policy_stable = False
                    
    
    def get_action(self):
        pass

## Value Iteration

In [362]:
env = GridWorld(8,8)
print('start = ',env.start_row,' ',env.start_col)
print('goal = ',env.goal_row,' ',env.goal_col)
agent  = ValueIteration(env)


sweep_no,max_sweeps = 0,10000
while sweep_no < max_sweeps:
    #print('new_update')
    agent.update()
    sweep_no+=1

start =  0   0
goal =  7   7


In [363]:
agent.V

array([[ 32.87679245,  35.41865828,  38.24295365,  41.38105961,
         44.86784401,  48.7420489 ,  53.046721  ,  57.82969   ],
       [ 35.41865828,  38.24295365,  41.38105961,  44.86784401,
         48.7420489 ,  53.046721  ,  57.82969   ,  63.1441    ],
       [ 38.24295365,  41.38105961,  44.86784401,  48.7420489 ,
         53.046721  ,  57.82969   ,  63.1441    ,  69.049     ],
       [ 41.38105961,  44.86784401,  48.7420489 ,  53.046721  ,
         57.82969   ,  63.1441    ,  69.049     ,  75.61      ],
       [ 44.86784401,  48.7420489 ,  53.046721  ,  57.82969   ,
         63.1441    ,  69.049     ,  75.61      ,  82.9       ],
       [ 48.7420489 ,  53.046721  ,  57.82969   ,  63.1441    ,
         69.049     ,  75.61      ,  82.9       ,  91.        ],
       [ 53.046721  ,  57.82969   ,  63.1441    ,  69.049     ,
         75.61      ,  82.9       ,  91.        , 100.        ],
       [ 57.82969   ,  63.1441    ,  69.049     ,  75.61      ,
         82.9       ,  91.       

## Poilcy Iteration

In [439]:
env = GridWorld(8,8)
print('start = ',env.start_row,' ',env.start_col)
print('goal = ',env.goal_row,' ',env.goal_col)
agent  = PolicyIteration(env)
#print(agent.V)
#print(agent.Q)
#print(agent.policy)
agent.update()

start =  0   0
goal =  7   7


In [440]:
print(agent.policy)

[[1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1]
 [3 3 3 3 3 3 3 0]]
