In [54]:
import numpy as np
import random

class BridgeMdp:
    def __init__(self, nc, nr, slip):
        self.nc = nc
        self.nr = nr
        self.slip = slip
    
    def states(self):
        """
        Returns the states of the system as integers.
        """
        return np.arange(1, (self.nc * self.nr) + 1)
    
    def get_block(self, s):
        """
        Given the state integer, return its co-ordinates.
        """
        x = (s - 1) // self.nc + 1
        y = (s - 1) % self.nc + 1
        
        return x, y
    
    def get_state_no(self, x, y):
        """
        Given co-ordinates, return the integer representation.
        """
        return ((x - 1) * self.nc) + y
    
    def get_actions(self, s):
        """
        Given the integer state, return all the actions that can be taken at that point.
        """
        x_co, y_co = self.get_block(s)
        
        actions = []
        if y_co - 1 >= 1:
            actions.append('left')
        if y_co <= self.nc:
            actions.append('right')
        if x_co - 1 >= 1:
            actions.append('up')
        if x_co <= self.nr:
            actions.append('down')
        
        return actions
    
    def failure_states(self):
        """
        Return the failure state of the system. The failure states are the integer representations.
        """
        return [3, 7]
    
    def success_states(self):
        """
        Return the success state of the system. The success states are the integer representations.
        """
        return [4]
    
    def is_goal(self, s):
        """
        Check the game over condition of the MDP.
        """
        if s in self.failure_states():
            return True
        if s in self.success_states():
            return True
        return False
    
    def reward(self, s, a, s_prime):
        """
        Check for the reward of the transition from s to s_prime using action a.
        """
        if s_prime in self.failure_states():
            return -50    # penalty for states
        if s_prime == 9:
            return 2    # the hard-coded state for minimal reward (3,1)
        if s_prime in self.success_states():
            return 20
        
        return 0   # the default reward
    
    def transition(self, s, a):
        pass
    
    def is_river(self,x,y):
            # neighbor configuration is left, right, up, down respectively
            neighbors = [self.get_state_no(x,y-1), self.get_state_no(x,y+1), self.get_state_no(x-1,y),self.get_state_no(x+1,y)]
            
            for state in neighbors:
                if state in self.failure_states():
                    return True
            
            return False
    
    def get_co_ord(self, a,x,y):
        if a == 'left':
            return x,y-1
        if a == 'right':
            return x, y+1
        if a == 'up':
            return x-1,y
        if a == 'down':
            return x+1,y

    def transition_probability(self, s, a, s_prime):
        """
        Return the probability of the new state 's_prime' given the action 'a' and state 's'.
        """
        x, y = self.get_block(s)
        river = self.is_river(x,y)
        
        if not river:   # it is not standing at the edge of the river
            x1, y1 = self.get_co_ord(a,x,y)
            new_s = self.get_state_no(x1,y1)
            if new_s == s_prime:
                return 1      # 60% chance of a valid transition
            
            return 0.0          # invalid transition
        else:
            # process the river neighbor case
            x1, y1 = self.get_co_ord(a,x,y)
            new_s = self.get_state_no(x1,y1)
            
            if new_s == s_prime:                            # Transitioning to a non-failure state
                return 0.6                                    # 60% chance of a valid transition
            elif s_prime in self.failure_states():                # Transitioning into the river
                return 0.4                                    # 40% chance of slipping into the river
            else:
                return 0.0  # Invalid transition        

In [55]:
import math

def policy_evaluation(mdp, gamma=0.05, epsilon=0.2):
    result_dict = {}  
    for state in mdp.states():
        result_dict[state] = 0

    for i in range(100):
        delta = 0
        for state in mdp.states():
            current = 0
            for action in mdp.get_actions(state):
                for s_prime in mdp.states():
                    current += mdp.transition_probability(state, action, s_prime) * \
                        (mdp.reward(state, action, s_prime) + result_dict[s_prime])

                    delta = max(delta, abs(current - result_dict[s_prime]))

                    result_dict[state] = math.ceil(current)
        if delta < epsilon:
            break
    temp = {}   # convert large arithmetic integers to scientific notation
    for k,v in result_dict.items():
        v = '{:e}'.format(v)
        temp[k] = v
    
    return temp

In [56]:
import math

def value_iteration(mdp, gamma=0.05, epsilon=0.2):
    value_dict = {}  # Initialize the value function dictionary
    for state in mdp.states():
        value_dict[state] = 0

    while True:
        delta = 0
        for state in mdp.states():
            v = value_dict[state]
            max_value = float("-inf")

            for action in mdp.get_actions(state):
                action_value = 0
                for s_prime in mdp.states():
                    action_value += mdp.transition_probability(state, action, s_prime) * \
                        (mdp.reward(state, action, s_prime) + gamma * value_dict[s_prime])

                max_value = max(max_value, action_value)

            value_dict[state] = max_value
            delta = max(delta, abs(v - value_dict[state]))

        if delta < epsilon:
            break

    return value_dict

In [59]:
mdp = BridgeMdp(nc=4,nr=3,slip=0.4)
# print(mdp.get_block(8))
# print(mdp.get_state_no(3,1))
# print(mdp.get_actions(1))
# print(mdp.transition_probability(2,'down',3))
print(value_iteration(mdp))

{1: 0.10025, 2: -41.45556450000001, 3: -30.699460800000004, 4: -41.400651216, 5: 2.0050125, 6: -41.400650841, 7: -42.38178504, 8: -30.703644453280003, 9: 0.100250625, 10: 2.00501253125, 11: -41.4014745408625, 12: 0.0}
