## 1. 原理解释

### Sarsa (state-action-reward-state-action) 

- 用来计算某一个状态下采取某一个动作的action value

- 注意：当我们进行 policy evaluation 得到 action value 后， 立刻进行 policy update

## 2. 代码详解

In [1]:
import numpy as np
import random
from IPython.display import clear_output
import sys, os
sys.path.append(os.path.dirname(os.getcwd()))
from GridWorld import GridWorld
from tqdm import tqdm
import time

c:\Users\callmest\.conda\envs\RBP-TSTL\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
c:\Users\callmest\.conda\envs\RBP-TSTL\lib\site-packages\numpy\.libs\libopenblas64__v0.3.23-gcc_10_3_0.dll


In [2]:
rows = 5
columns = 5
gridworld  = GridWorld(forbiddenAreaReward=-10, reward=1, desc=[".....", ".##..", "..#..", ".#T#.", ".#..."])
print('Initial Grid World')
gridworld.show()
policy = np.eye(5)[np.random.randint(0,5,size=(rows*columns))] 
print('Initial Policy')
gridworld.show_policy_matirx(policy)

value = np.zeros((rows * columns))
print(f'Initial State Value: {value}')

action_value = np.zeros((rows * columns, 5))
print(f'Initial Action Value: {action_value}')

# Hyperparameters
num_episodes = 1000
alpha = 0.1
gamma = 0.9
epsilon = 0.1

Initial Grid World
⬜️⬜️⬜️⬜️⬜️
⬜️🚫🚫⬜️⬜️
⬜️⬜️🚫⬜️⬜️
⬜️🚫✅🚫⬜️
⬜️🚫⬜️⬜️⬜️
Initial Policy
⬆️⬅️⬆️⬅️➡️
⬇️⏪⏩️⬅️🔄
⬇️⬇️⏫️⬆️⬆️
🔄⏪✅⏬➡️
⬅️⏬➡️⬅️🔄
Initial State Value: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0.]
Initial Action Value: [[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


In [3]:

for episode in range(num_episodes):
    clear_output(wait=True)
    print(f'episode: {episode} \ {num_episodes}')
    # 定义epsilon-greedy策略
    greedy_action_prob = 1 - epsilon * (4 / 5)
    non_greedy_action_prob = epsilon / 5
    action_dict = { 1: greedy_action_prob,
                   0: non_greedy_action_prob}
    # 这一步是根据epsilon-greedy策略赋予每个状态动作的概率
    policy_epsilon_greedy = np.vectorize(action_dict.get)(policy)
    # 检查每个状态被访问的次数
    state_visited = [0 for _ in range(rows * columns)]
    # 随机选取一个初始化状态和动作
    init_state = random.choice(range(rows * columns))
    init_action = random.choice(range(5))

    # 获取一条轨迹,收集数据，规定达到终点才算一条轨迹
    trajectory = gridworld.get_episode_score(now_state=init_state,
                                             action=init_action,
                                             policy=policy_epsilon_greedy,
                                             steps=-1,
                                             stop_when_reach_target=True)

    print(f'episode end, trajectory length: {len(trajectory)}')
    # 利用获取的action value更新policy
    steps = len(trajectory) - 1
    # 从后往前更新，减少计算量， 迭代式的求解方式
    for k in range(steps, -1, -1):
        last_state, last_action, reward, next_state, next_action = trajectory[k]
        # print(f'last_state: {last_state}, last_action: {last_action}, reward: {reward}, next_state: {next_state}, next_action: {next_action}')
        state_visited[last_state] += 1
        # 应用SARSA，注意这里直接是next_action
        TD_error = action_value[last_state][last_action] - (reward + gamma * action_value[next_state][next_action])
        action_value[last_state][last_action] -= alpha * TD_error

    # 更新policy 选取action value最大的动作作为policy
    # 用argmax来索引最大值. 并且用np.eye(5)来生成one-hot编码，即最大值的位置为1，其余为0
    policy = np.eye(5)[np.argmax(action_value, axis=1)]
    policy_epsilon_greedy = np.vectorize(action_dict.get)(policy)
    state_value = np.sum(action_value * policy_epsilon_greedy, axis=1)
    mean_state_value = np.mean(state_value)
    
    print(f'state value updated: \n{state_value}')
    print('policy updated')
    gridworld.show_policy_matirx(policy)
    time.sleep(0.2)
    

print('Final Policy')
gridworld.show_policy_matirx(policy)
print('Final State Value')
print(state_value)


episode: 286 \ 1000
episode end, trajectory length: 14
state value updated: 
[-6.11020297 -4.65421421 -4.3238123  -2.14133525 -2.67905476 -4.1471937
 -3.76379137 -1.04817291  0.44754809 -1.86193979 -3.34648771 -2.9077425
  5.8558343   1.89208948  2.42288588 -3.47177349  5.71492079  5.64838663
  5.58699491  3.03196797 -4.89566669  1.94418577  5.96934303  4.98459251
  3.51467724]
policy updated
⬇️➡️⬅️⬇️🔄
⬇️⏬⏩️⬇️⬅️
➡️🔄⏬➡️⬇️
⬆️⏩️✅⏪⬇️
⬆️⏩️⬆️⬅️⬅️


KeyboardInterrupt: 