In [297]:
import numpy as np
import random
import pandas as pd

np.set_printoptions(suppress=True)

In [298]:
#Grid parameters

grid_width = 10
grid_height = 7

starting_state = (0, 3)
terminal_state = (7, 3)

wind_levels = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]

In [299]:
#Action set
action_set = ["left-down", "left-middle", "left-up", \
              "middle-down", "middle-up", \
              "right-down", "right-middle", "right-up"]
action_set.append("middle-middle")

In [300]:
#Algo parameters
alpha = 0.5
eps = 0.01
num_episodes = 100000

In [310]:
#Initialze Q(S,A) randomly
q = np.random.rand(grid_width, grid_height, len(action_set))*-10

#Set Q for the terminal state to 0.
q[terminal_state[0], terminal_state[1], :] = 0

In [311]:
def action_effect(state, action, stochastic_wind=False):
    
    #Parse action
    horizontal_action, vertical_action = action.split("-")
    
    #Set wind
    wind_effect = wind_levels[state[0]]
    if stochastic_wind:
        r = np.random.rand()
        if r < 1/3:
            wind_effect -= 1
        elif r < 2/3:
            wind_effect += 1

    #Move horizontally
    if horizontal_action == "left":
        horizontal_position = max(state[0] - 1, 0)
    elif horizontal_action == "middle":
        horizontal_position = state[0]
    elif horizontal_action == "right":
        horizontal_position = min(state[0] + 1, grid_width - 1)

    #Move vertically and add wind
    if vertical_action == "down":
        vertical_position = min(max(state[1] - 1 + wind_effect, 0), grid_height - 1)
    elif vertical_action == "middle":
        vertical_position = min(state[1] + wind_effect, grid_height - 1)
    elif vertical_action == "up":
        vertical_position = min(state[1] + 1 + wind_effect, grid_height - 1)
    next_state = (horizontal_position, vertical_position)
    
    return next_state


In [312]:
action_effect((4,6), "right-up")

(5, 6)

In [313]:
def choose_action(state, greedy=False):
    best_action = action_set[q.argmax(axis=2)[state[0], state[1]]]
    if not greedy and np.random.rand() < eps:
        return random.choice(action_set)
    else:
        return best_action

In [314]:
choose_action((3,3))

'middle-middle'

In [315]:
#For each episode
for e in range(num_episodes):
    
    #Set starting state
    state = starting_state
    
    #Choose starting action from policy
    action = choose_action(state)
    
    r = 0
    while state != terminal_state:
        r -= 1
        
        #Take wind and then action A
        next_state = action_effect(state, action, stochastic_wind=False)

        #Choose A prime from policy
        next_action = choose_action(next_state)
        
        #Set Q(s,a)
        q[state[0], state[1], action_set.index(action)] = \
            q[state[0], state[1], action_set.index(action)] + \
            alpha * (-1 + q[next_state[0], next_state[1], action_set.index(next_action)] - \
                     q[state[0], state[1], action_set.index(action)])
        
        #Set S prime to S
        state = next_state
        
        #Set A prime to A
        action = next_action
    

In [316]:
#Set starting state
state = starting_state

#Choose starting action from policy
action = choose_action(state, greedy=True)

r = 0
while state != terminal_state:
    print(state)
    print(action)
    r -= 1

    #Take wind and then action A
    next_state = action_effect(state, action, stochastic_wind=False)

    #Choose A prime from policy
    next_action = choose_action(next_state, greedy=True)

#     #Set Q(s,a)
#     q[state[0], state[1], action_set.index(action)] = \
#         q[state[0], state[1], action_set.index(action)] + \
#         alpha * (-1 + q[next_state[0], next_state[1], action_set.index(next_action)] - \
#                  q[state[0], state[1], action_set.index(action)])

    #Set S prime to S
    state = next_state

    #Set A prime to A
    action = next_action
print(r)

(0, 3)
right-down
(1, 2)
right-down
(2, 1)
right-down
(3, 0)
right-down
(4, 0)
right-down
(5, 0)
right-middle
(6, 1)
right-middle
-7


In [317]:
q[1,3,:]

array([-9.68321763, -9.95923736, -9.7482902 , -9.24483275, -9.78070063,
       -7.96287472, -9.62146822, -9.66021521, -9.75145249])

In [318]:
display(pd.DataFrame(np.max(q, axis=2)).T)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,-9.312601,-7.699336,-5.085456,-4.0,-3.0,-2.0,-1.0,-1.926817,-2.842981,-3.006888
1,-8.218078,-6.635924,-5.0,-4.037982,-3.006361,-4.036025,-1.0,-3.065658,-0.786842,-3.742171
2,-8.236824,-6.050552,-7.506546,-9.267021,-7.101467,-4.311704,-1.0,-1.0,-2.000102,-4.592263
3,-7.252764,-7.962875,-8.528399,-8.519621,-9.021433,-8.001777,-7.007658,0.0,-1.0,-3.670505
4,-9.063312,-9.420858,-8.418935,-9.370724,-9.199839,-8.000229,-7.000027,-6.001216,-4.117522,-2.392487
5,-9.949292,-10.181505,-7.725756,-10.932075,-10.109218,-8.028639,-7.0,-6.000001,-5.000054,-4.096038
6,-10.816503,-10.941439,-10.052258,-11.510854,-9.424909,-8.082949,-7.033083,-6.0,-5.311064,-5.014615
