# Windy Gridworld

Reproducing the results from Example 6.5, td control with SARSA

In [1]:
import numpy as np

In [2]:
class WindyGridworld():
    
    def __init__(self, size, start, end, wind):
        self._size = size
        self._start = start
        self._end = end
        self._wind = wind
        self._current_position = np.array(start)
        
    def reset(self):
        self._current_position = self._start
        return self._current_position
        
    def move(self, move):
        new_position = self._calc_new_position(move)
        new_position = self._cap_new_position(new_position)
        self._current_position = new_position
        is_over = self._is_game_over(new_position)
        return -1, new_position, is_over
    
    def _is_game_over(self, position):
        return (position == self._end).all()
        
    def _calc_new_position(self, move):
        new_position = self._current_position + np.array(move)
        random_number = np.random.rand()
        if random_number < 0.33:
            new_position[1] += self._wind[self._current_position[0]]-1
        elif random_number < 0.67:
            new_position[1] += self._wind[self._current_position[0]]
        else:
            new_position[1] += self._wind[self._current_position[0]]+1
        return new_position
        
    def _cap_new_position(self, new_position):
        new_position[0] = min(new_position[0], self._size[0]-1)
        new_position[0] = max(new_position[0], 0)
        new_position[1] = min(new_position[1], self._size[1]-1)
        new_position[1] = max(new_position[1], 0)
        return new_position
    

In [9]:
def choose_action(q, state, eps):
    if np.random.rand() < eps:
        return np.random.randint(8)
    else:
        return np.argmax(q[state[0], state[1],:])
    
def parse_action(action):
    parsed_actions = [[0,1],[1,1],[1,0],[1,-1],[0,-1],[-1,-1],[-1,0],[-1,1]]
    return parsed_actions[action]

In [14]:
q = np.zeros((10,7,8))
start = [0,3]
end = [7,3]
size = [10,7]
wind = [0,0,0,1,1,1,2,2,1,0]

env = WindyGridworld(size, start, end, wind)
for episode in range(2000):
    cnt = 0
    eps = 1 / 10
    if episode > 1900:
        eps = 0
    gamma = 0.9
    alpha = 0.5
    state = env.reset()
    done = False
    action = choose_action(q, state, eps)
    while not done:
        cnt += 1
        if episode == 999:
            print(state)
        r, next_state, done = env.move(parse_action(action))
        next_action = choose_action(q, next_state, eps)
        q[state[0], state[1], action] += alpha * (r + gamma * q[next_state[0], next_state[1], next_action] \
                                                  - q[state[0], state[1], action])
        state = next_state
        action = next_action
    if episode % 10 == 0:
        print('Episode length is:', cnt)

    


Episode length is: 288
Episode length is: 87
Episode length is: 30
Episode length is: 228
Episode length is: 24
Episode length is: 93
Episode length is: 30
Episode length is: 229
Episode length is: 61
Episode length is: 67
Episode length is: 31
Episode length is: 120
Episode length is: 183
Episode length is: 112
Episode length is: 71
Episode length is: 34
Episode length is: 11
Episode length is: 51
Episode length is: 103
Episode length is: 12
Episode length is: 28
Episode length is: 60
Episode length is: 21
Episode length is: 103
Episode length is: 10
Episode length is: 36
Episode length is: 66
Episode length is: 13
Episode length is: 17
Episode length is: 54
Episode length is: 30
Episode length is: 22
Episode length is: 14
Episode length is: 13
Episode length is: 24
Episode length is: 104
Episode length is: 92
Episode length is: 9
Episode length is: 59
Episode length is: 39
Episode length is: 64
Episode length is: 75
Episode length is: 44
Episode length is: 24
Episode length is: 27
Ep

In [None]:
def test_windy_grid_world():
    wgw = WindyGridworld((4,3),(0,0),(1,1),[0,1,2,0])
    assert(wgw.move((1,1) == (-1,(1,1),1)))
    assert(wgw.move((1,0) == (-1, (2,2),0)))
    assert(wgw.move((1,0) == (-1, (3,3),0)))
    
def test_all():
    test_windy_grid_world()
    
test_all()