### Rational agent in a 3x5 grid world formulated as a Markov Decision Process

Description: the agent's goal is to gather resources by moving one block at a time, in any of the four directions in the grid (north, east, south, west). The agent will only be able to walk into the intended direction 80% of the time, and to either left or right of its original position with an equal probability of 10%. The agent stays in the same position when it bumps into a boundary wall. There are several types of resources and hazards in the grid world. The agent can assign reward values to certain goal states, such as being on a block to collect iron (a common resource), collect diamonds (a rare resource), fall into a pit (damages the agent) or stand in lava (kills the agent). Resources and hazards are modeled as goal states in this MDP.

Goal: the goal of this exercise is to find the optimal policy for the agent using policy iteration algorithm.

In [1]:
# Utility function that formats a single value
def format_value(v):
  if isinstance(v, str):
    return '{:2s}'.format(v)
  else:
    return (' ' if v >= 0 else '') + '{:.2f}'.format(float(v))

# Utility function that pretty prints the grid world
def print_grid(grid):          
    for row in grid:
        print('|'.join(format_value(v) for v in row))
    print()

In [9]:
# MDP definition
grid = [['' for j in range(5)] for i in range(3)]
grid[0][2] = 'I' # iron
grid[0][3] = 'L' # lava
grid[0][4] = 'D' # diamond
grid[1][2] = 'L' # lava
grid[1][3] = 'L' # lava
gamma = 0.9 # discount factor
rewards = {'D': 10, 'I': 3, 'L': -10, 'P': -3, '': 0} # rewards
actions = ['n', 'e', 's', 'w'] # actions
# print the grid world
print_grid(grid)

  |  |I |L |D 
  |  |L |L |  
  |  |  |  |  



In [3]:
# Reward function
def R(s, a, s_new):
    return rewards[grid[s_new[0]][s_new[1]]]

# Transition function
def T(s, a):
    (y,x) = s # grid location from state
    if grid[y][x] != '':
        return {}
    if a == 'n': # north
        P = [
            ((max(0,y-1), x), 0.8),
            ((y, min(4,x+1)), 0.1),
            ((y, max(0,x-1)), 0.1)
        ]
    elif a =='e': # east
        P = [
            ((y, min(4,x+1)), 0.8), 
            ((max(0,y-1), x), 0.1),
            ((min(2,y+1), x), 0.1)
        ]
    elif a == 's': # south
        P = [
            ((min(2,y+1), x), 0.8), 
            ((y, max(0,x-1)), 0.1), 
            ((y, min(4,x+1)), 0.1)
        ]
    elif a == 'w': # west
        P = [
            ((y, max(0,x-1)), 0.8), 
            ((min(2,y+1), x), 0.1),
            ((max(0,y-1), x), 0.1)
        ]
    
    ret = {}
    for k,p in P:
        ret[k] = p + ret.get(k, 0)
    return ret

# Q-value calculation function
def q_value(V, s, a, gamma):
    return sum(
        T(s,a)[sn] * (R(s,a,sn) + gamma * V[sn[0]][sn[1]]) for sn in T(s,a)
        )

# Value update function
def value_update(V, P, gamma):
    return [
        [ q_value(V, (i,j), P[i][j], gamma)
            for j in range(len(V[i]))
        ] for i in range(len(V))
    ]

# Policy update function
def policy_update(V, gamma):
    return [
        [ max((q_value(V, (i,j), a, gamma),a) for a in actions)[1]
            for j in range(len(V[i]))
        ] for i in range(len(V))
    ]

In [10]:
# Initial policy
Pi = {0: [
    ['n', 'e', 'n', 'n', 'n'],
    ['s', 'e', 'n', 'n', 'n'],
    ['e', 'n', 'n', 'w', 'n'],
]}
max_iters = 10 # max number of iterations
max_k = 20 # max number of value updates
conv_threshold = 0.1 # convergence threshold

# Policy iteration loop
for t in range(max_iters):
    Vk = {0: [[0 for j in range(len(grid[i]))] for i in range(len(grid)) ]}
    lastk = 0
    # Value update loop
    for k in range(max_k):
        Vk[k+1] = value_update(Vk[k], Pi[t], gamma)
        lastk = k+1
        delta = max(
                 max(
                  abs(Vk[k+1][i][j] - Vk[k][i][j]) 
                  for j in range(len(Vk[k][i]))) 
                for i in range(len(Vk[k]))
        )
        print("V(iter={},k={}): (Delta={:.3f})".format(t+1,k+1,delta))
        print_grid(Vk[k+1])
        print()
        if delta < conv_threshold:
            break
        
    Pi[t+1] = policy_update(Vk[lastk], gamma)
    print("Pi(iter={})".format(t+1,))
    print_grid(Pi[t+1])
    print("-    -    -   -   -   -   -   -") # end of current iteration
    print()
    if all(
        all(Pi[t+1][i][j] == Pi[t][i][j] for j in range(len(Pi[t][i])))
        for i in range(len(Pi[t]))
    ):
        break

V(iter=1,k=1): (Delta=8.000)
 0.00| 2.40| 0.00| 0.00| 0.00
 0.00|-8.00| 0.00| 0.00| 7.00
 0.00| 0.00|-8.00|-1.00| 0.00


V(iter=1,k=2): (Delta=6.480)
 0.22| 1.90| 0.00| 0.00| 0.00
-0.72|-7.78| 0.00| 0.00| 7.63
 0.00|-6.48|-8.09|-6.85| 4.95


V(iter=1,k=3): (Delta=4.730)
 0.35| 1.87| 0.00| 0.00| 0.00
-0.77|-8.41| 0.00| 0.00| 7.69
-4.73|-6.33|-9.20|-7.44| 5.32


V(iter=1,k=4): (Delta=3.467)
 0.45| 1.81| 0.00| 0.00| 0.00
-4.23|-8.40| 0.00| 0.00| 7.69
-5.05|-7.31|-9.24|-8.29| 5.34


V(iter=1,k=5): (Delta=1.045)
 0.53| 1.81| 0.00| 0.00| 0.00
-4.78|-8.49| 0.00| 0.00| 7.69
-6.10|-7.34|-9.40|-8.40| 5.27


V(iter=1,k=6): (Delta=0.810)
 0.59| 1.80| 0.00| 0.00| 0.00
-5.59|-8.50| 0.00| 0.00| 7.69
-6.26|-7.51|-9.42|-8.53| 5.26


V(iter=1,k=7): (Delta=0.214)
 0.64| 1.80| 0.00| 0.00| 0.00
-5.78|-8.51| 0.00| 0.00| 7.69
-6.47|-7.53|-9.44|-8.55| 5.24


V(iter=1,k=8): (Delta=0.173)
 0.68| 1.80| 0.00| 0.00| 0.00
-5.95|-8.52| 0.00| 0.00| 7.69
-6.52|-7.56|-9.45|-8.57| 5.24


V(iter=1,k=9): (Delta=0.051)
 0.