## 动态规划


#### 1、创建环境

In [1]:
def get_state(row, col):
    if row != 3:
        return 'ground'
    if row == 3 and col == 0:
        return 'ground'
    if row == 3 and col == 11:
        return 'terminal'
    
    return 'trap'

get_state(0, 0)
    

'ground'

#### 2、动作空间

In [2]:
def move(row, col ,action):
    if get_state(row, col) in ['trap', 'terminal']:
        return row, col, 0
    
    if action == 0:
        row -= 1

    if action == 1:
        row += 1 

    if action == 2:
        col -= 1 

    if action == 3:
        col += 1

    row = max(0, row) 
    row = min(3, row) 
    col = max(0, col) 
    col = min(11, col) 
    reward = -1;
    if get_state(row, col) == 'trap':
        reward = -100
    return row, col, reward

#### 3、Q函数


In [3]:
import numpy as np
values = np.zeros([4, 12])
pi = np.ones([4, 12, 4]) * 0.25


def get_Qsa(row, col, action):
    next_row, next_col, reward = move(row, col, action)

    value = values[next_row, next_col] * 0.9

    if get_state(next_row, next_col) in ['trap', 'terminal']:
        value = 0

    return value +reward


##策略评估函数
def get_values():
    new_values = np.zeros([4, 12])
    for row in range(4):
        for col in range(12):
            action_value = np.zeros(4)

            for action in range(4):
                action_value[action] = get_Qsa(row, col, action)

            action_value *= pi[row, col]
            new_values[row, col] = action_value.sum()
    return new_values

#策略提升函数
def get_pi():
    new_pi = np.zeros([4, 12 ,4])

    for row in range(4):
        for col in range(12):
            action_value = np.zeros(4)

            for action in range(4):
                action_value[action] = get_Qsa(row, col, action)

            count = (action_value == action_value.max()).sum()

            for action in range(4):
                if action_value[action] == action_value.max():
                    new_pi[row, col, action] = 1 / count
                else:
                    new_pi[row, col, action] = 0
    return new_pi
# get_Qsa(0, 0, 0)
# get_values()
get_pi()

array([[[0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ]],

       [[0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.25      , 0.25      , 0.25      , 0.25      ],
        [0.2

#### 4、循环迭代策略评估和策略提升

In [4]:
for _ in range(10):
    for _ in range(100):
        values = get_values()
    pi = get_pi()

values, pi

(array([[-7.71232075, -7.45813417, -7.17570464, -6.86189404, -6.5132156 ,
         -6.12579511, -5.6953279 , -5.217031  , -4.68559   , -4.0951    ,
         -3.439     , -2.71      ],
        [-7.45813417, -7.17570464, -6.86189404, -6.5132156 , -6.12579511,
         -5.6953279 , -5.217031  , -4.68559   , -4.0951    , -3.439     ,
         -2.71      , -1.9       ],
        [-7.17570464, -6.86189404, -6.5132156 , -6.12579511, -5.6953279 ,
         -5.217031  , -4.68559   , -4.0951    , -3.439     , -2.71      ,
         -1.9       , -1.        ],
        [-7.45813417,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]]),
 array([[[0.  , 0.5 , 0.  , 0.5 ],
         [0.  , 0.5 , 0.  , 0.5 ],
         [0.  , 0.5 , 0.  , 0.5 ],
         [0.  , 0.5 , 0.  , 0.5 ],
         [0.  , 0.5 , 0.  , 0.5 ],
         [0.  , 0.5 , 0.  , 0.5 ],
         [0.  , 0.5 , 0.  , 0.5 ],
         [0.  , 0

##### 5、结果显示

In [5]:
#打印游戏，方便测试
def show(row, col, action):
    graph = [
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',
        '○', '○', '○', '○', '○', '❤'
    ]

    action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]

    graph[row * 12 + col] = action

    graph = ''.join(graph)

    for i in range(0, 4 * 12, 12):
        print(graph[i:i + 12])

from IPython import display
import time


def test():
    #起点在0,0
    row = 0
    col = 0

    #最多玩N步
    for _ in range(200):

        #选择一个动作
        action = np.random.choice(np.arange(4), size=1, p=pi[row, col])[0]

        #打印这个动作
        display.clear_output(wait=True)
        time.sleep(0.1)
        show(row, col, action)

        #执行动作
        row, col, reward = move(row, col, action)

        #获取当前状态，如果状态是终点或者掉陷阱则终止
        if get_state(row, col) in ['trap', 'terminal']:
            break


test()

□□□□□□□□□□□□
□□□□□□□□□□□□
□□□□□□□□□□□↓
□○○○○○○○○○○❤


In [6]:
for row in range(4):
    line = ''
    for col in range(12):
        action = pi[row, col].argmax()
        action = {0:'↑', 1:'↓', 2:'←', 3:"→"}[action]
        line += action
    print(line)

↓↓↓↓↓↓↓↓↓↓↓↓
↓↓↓↓↓↓↓↓↓↓↓↓
→→→→→→→→→→→↓
↑↑↑↑↑↑↑↑↑↑↑↑
