## 策略迭代
策略迭代也是求解贝尔曼最优公式的方法之一

相比于价值迭代，我们一开始先给出一个给定的策略`policy`，根据策略可以算出每个状态的`state value`；

当我们有了对应的`state value`后，我们就可以按照给定的价值来求出最佳的策略，这一步称为`policy improvement`

### 1. Policy evaluation (PE)
$$
v_{{\pi}_{k}} = r_{{\pi}_{k}} + \gamma P_{{\pi}_{k}}v_{{\pi}_{k}}
$$

注意：这里的 $ v_{{\pi}_{k}} $ 求解的是给定策略的贝尔曼方程得到的，所以是`state value`。

### 2. policy improvement (PI)
$$
\pi_{k+1} = \argmax_{\pi} (r_{\pi} + \gamma P_{\pi}v_{{\pi}_{k}})
$$

策略迭代的特点是越靠近目标的区域越先更新好，原因是离目标区域比较远的区域必须当它周围有能够到达目标区域的网格是它才能到达目标区域。

In [1]:
import numpy as np
import random
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
from GridWorld import GridWorld

c:\Users\callmest\.conda\envs\RBP-TSTL\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
c:\Users\callmest\.conda\envs\RBP-TSTL\lib\site-packages\numpy\.libs\libopenblas64__v0.3.23-gcc_10_3_0.dll


In [2]:
# discount rate, 范围[0, 1)， 0表示只关心当前reward，1表示关心所有future reward
gamma = 0.9 
rows = 5
cols = 5
# 加载网格世界
grid_world = GridWorld(rows, cols)
grid_world.show()

⬜️⬜️⬜️⬜️⬜️
⬜️⬜️⬜️⬜️⬜️
🚫⬜️⬜️⬜️⬜️
⬜️⬜️⬜️⬜️⬜️
🚫✅⬜️⬜️🚫


In [3]:
# state value, 初始化为0， 表示每个state的value
value = np.zeros(rows*cols) 
# action value, 初始化为0, 表示每个state的5个action的value
qtable = np.zeros((rows*cols, 5)) 
# policy, 初始化为随机策略, 注意这里是随机的，每个状态的策略是确定的初始策略，实际上在不同状态下，可以采取不同的策略
# 例如0.5的概率向上，0.5的概率向下
# 但是这里为了简单，直接随机初始化为确定的策略
policy = np.random.randint(0, 5, rows*cols) 
print(f'初始value: \n{value}')
print(f'初始policy: \n{policy}\n')
grid_world.show_policy_list(policy)

print('\n可以发现初始的策略是不好的')

初始value: 
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0.]
初始policy: 
[0 2 0 4 4 0 3 3 2 2 0 2 1 1 4 2 1 3 2 3 0 3 3 1 1]

现在的策略是：
⬆️⬇️⬆️🔄🔄
⬆️⬅️⬅️⬇️⬇️
⏫️⬇️➡️➡️🔄
⬇️➡️⬅️⬇️⬅️
⏫️✅⬅️➡️⏩️

可以发现初始的策略是不好的


In [4]:
## 设置截断条件，防止无限循环
cut = 0
cut_max = 1000
## 设置阈值，当value的变化小于阈值时，停止迭代
threshold = 1e-3
pre_value = value.copy() + 1

while np.sum((pre_value - value)**2) > threshold and cut < cut_max:
    pre_value = value.copy()
    
    # policy evaluation, 策略评估
    policy_value = value.copy() + 1

    # 设置一个截断条件，本质上属于truncate policy iteration
    truncated_cut = 10
    # 这里就是在用迭代的方式求解贝尔曼方程，根据给定的策略，通过上一步的value得到下一步的value，最终迭代得到state value
    while np.sum((policy_value - value)**2) > threshold:
        policy_value = value.copy()

        truncated_cut -= 1
        if truncated_cut < 0:
            break

        # 使用policy更新value
        for i in range(rows * cols):
            action = policy[i]
            score, next_state = grid_world.getScore(i, action)
            # 贝尔曼方程 = 立即能得到的reward：score 和 跳到下一个状态能得到的reward
            # 注意这里是使用的前面保存的policy_value, 因为value是我们要更新的
            value[i] = score + gamma * policy_value[next_state]


    # policy improvement
    # 更新qtable
    # 即遍历每一个状态的每一个动作，根据上面更新的value来计算action value
    for i in range(rows * cols):
        for j in range(5):
            score, next_state = grid_world.getScore(i, j)
            qtable[i][j] = score + gamma * value[next_state]

    # 然后我们就可以根据得到的qtable来选取最佳的策略进行策略更新
    policy = np.argmax(qtable, axis=1)

    grid_world.show_policy_list(policy=policy)

    print(f'当前各状态的state value: \n {value.reshape(rows, cols)}')

    cut = cut + 1

print(f'iteration times: {cut}')    
print(f'final value: \n {value.reshape(rows, cols)}')
print(f'final policy: \n {policy}')
grid_world.show_policy_list(policy)    

nowState: (0, 0), action: 0, nextState: (-1, 0)
nowState: (0, 1), action: 2, nextState: (1, 1)
nowState: (0, 2), action: 0, nextState: (-1, 2)
nowState: (0, 3), action: 4, nextState: (0, 3)
nowState: (0, 4), action: 4, nextState: (0, 4)
nowState: (1, 0), action: 0, nextState: (0, 0)
nowState: (1, 1), action: 3, nextState: (1, 0)
nowState: (1, 2), action: 3, nextState: (1, 1)
nowState: (1, 3), action: 2, nextState: (2, 3)
nowState: (1, 4), action: 2, nextState: (2, 4)
nowState: (2, 0), action: 0, nextState: (1, 0)
nowState: (2, 1), action: 2, nextState: (3, 1)
nowState: (2, 2), action: 1, nextState: (2, 3)
nowState: (2, 3), action: 1, nextState: (2, 4)
nowState: (2, 4), action: 4, nextState: (2, 4)
nowState: (3, 0), action: 2, nextState: (4, 0)
nowState: (3, 1), action: 1, nextState: (3, 2)
nowState: (3, 2), action: 3, nextState: (3, 1)
nowState: (3, 3), action: 2, nextState: (4, 3)
nowState: (3, 4), action: 3, nextState: (3, 3)
nowState: (4, 0), action: 0, nextState: (3, 0)
nowState: (