In [1]:
import numpy as np
import constants as ct

from planning_utils import create_grid

In [2]:
TARGET_ALTITUDE = 5
SAFETY_DISTANCE = 5
source = (10, 700)
target = (800, 150)
terminal_states = [target]
delta = 10

# Read in obstacle map
data = np.loadtxt(ct.COLLIDERS_FILE, delimiter=',', dtype='Float64', skiprows=2)
grid, north_offset, east_offset = create_grid(data, TARGET_ALTITUDE, SAFETY_DISTANCE)

# Transform grid to MDP representation
grid[grid == 1] = None
grid[grid == 0] = -0.02
grid[target] = 1
grid

array([[  nan,   nan,   nan, ..., -0.02, -0.02, -0.02],
       [  nan,   nan,   nan, ..., -0.02, -0.02, -0.02],
       [  nan,   nan,   nan, ..., -0.02, -0.02, -0.02],
       ...,
       [-0.02, -0.02, -0.02, ..., -0.02, -0.02, -0.02],
       [-0.02, -0.02, -0.02, ..., -0.02, -0.02, -0.02],
       [-0.02, -0.02, -0.02, ..., -0.02, -0.02, -0.02]])

In [3]:
# grid = np.full((3, 4), -0.02)
# terminal_states = [(2, 3), (1, 3)]
# grid[1, 1] = None
# grid[terminal_states[0]] = 1
# grid[terminal_states[1]] = -1
# grid

In [4]:
# Extract states, rewards from grid
states = set()
rewards = dict()
optimal_policy = dict()

numRows, numCols = grid.shape
for rowNum in range(numRows):
    for colNum in range(numCols):
        rewards[rowNum, colNum] = grid[rowNum, colNum]
        if not np.isnan(grid[rowNum, colNum]):
                    states.add((rowNum, colNum))
                
print('States:\n', states, len(states))
print('Rewards:\n', rewards, len(rewards))

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [5]:
# (0, 1) -> next col, (0, -1) -> prev col, (-1, 0) -> prev row, (1, 0) -> next row
possible_actions = [(0, 1), (0, -1), (-1, 0), (1, 0)]

def valid_actions(state):
    actions = []
    for action in possible_actions:
        new_state = (state[0] + action[0], state[1] + action[1])
        if new_state in states:
            actions.append(action)
    return actions

def transition(state, action):
    return [ \
        (0.8 if valid_action == action else 0.2 / 3.0, \
         (state[0] + valid_action[0], state[1] + valid_action[1]) \
        ) for valid_action in valid_actions(state)]

In [6]:
def value_iteration(discount_factor=0.9, epsilon=0.001):
    utilities = np.zeros((numRows, numCols))
    itr = 0
    while True:
        itr += 1
        prev_utilities = utilities.copy()
        delta = 0.0
        for state in states:
            utility_row_sums = [sum([prob * prev_utilities[new_state] \
                                     for prob, new_state in transition(state, action)]) \
                                for action in valid_actions(state)]
            utilities[state] = rewards[state] + discount_factor * max(utility_row_sums)
            
        delta = max(delta, abs(utilities[state] - prev_utilities[state]))
        print(itr, delta, epsilon * (1 - discount_factor) / discount_factor)
        #check for convergence, if values converged then return V
        if delta < epsilon * (1 - discount_factor) / discount_factor:
            return utilities

In [7]:
final_utilities = value_iteration()
final_utilities

1 0.02 0.00011111111111111109
2 0.018 0.00011111111111111109
3 0.0162 0.00011111111111111109
4 0.014579999999999996 0.00011111111111111109
5 0.012674102400000012 0.00011111111111111109
6 0.010896088895999995 0.00011111111111111109
7 0.009292114402560017 0.00011111111111111109
8 0.007894749641241608 0.00011111111111111109
9 0.006692678646022646 0.00011111111111111109
10 0.005667026235944617 0.00011111111111111109
11 0.0047953587266558045 0.00011111111111111109
12 0.0040562674331346416 0.00011111111111111109
13 0.0034303659557510524 0.00011111111111111109
14 0.0029006963080860926 0.00011111111111111109
15 0.00245264165187023 0.00011111111111111109
16 0.002073712959523216 0.00011111111111111109
17 0.0017420016705675945 0.00011111111111111109
18 0.0014721346702979698 0.00011111111111111109
19 0.0012228899355914424 0.00011111111111111109
20 0.001032558723171817 0.00011111111111111109
21 0.0008475995444058304 0.00011111111111111109
22 0.000715001873098059 0.00011111111111111109
23 0.00058099

array([[ 0.        ,  0.        ,  0.        , ..., -0.1140749 ,
        -0.11054312, -0.10619256],
       [ 0.        ,  0.        ,  0.        , ..., -0.12506056,
        -0.1211941 , -0.11054312],
       [ 0.        ,  0.        ,  0.        , ..., -0.13392574,
        -0.12506056, -0.1140749 ],
       ...,
       [-0.1140749 , -0.12506055, -0.13392573, ..., -0.13392574,
        -0.12506056, -0.1140749 ],
       [-0.11054312, -0.1211941 , -0.12506056, ..., -0.12506056,
        -0.1211941 , -0.11054312],
       [-0.10619256, -0.11054312, -0.1140749 , ..., -0.1140749 ,
        -0.11054312, -0.10619256]])

In [8]:
def get_state_action_utility(action, state):
    return sum([prob * final_utilities[new_state] for prob, new_state in transition(state, action)])

for state in states:
    if state in terminal_states:
        continue
    valid_actions_for_state = valid_actions(state)
    state_action_utils = [get_state_action_utility(action, state) for action in valid_actions_for_state]
    max_action_util_index = state_action_utils.index(max(state_action_utils))
    optimal_policy[state] = valid_actions_for_state[max_action_util_index]
    
# action -> meaning: (0, 1) -> next col, (0, -1) -> prev col, (-1, 0) -> prev row, (1, 0) -> next row
# state, action pairs
optimal_policy

{(102, 595): (1, 0),
 (621, 577): (-1, 0),
 (727, 68): (-1, 0),
 (342, 418): (0, -1),
 (697, 145): (1, 0),
 (418, 370): (0, 1),
 (367, 853): (0, -1),
 (5, 178): (-1, 0),
 (722, 452): (-1, 0),
 (886, 687): (0, 1),
 (669, 203): (0, 1),
 (376, 760): (-1, 0),
 (747, 795): (0, 1),
 (694, 806): (0, 1),
 (164, 789): (0, 1),
 (0, 862): (0, 1),
 (25, 701): (-1, 0),
 (544, 667): (0, -1),
 (122, 635): (-1, 0),
 (69, 912): (1, 0),
 (468, 611): (-1, 0),
 (530, 418): (1, 0),
 (892, 107): (0, 1),
 (569, 56): (-1, 0),
 (454, 586): (0, 1),
 (915, 36): (1, 0),
 (809, 617): (-1, 0),
 (661, 520): (0, -1),
 (555, 125): (0, -1),
 (424, 521): (0, -1),
 (209, 569): (1, 0),
 (119, 90): (0, 1),
 (781, 755): (0, 1),
 (873, 594): (-1, 0),
 (820, 111): (0, 1),
 (89, 227): (-1, 0),
 (753, 132): (0, 1),
 (181, 691): (0, 1),
 (22, 761): (-1, 0),
 (421, 228): (-1, 0),
 (61, 861): (0, 1),
 (100, 805): (0, -1),
 (619, 735): (1, 0),
 (47, 798): (1, 0),
 (725, 310): (0, 1),
 (764, 284): (0, -1),
 (86, 240): (0, 1),
 (801,