## 值迭代
值迭代是解决贝尔曼最优公式的解法之一

贝尔曼最优公式：
$$ v^{*} = r_{\pi}^{*} + \gamma P_{\pi}v^{*} $$

贝尔曼最优公式， 根据`contracting theorem`可以通过迭代的方式求解：
$$ v_{k+1} = f(v_k) = \max_{\pi}(r_{\pi} + \gamma P_{\pi}v_{k}) $$

我们可以通过两个步骤来求解这个方程式：

第一步：策略更新

可以随机的给出一个 $v_{k}$ 的值，然后解决一个优化问题：
$$
\pi_{k+1} = \argmax_{\pi}(r_{\pi} + \gamma P_{\pi}v_{k})
$$
即：
$$
\pi_{k+1}(s) = \argmax_{\pi} \sum_{a} \pi(a|s) \underbrace{(\sum_{r}p(r|s,a)r + \gamma \sum_{s^{\prime}}p(s^{\prime}|s,a)v_{k}(s^{\prime}))}_{q_\pi(a, s)}, s \in \mathbb{S}
$$
代入初始的`state value`后，对于每一个状态s，根据策略的不同，我可以算出每个策略所对应的`action value`，此时的
$\pi_{k+1}$ 即是我求出最大`action value`时对应的策略。

第二步：值更新

再根据更新的策略代入上面的式子更新 `state value`
$$
v_{k+1} = r_{{\pi}_{k+1}} + \gamma P_{{\pi}_{k+1}}v_{k}
$$
即：
$$
v_{k+1}(s) =  \sum_{a} \pi_{k+1}(a|s) \underbrace{(\sum_{r}p(r|s,a)r + \gamma \sum_{s^{\prime}}p(s^{\prime}|s,a)v_{k}(s^{\prime}))}_{q_\pi(a, s)}, s \in \mathbb{S}
$$
因为这里的 $\pi_{k+1}$ 是贪心策略求出的，所以第二步的值实际上对应的就是采取 $\pi_{k+1}$ 时的`action value`

由于我们这里的 $v_{k}$ 是随机初始化的，所以它并不是一个真实的`state value`只有当 $k \to \infty$的时候才可以称为`state value`

In [1]:
import numpy as np
import random
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
from GridWorld import GridWorld

c:\Users\callmest\.conda\envs\RBP-TSTL\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
c:\Users\callmest\.conda\envs\RBP-TSTL\lib\site-packages\numpy\.libs\libopenblas64__v0.3.23-gcc_10_3_0.dll


In [2]:
gamma = 0.9 # discount rate, 范围[0, 1)， 0表示只关心当前reward，1表示关心所有future reward
rows = 5
cols = 5
# 加载网格世界
grid_world = GridWorld(rows, cols)
grid_world.show()

⬜️⬜️⬜️⬜️⬜️
⬜️⬜️⬜️⬜️⬜️
🚫⬜️⬜️⬜️⬜️
⬜️⬜️⬜️⬜️⬜️
🚫✅⬜️⬜️🚫


In [3]:
value = np.zeros(rows*cols) # state value, 初始化为0， 表示每个state的value
qtable = np.zeros((rows*cols, 5)) # action value, 初始化为0, 表示每个state的5个action的value
policy = np.argmax(qtable, axis=1) # policy, 初始化为0，表示每个state的最优action, 这里的policy实际上是最优policy
print(f'初始value: {value}')
print(f'初始policy: {policy}')
grid_world.show_policy_list(policy)

初始value: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0.]
初始policy: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
⬆️⬆️⬆️⬆️⬆️
⬆️⬆️⬆️⬆️⬆️
⏫️⬆️⬆️⬆️⬆️
⬆️⬆️⬆️⬆️⬆️
⏫️✅⬆️⬆️⏫️


In [5]:
## policy iteration
# 1. initialize value
pre_value = value.copy() + 1

grid_world.show()
grid_world.show_policy_list(policy)

# 2. value iteration
## 设置截断条件，防止无限循环
cut = 0
cut_max = 1000
## 设置阈值，当value的变化小于阈值时，停止迭代
threshold = 1e-3

while np.sum((pre_value - value)**2) > threshold and cut < cut_max:
    print("Euclidean Distance:",np.sum((pre_value-value)**2))
    pre_value = value.copy()

    
    # 遍历每一个state
    for i in range(rows * cols):
        nowx = i // cols
        nowy = i % cols
        # 遍历每一个action
        for j in range(5):
            # 获取在当前state执行action j后的reward和下一个state
            reward, next_state = grid_world.getScore(i, j)
            # 更新qtable
            qtable[i, j] = reward + gamma * value[next_state]
    # 经过上面的循环，qtable已经更新完毕
    # 首先获取policy
    # 因为policy是按照贪心策略来选择的，所以只需要选择qtable中最大的值对应的action即可
    policy = np.argmax(qtable, axis=1)
    # 更新value
    # value是根据qtable来更新的，每个state的value是qtable中最大的值
    value = np.max(qtable, axis=1)

    print(f'value: \n {value}')
    grid_world.show_policy_list(policy)

    cut += 1
print(f'iteration times: {cut}')    
print(f'final value: \n {value.reshape(rows, cols)}')
print(f'final policy: \n {policy}')
grid_world.show_policy_list(policy)        


⬜️⬜️⬜️⬜️⬜️
⬜️⬜️⬜️⬜️⬜️
🚫⬜️⬜️⬜️⬜️
⬜️⬜️⬜️⬜️⬜️
🚫✅⬜️⬜️🚫
➡️⬇️⬇️⬇️⬇️
➡️⬇️⬇️⬇️⬇️
⏩️⬇️⬇️⬇️⬇️
➡️⬇️⬇️⬇️⬅️
⏩️✅⬅️⬅️⏪
Euclidean Distance: 25.0
nowState: (0, 0), action: 0, nextState: (-1, 0)
nowState: (0, 0), action: 1, nextState: (0, 1)
nowState: (0, 0), action: 2, nextState: (1, 0)
nowState: (0, 0), action: 3, nextState: (0, -1)
nowState: (0, 0), action: 4, nextState: (0, 0)
nowState: (0, 1), action: 0, nextState: (-1, 1)
nowState: (0, 1), action: 1, nextState: (0, 2)
nowState: (0, 1), action: 2, nextState: (1, 1)
nowState: (0, 1), action: 3, nextState: (0, 0)
nowState: (0, 1), action: 4, nextState: (0, 1)
nowState: (0, 2), action: 0, nextState: (-1, 2)
nowState: (0, 2), action: 1, nextState: (0, 3)
nowState: (0, 2), action: 2, nextState: (1, 2)
nowState: (0, 2), action: 3, nextState: (0, 1)
nowState: (0, 2), action: 4, nextState: (0, 2)
nowState: (0, 3), action: 0, nextState: (-1, 3)
nowState: (0, 3), action: 1, nextState: (0, 4)
nowState: (0, 3), action: 2, nextState: (1, 3)
nowState: (0, 3), ac