# SARSA Learning

Originally from https://skettee.github.io/post/sarsa_learning/ (in Korean)

## Load Libraries and Extensions

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from IPython.display import display, clear_output, Pretty
import numpy as np
from time import sleep
from tqdm import tqdm_notebook as tqdm

import gym

## Frozen Lake Environment

In [3]:
ENV_NAME = 'FrozenLake8x8-v0'
N_STEP = 100

In [4]:
env = gym.make(ENV_NAME, is_slippery=False)
state = env.reset()

world = env.render(mode='ansi')
display(Pretty(world))
sleep(0.5)


[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG


In [5]:
for step in range(N_STEP):
    action =env.action_space.sample()
    next_state, reward, done, info = env.step(action)    
    state = next_state
    
    # updated world display
    world = env.render(mode='ansi')
    clear_output(wait=True)
    display(Pretty(world))
    sleep(0.5)
    
    if done: # an episode finished
        print("Episode finished after {} timesteps".format(step+1))
        break

  (Down)
SFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
F[41mH[0mHFFFHF
FHFFHFHF
FFFHFFFG


Episode finished after 48 timesteps


In [6]:
env.observation_space

Discrete(64)

There are 64 states as follows:

$S = \{0, 1, \cdots , 63\}$   

$\begin{vmatrix}
0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 \\
8 & 9 & 10 & 11 & 12 & 13 & 14 & 15 \\
16 & 17 & 18 & 19 & 20 & 21 & 22 & 23 \\
24 & 25 & 26 & 27 & 28 & 29 & 30 & 31 \\
32 & 33 & 34 & 35 & 36 & 37 & 38 & 39 \\
40 & 41 & 42 & 43 & 44 & 45 & 46 & 47 \\
48 & 49 & 50 & 51 & 52 & 53 & 54 & 55 \\
56 & 57 & 58 & 59 & 60 & 61 & 62 & 63
\end{vmatrix}$

There are 10 terminal states as follows:

$S_{\text{terminal}} = \{19, 29, 35, 41, 42, 49, 52, 54, 59, 63\}$

## SARSA Learning

$\begin{align}
q_{\pi}(s,a) & = \mathbb E_{\pi} [G_t | S_t = s, A_t = a] \\
&= \mathbb E_{\pi} [R_{t+1} + \gamma q_{\pi}(S_{t+1}, A_{t+1}) | S_t = s, A_t = a] 
\end{align}$

$q_*(s,a) = \max_{\pi} q_{\pi}(s,a)$  

$\pi_*(s,a) = \begin{cases}
1 & \text{if } a= \text{argmax}_{a \in A} q_\star(s,a) \\
0 & \text{otherwise}
\end{cases}$

$Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha \left( R_{t+1} + \gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t) \right)$

$\pi \leftarrow \epsilon \text{-greedy(Q)}$

In [7]:
n_state = env.observation_space.n
n_action = env.action_space.n
n_episode = 1000
GAMMA = .9
EPSILON = .3
ALPHA = .1

In [8]:
terminal_states = [19, 29, 35, 41, 42, 49, 52, 54, 59, 63]

Q_table = np.random.uniform(low=0.0, high=0.00000001, size=(n_state, n_action))
for s in terminal_states:
    Q_table[s] = 0
    
for episode in tqdm(range(n_episode)):
    state = env.reset()
    done = False
    
    if np.random.uniform() < EPSILON:
        action = env.action_space.sample()
    else:
        action = np.argmax(Q_table[state])
        
    while not done:
        next_state, reward, done, info = env.step(action)
        if np.random.uniform() < EPSILON:
            next_action = env.action_space.sample()
        else:
            next_action = np.argmax(Q_table[next_state])
        
        target = reward + GAMMA * Q_table[next_state, next_action]
        delta = target - Q_table[state][action]
        Q_table[state][action] += ALPHA * delta
        state, action = next_state, next_action

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




## Solution

In [9]:
state = env.reset()
done = False

world = env.render(mode='ansi')
display(Pretty(world))
sleep(.5)

while not done:
    action = np.argmax(Q_table[state])
    state, reward, done, info = env.step(action)
    
    world = env.render(mode='ansi')
    clear_output(wait=True)
    display(Pretty(world))
    sleep(.5)
    
    if done and state == 63:
        print('\nSuccess!')

  (Down)
SFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFF[41mG[0m



Success!
