## Using Dynamic Programming to solve a gridworld problem 

In [1]:
import numpy as np
import matplotlib.pyplot as plt 

In [2]:
gridworld = np.arange(16).reshape((4, 4))
gridworld

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

In [3]:
r = -1  # return = -1 on all transitions 
actions = ['u', 'd', 'l', 'r']  # 4 possible actions 
values = {s:0 for s in gridworld.flatten()}
print(values)

policy = {}
for s in gridworld.flatten():
    if s in [0, 15]:
        continue
    else:
        policy[s] = {}
        for a in actions:
            policy[s][a] = 1/4 
policy 

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0}


{1: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 2: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 3: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 4: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 5: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 6: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 7: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 8: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 9: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 10: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 11: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 12: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 13: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 14: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25}}

In [4]:
def transition(s, a):
    """
    get s' when the agent is currently in s and takes action a 
    s (int, the current state number)
    a (str, one of the actions)

    return s' (int, the next state number)
    """

    current_loc = np.where(gridworld == s)    
    new_loc = [current_loc[0][0], current_loc[1][0]]

    if a == 'u':
        new_loc[0] = np.max([0, new_loc[0]-1])
    elif a == 'd':
        new_loc[0] = np.min([3, new_loc[0]+1])
    elif a == 'l':
        new_loc[1] = np.max([0, new_loc[1]-1])
    elif a == 'r':
        new_loc[1] = np.min([3, new_loc[1]+1])
    else:
        print('undefined actions')
    return gridworld[new_loc[0], new_loc[1]]

In [5]:
# policy evaluation 
def evaluate(s):
    val = 0
    for a in actions:
        s_new = transition(s, a)
        val += policy[s][a] * (r + values[s_new])
    return val 

# policy improvement
def policy_improvement():
    for s in policy.keys():
        vals_next = []
        for a in actions: 
            s_new = transition(s, a)
            vals_next.append(values[s_new])
        optimal_actions = np.argwhere(vals_next == np.amax(vals_next))
        for i in np.arange(len(actions)):
            if i not in optimal_actions:
                policy[s][actions[i]] = 0
            else:
                policy[s][actions[i]] = 1/len(optimal_actions)

In [6]:
# do a few rounds of policy evaluation and policy improvements 
for k in np.arange(20):
    for s in policy.keys():
        values[s] = evaluate(s)
    print(f'k = {k}, values: ')
    print(np.reshape(list(values.values()), (4, 4)))
    policy_improvement()
    print(f'policy: {policy}')

k = 0, values: 
[[ 0.        -1.        -1.25      -1.3125   ]
 [-1.        -1.5       -1.6875    -1.75     ]
 [-1.25      -1.6875    -1.84375   -1.8984375]
 [-1.3125    -1.75      -1.8984375  0.       ]]
policy: {1: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0}, 2: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0}, 3: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0}, 4: {'u': 1.0, 'd': 0, 'l': 0, 'r': 0}, 5: {'u': 0.5, 'd': 0, 'l': 0.5, 'r': 0}, 6: {'u': 1.0, 'd': 0, 'l': 0, 'r': 0}, 7: {'u': 1.0, 'd': 0, 'l': 0, 'r': 0}, 8: {'u': 1.0, 'd': 0, 'l': 0, 'r': 0}, 9: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0}, 10: {'u': 0.5, 'd': 0, 'l': 0.5, 'r': 0}, 11: {'u': 0, 'd': 1.0, 'l': 0, 'r': 0}, 12: {'u': 1.0, 'd': 0, 'l': 0, 'r': 0}, 13: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0}, 14: {'u': 0, 'd': 0, 'l': 0, 'r': 1.0}}
k = 1, values: 
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -4.]
 [-2. -3. -4. -1.]
 [-3. -4. -1.  0.]]
policy: {1: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0}, 2: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0}, 3: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0}, 4: {'u': 1.0

In [7]:
np.reshape(list(values.values()), (4, 4))

array([[ 0., -1., -2., -3.],
       [-1., -2., -3., -2.],
       [-2., -3., -2., -1.],
       [-3., -2., -1.,  0.]])

In [8]:
policy 

{1: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0},
 2: {'u': 0, 'd': 0, 'l': 1.0, 'r': 0},
 3: {'u': 0, 'd': 0.5, 'l': 0.5, 'r': 0},
 4: {'u': 1.0, 'd': 0, 'l': 0, 'r': 0},
 5: {'u': 0.5, 'd': 0, 'l': 0.5, 'r': 0},
 6: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 7: {'u': 0, 'd': 1.0, 'l': 0, 'r': 0},
 8: {'u': 1.0, 'd': 0, 'l': 0, 'r': 0},
 9: {'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25},
 10: {'u': 0, 'd': 0.5, 'l': 0, 'r': 0.5},
 11: {'u': 0, 'd': 1.0, 'l': 0, 'r': 0},
 12: {'u': 0.5, 'd': 0, 'l': 0, 'r': 0.5},
 13: {'u': 0, 'd': 0, 'l': 0, 'r': 1.0},
 14: {'u': 0, 'd': 0, 'l': 0, 'r': 1.0}}