In [26]:
#https://www.youtube.com/watch?v=iKdlKYG78j4
import numpy as np

num_rows = 11
num_cols = 11

epsilon = 0.9
discount_factor = 0.9
learning_rate = 0.9

q_values = np.zeros((num_rows, num_cols,4))
rewards = np.genfromtxt('rewards.csv', dtype=float, delimiter=',')

print(rewards)

[[-100. -100. -100. -100. -100.  100. -100. -100. -100. -100. -100.]
 [-100.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1. -100.]
 [-100.   -1. -100. -100. -100. -100. -100.   -1. -100.   -1. -100.]
 [-100.   -1.   -1.   -1.   -1.   -1.   -1.   -1. -100.   -1. -100.]
 [-100. -100. -100.   -1. -100. -100. -100.   -1. -100. -100. -100.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [-100. -100. -100. -100. -100.   -1. -100. -100. -100. -100. -100.]
 [-100.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1. -100.]
 [-100. -100. -100.   -1. -100. -100. -100.   -1. -100. -100. -100.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [-100. -100. -100. -100. -100. -100. -100. -100. -100. -100. -100.]]


In [27]:
#returns true if reward is 100 or -100 and false when reward is -1
def is_terminal_state(row_index,col_index):
    if rewards[row_index,col_index] == -1.:
        return False
    else:
        return True

#randomly picks a starting location that is not a terminal state
#modification from video: avoid repeat of randint
def get_starting_location():
    
    while True:
        row_index = np.random.randint(num_rows)
        col_index = np.random.randint(num_cols)
        
        if not is_terminal_state(row_index,col_index):
            break
    
    return row_index,col_index

#epsilon greedy implementation
#epsilon chance of following the action denoted by the highest q value
#otherwise explore randomly
def get_next_action(row_index,col_index):
    if np.random.random() < epsilon:
        return np.argmax(q_values[row_index,col_index,:])
    else:
        return np.random.randint(4)

#up right down left (0,1,2,3)
#check for out of bounds before adding or subtracting to move
#modification from video: use the action_index directly and do not store variables in new local variables
def get_next_location(row_index,col_index,action_index):
    if action_index == 0 and row_index > 0:
        row_index -= 1
    elif action_index == 1 and col_index < num_cols - 1:
        col_index += 1
    elif action_index == 2 and row_index < num_rows - 1:
        row_index += 1
    elif action_index == 3 and col_index > 0:
        col_index -= 1
    return row_index,col_index

#append current location to shortest path list
#the action to take is always argmax of the q_value at the location
#update the row_index, col_index and iterate until terminal state reached
def get_shortest_path(row_index,col_index):
    shortest_path = []
    
    while True: 
        shortest_path.append([row_index,col_index])
        
        if is_terminal_state(row_index,col_index):
            break
        
        action_index = np.argmax(q_values[row_index,col_index,:])
        row_index,col_index = get_next_location(row_index,col_index,action_index)
        
    return shortest_path

In [28]:
for episode in range(1000):
    row_index,col_index = get_starting_location()
    
    while not is_terminal_state(row_index,col_index):
        
        action_index = get_next_action(row_index,col_index)
        
        old_row_index,old_col_index = row_index,col_index
        old_q_value = q_values[row_index,col_index,action_index]
        
        row_index,col_index = get_next_location(row_index,col_index,action_index)
        reward = rewards[row_index,col_index]
        
        temporal_difference = reward + (discount_factor * np.max(q_values[row_index,col_index,:])) - old_q_value
        
        q_values[old_row_index,old_col_index, action_index] = old_q_value + learning_rate*temporal_difference

In [31]:
#paths are different from the ones in the video but are also valid optimal paths
print(get_shortest_path(3,9))
print(get_shortest_path(5,0))
print(get_shortest_path(9,5))

[[3, 9], [2, 9], [1, 9], [1, 8], [1, 7], [1, 6], [1, 5], [0, 5]]
[[5, 0], [5, 1], [5, 2], [5, 3], [4, 3], [3, 3], [3, 2], [3, 1], [2, 1], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [0, 5]]
[[9, 5], [9, 6], [9, 7], [8, 7], [7, 7], [7, 6], [7, 5], [6, 5], [5, 5], [5, 6], [5, 7], [4, 7], [3, 7], [2, 7], [1, 7], [1, 6], [1, 5], [0, 5]]
