# 使用AlphaZero算法打造属于你自己的象棋AI!

![](https://ai-studio-static-online.cdn.bcebos.com/70482fc2643343959250d272a3a252bcacadc365770b4bada4bd4f779e8b735e)

## 一、从AlphaGo到AlphaZero，实现一套通用的游戏AI训练算法

![](https://ai-studio-static-online.cdn.bcebos.com/96abb154ac8f45c2b154cf197b088edeb526dca97cb44a0e87819e1d10d22ec9)

* 谷歌最早期的AlphaGo在2016年就已经可以击败围棋职业选手，但是当将AlphaGo应用到其他游戏上时却有很多困难，如：需要专家精心设计特征，需要人类玩家对弈的数据进行监督训练等。在迭代更新之后，一套全新的通用演算法AlphaZero出现了，理论上，只要是双人完全信息博弈的游戏都可以籍由此一套算法来进行AI训练，所以AlphaZero是一套通用演算法(不考虑算力成本)。

* 对比于AlphaGo，AlphaZero主要对以下几点做了改进：

> 1、AlphaGo用到了人类专家对弈的棋谱进行训练，AlphaZero则完全从零进行自对弈训练。避免了找专家数据的麻烦。

> 2、AlphaGo用到了人类专家精心设计的特征，AlphaZero则只用到了棋盘的表示特征。让围棋小白也能训练出一个超强AI。

> 3、AlphaZero棋力更强，是一个通用算法，理论上双人完全信息博弈类游戏皆可籍由此一个算法搞定。

> 4、省去模型评估的部分，加速训练过程。

![](https://ai-studio-static-online.cdn.bcebos.com/ede8c74143f04aa6aadeb35695505d169641bcf3afaa4538aa98f6a481f4bfde)

如上图所示，AlphaZero通过几十万次的自我对弈，最终要强于其他游戏软件和AlphaGo。

## 二、本项目简介

1、本项目的代码从零编写过程录制成了视频上传到b站，想认真分析代码的同学可以进行参考。另外，项目简介同样也录制了视频。

2、本项目最好用多个进程进行训练，训练的方法如下：

* 在终端运行"**cd aichess**"将路径定位到aichess里，然后运行"**python collect.py**"进行自我对弈数据的收集。可以开四个这样的终端进行更快速的数据收集。

* 在终端运行"**cd aichess**"将路径定位到aichess里，然后运行"**python train.py**"进行模型的训练。这个只能开一个，因为我们的多进程是collect.py来实现的。

3、训练差不多了之后，就可以进行快乐的人机对战了，可以在aistudio上运行"play_with_ai.py"进行print模式的对弈，也可以把项目copy到本地运行"UIplay"进行UI界面的一个对战。

4、以下视频展示了训练了100次，蒙特卡洛搜索次数设置为4000次的一个人机对弈棋局，模型已经学会一步之内的将军防守，但不会预知两步之内的必杀将军。

<div align=center><img src="https://ai-studio-static-online.cdn.bcebos.com/3f018b0239ba42ff8324d4e9c417c96887fb952562c44c308024e55a1ee1949e"></div>

## 三、游戏棋盘表示

接下来我们编写游戏棋盘表示，虽然AlphaZero不需要人类添加的特征，但是需要知道走子的规则。在这个游戏棋盘表示里面，我们会实现棋盘类用于棋盘表示，和游戏类用于游戏逻辑的控制，包括自我对弈，人机对弈等。这个代码非常长，主要是象棋本身的逻辑比较复杂，所以一开始上手本项目可以先看蒙特卡洛、神经网络、自我对弈训练部分。


In [None]:
"""棋盘游戏控制"""


import numpy as np
import copy
import time
from config import CONFIG
from collections import deque   # 这个队列用来判断长将或长捉
import random


# 列表来表示棋盘，红方在上，黑方在下。使用时需要使用深拷贝
state_list_init = [['红车', '红马', '红象', '红士', '红帅', '红士', '红象', '红马', '红车'],
                   ['一一', '一一', '一一', '一一', '一一', '一一', '一一', '一一', '一一'],
                   ['一一', '红炮', '一一', '一一', '一一', '一一', '一一', '红炮', '一一'],
                   ['红兵', '一一', '红兵', '一一', '红兵', '一一', '红兵', '一一', '红兵'],
                   ['一一', '一一', '一一', '一一', '一一', '一一', '一一', '一一', '一一'],
                   ['一一', '一一', '一一', '一一', '一一', '一一', '一一', '一一', '一一'],
                   ['黑兵', '一一', '黑兵', '一一', '黑兵', '一一', '黑兵', '一一', '黑兵'],
                   ['一一', '黑炮', '一一', '一一', '一一', '一一', '一一', '黑炮', '一一'],
                   ['一一', '一一', '一一', '一一', '一一', '一一', '一一', '一一', '一一'],
                   ['黑车', '黑马', '黑象', '黑士', '黑帅', '黑士', '黑象', '黑马', '黑车']]


# deque来存储棋盘状态，长度为4
state_deque_init = deque(maxlen=4)
for _ in range(4):
    state_deque_init.append(copy.deepcopy(state_list_init))


# 构建一个字典：字符串到数组的映射，函数：数组到字符串的映射
string2array = dict(红车=np.array([1, 0, 0, 0, 0, 0, 0]), 红马=np.array([0, 1, 0, 0, 0, 0, 0]),
                    红象=np.array([0, 0, 1, 0, 0, 0, 0]), 红士=np.array([0, 0, 0, 1, 0, 0, 0]),
                    红帅=np.array([0, 0, 0, 0, 1, 0, 0]), 红炮=np.array([0, 0, 0, 0, 0, 1, 0]),
                    红兵=np.array([0, 0, 0, 0, 0, 0, 1]), 黑车=np.array([-1, 0, 0, 0, 0, 0, 0]),
                    黑马=np.array([0, -1, 0, 0, 0, 0, 0]), 黑象=np.array([0, 0, -1, 0, 0, 0, 0]),
                    黑士=np.array([0, 0, 0, -1, 0, 0, 0]), 黑帅=np.array([0, 0, 0, 0, -1, 0, 0]),
                    黑炮=np.array([0, 0, 0, 0, 0, -1, 0]), 黑兵=np.array([0, 0, 0, 0, 0, 0, -1]),
                    一一=np.array([0, 0, 0, 0, 0, 0, 0]))


def array2string(array):
    return list(filter(lambda string: (string2array[string] == array).all(), string2array))[0]


# 改变棋盘状态
def change_state(state_list, move):
    """move : 字符串'0010'"""
    copy_list = copy.deepcopy(state_list)
    y, x, toy, tox = int(move[0]), int(move[1]), int(move[2]), int(move[3])
    copy_list[toy][tox] = copy_list[y][x]
    copy_list[y][x] = '一一'
    return copy_list


# 打印盘面，可视化用到
def print_board(_state_array):
    # _state_array: [10, 9, 7], HWC
    board_line = []
    for i in range(10):
        for j in range(9):
            board_line.append(array2string(_state_array[i][j]))
        print(board_line)
        board_line.clear()


# 列表棋盘状态到数组棋盘状态
def state_list2state_array(state_list):
    _state_array = np.zeros([10, 9, 7])
    for i in range(10):
        for j in range(9):
            _state_array[i][j] = string2array[state_list[i][j]]
    return _state_array


# 拿到所有合法走子的集合，2086长度，也就是神经网络预测的走子概率向量的长度
# 第一个字典：move_id到move_action
# 第二个字典：move_action到move_id
# 例如：move_id:0 --> move_action:'0010'
def get_all_legal_moves():
    _move_id2move_action = {}
    _move_action2move_id = {}
    row = ['0', '1', '2', '3', '4', '5', '6', '7', '8']
    column = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    # 士的全部走法
    advisor_labels = ['0314', '1403', '0514', '1405', '2314', '1423', '2514', '1425',
                      '9384', '8493', '9584', '8495', '7384', '8473', '7584', '8475']
    # 象的全部走法
    bishop_labels = ['2002', '0220', '2042', '4220', '0224', '2402', '4224', '2442',
                     '2406', '0624', '2446', '4624', '0628', '2806', '4628', '2846',
                     '7052', '5270', '7092', '9270', '5274', '7452', '9274', '7492',
                     '7456', '5674', '7496', '9674', '5678', '7856', '9678', '7896']
    idx = 0
    for l1 in range(10):
        for n1 in range(9):
            destinations = [(t, n1) for t in range(10)] + \
                           [(l1, t) for t in range(9)] + \
                           [(l1 + a, n1 + b) for (a, b) in
                            [(-2, -1), (-1, -2), (-2, 1), (1, -2), (2, -1), (-1, 2), (2, 1), (1, 2)]]  # 马走日
            for (l2, n2) in destinations:
                if (l1, n1) != (l2, n2) and l2 in range(10) and n2 in range(9):
                    action = column[l1] + row[n1] + column[l2] + row[n2]
                    _move_id2move_action[idx] = action
                    _move_action2move_id[action] = idx
                    idx += 1

    for action in advisor_labels:
        _move_id2move_action[idx] = action
        _move_action2move_id[action] = idx
        idx += 1

    for action in bishop_labels:
        _move_id2move_action[idx] = action
        _move_action2move_id[action] = idx
        idx += 1

    return _move_id2move_action, _move_action2move_id


move_id2move_action, move_action2move_id = get_all_legal_moves()


# 走子翻转的函数，用来扩充我们的数据
def flip_map(string):
    new_str = ''
    for index in range(4):
        if index == 0 or index == 2:
            new_str += (str(string[index]))
        else:
            new_str += (str(8 - int(string[index])))
    return new_str


# 边界检查
def check_bounds(toY, toX):
    if toY in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and toX in [0, 1, 2, 3, 4, 5, 6, 7, 8]:
        return True
    return False


# 不能走到自己的棋子位置
def check_obstruct(piece, current_player_color):
    # 当走到的位置存在棋子的时候，进行一次判断
    if piece != '一一':
        if current_player_color == '红':
            if '黑' in piece:
                return True
            else:
                return False
        elif current_player_color == '黑':
            if '红' in piece:
                return True
            else:
                return False
    else:
        return True


# 得到当前盘面合法走子集合
# 输入状态队列不能小于10，current_player_color:当前玩家控制的棋子颜色
# 用来存放合法走子的列表，例如[0, 1, 2, 1089, 2085]
def get_legal_moves(state_deque, current_player_color):
    """
    ====
      将
    车
    ====
    ====
      将
      车
    ====
    ====
    将
      车
    ====
    ====
    将
    车
    ====
    ====
      将
    车
    ====
    这个时候，车就不能再往右走抓帅
    接下来不能走的动作是(1011)，因为将会盘面与state_deque[-4]重复
    """

    state_list = state_deque[-1]
    old_state_list = state_deque[-4]

    moves = []  # 用来存放所有合法的走子方法
    face_to_face = False  # 将军面对面

    # 记录将军的位置信息
    k_x = None
    k_y = None
    K_x = None
    K_y = None

    # state_list是以列表形式表示的, len(state_list) == 10, len(state_list[0]) == 9
    # 遍历移动初始位置
    for y in range(10):
        for x in range(9):
            # 只有是棋子才可以移动
            if state_list[y][x] == '一一':
                pass
            else:
                if state_list[y][x] == '黑车' and current_player_color == '黑':  # 黑车的合法走子
                    toY = y
                    for toX in range(x - 1, -1, -1):
                        # 前面是先前位置，后面是移动后的位置
                        # 这里通过中断for循环实现了车的走子，车不能越过子
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if state_list[toY][toX] != '一一':
                            if '红' in state_list[toY][toX]:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            break
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)
                    for toX in range(x + 1, 9):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if state_list[toY][toX] != '一一':
                            if '红' in state_list[toY][toX]:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            break
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)

                    toX = x
                    for toY in range(y - 1, -1, -1):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if state_list[toY][toX] != '一一':
                            if '红' in state_list[toY][toX]:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            break
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)
                    for toY in range(y + 1, 10):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if state_list[toY][toX] != '一一':
                            if '红' in state_list[toY][toX]:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            break
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)

                elif state_list[y][x] == '红车' and current_player_color == '红':  # 红车的合法走子
                    toY = y
                    for toX in range(x - 1, -1, -1):
                        # 前面是先前位置，后面是移动后的位置
                        # 这里通过中断for循环实现了，车不能越过子
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if state_list[toY][toX] != '一一':
                            if '黑' in state_list[toY][toX]:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            break
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)
                    for toX in range(x + 1, 9):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if state_list[toY][toX] != '一一':
                            if '黑' in state_list[toY][toX]:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            break
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)

                    toX = x
                    for toY in range(y - 1, -1, -1):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if state_list[toY][toX] != '一一':
                            if '黑' in state_list[toY][toX]:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            break
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)
                    for toY in range(y + 1, 10):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if state_list[toY][toX] != '一一':
                            if '黑' in state_list[toY][toX]:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            break
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)

                # 黑马的合理走法
                elif state_list[y][x] == '黑马' and current_player_color == '黑':
                    for i in range(-1, 3, 2):
                        for j in range(-1, 3, 2):
                            toY = y + 2 * i
                            toX = x + 1 * j
                            if check_bounds(toY, toX) \
                                    and check_obstruct(state_list[toY][toX], current_player_color='黑') \
                                    and state_list[toY - i][x] == '一一':
                                m = str(y) + str(x) + str(toY) + str(toX)
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            toY = y + 1 * i
                            toX = x + 2 * j
                            if check_bounds(toY, toX) \
                                    and check_obstruct(state_list[toY][toX], current_player_color='黑') \
                                    and state_list[y][toX - j] == '一一':
                                m = str(y) + str(x) + str(toY) + str(toX)
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)

                # 红马的合理走法
                elif state_list[y][x] == '红马' and current_player_color == '红':
                    for i in range(-1, 3, 2):
                        for j in range(-1, 3, 2):
                            toY = y + 2 * i
                            toX = x + 1 * j
                            if check_bounds(toY, toX) \
                                    and check_obstruct(state_list[toY][toX], current_player_color='红') \
                                    and state_list[toY - i][x] == '一一':
                                m = str(y) + str(x) + str(toY) + str(toX)
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                            toY = y + 1 * i
                            toX = x + 2 * j
                            if check_bounds(toY, toX) \
                                    and check_obstruct(state_list[toY][toX], current_player_color='红') \
                                    and state_list[y][toX - j] == '一一':
                                m = str(y) + str(x) + str(toY) + str(toX)
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)

                # 黑象的合理走法
                elif state_list[y][x] == '黑象' and current_player_color == '黑':
                    for i in range(-2, 3, 4):
                        toY = y + i
                        toX = x + i
                        if check_bounds(toY, toX) \
                                and check_obstruct(state_list[toY][toX], current_player_color='黑') \
                                and toY >= 5 and state_list[y + i // 2][x + i // 2] == '一一':
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)
                        toY = y + i
                        toX = x - i
                        if check_bounds(toY, toX) \
                                and check_obstruct(state_list[toY][toX], current_player_color='黑') \
                                and toY >= 5 and state_list[y + i // 2][x - i // 2] == '一一':
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)

                # 红象的合理走法
                elif state_list[y][x] == '红象' and current_player_color == '红':
                    for i in range(-2, 3, 4):
                        toY = y + i
                        toX = x + i
                        if check_bounds(toY, toX) \
                                and check_obstruct(state_list[toY][toX], current_player_color='红') \
                                and toY <= 4 and state_list[y + i // 2][x + i // 2] == '一一':
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)
                        toY = y + i
                        toX = x - i
                        if check_bounds(toY, toX) \
                                and check_obstruct(state_list[toY][toX], current_player_color='红') \
                                and toY <= 4 and state_list[y + i // 2][x - i // 2] == '一一':
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)

                # 黑士的合理走法
                elif state_list[y][x] == '黑士' and current_player_color == '黑':
                    for i in range(-1, 3, 2):
                        toY = y + i
                        toX = x + i
                        if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='黑') \
                                and toY >= 7 and 3 <= toX <= 5:
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)
                        toY = y + i
                        toX = x - i
                        if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='黑') \
                                and toY >= 7 and 3 <= toX <= 5:
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)

                # 红士的合理走法
                elif state_list[y][x] == '红士' and current_player_color == '红':
                    for i in range(-1, 3, 2):
                        toY = y + i
                        toX = x + i
                        if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='红') \
                                and toY <= 2 and 3 <= toX <= 5:
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)
                        toY = y + i
                        toX = x - i
                        if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='红') \
                                and toY <= 2 and 3 <= toX <= 5:
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)

                # 黑帅的合理走法
                elif state_list[y][x] == '黑帅':
                    k_x = x
                    k_y = y
                    if current_player_color == '黑':
                        for i in range(2):
                            for sign in range(-1, 2, 2):
                                j = 1 - i
                                toY = y + i * sign
                                toX = x + j * sign

                                if check_bounds(toY, toX) and check_obstruct(
                                        state_list[toY][toX], current_player_color='黑') and toY >= 7 and 3 <= toX <= 5:
                                    m = str(y) + str(x) + str(toY) + str(toX)
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)

                # 红帅的合理走法
                elif state_list[y][x] == '红帅':
                    K_x = x
                    K_y = y
                    if current_player_color == '红':
                        for i in range(2):
                            for sign in range(-1, 2, 2):
                                j = 1 - i
                                toY = y + i * sign
                                toX = x + j * sign

                                if check_bounds(toY, toX) and check_obstruct(
                                        state_list[toY][toX], current_player_color='红') and toY <= 2 and 3 <= toX <= 5:
                                    m = str(y) + str(x) + str(toY) + str(toX)
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)

                # 黑炮的合理走法
                elif state_list[y][x] == '黑炮' and current_player_color == '黑':
                    toY = y
                    hits = False
                    for toX in range(x - 1, -1, -1):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if hits is False:
                            if state_list[toY][toX] != '一一':
                                hits = True
                            else:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                        else:
                            if state_list[toY][toX] != '一一':
                                if '红' in state_list[toY][toX]:
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)
                                break
                    hits = False
                    for toX in range(x + 1, 9):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if hits is False:
                            if state_list[toY][toX] != '一一':
                                hits = True
                            else:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                        else:
                            if state_list[toY][toX] != '一一':
                                if '红' in state_list[toY][toX]:
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)
                                break

                    toX = x
                    hits = False
                    for toY in range(y - 1, -1, -1):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if hits is False:
                            if state_list[toY][toX] != '一一':
                                hits = True
                            else:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                        else:
                            if state_list[toY][toX] != '一一':
                                if '红' in state_list[toY][toX]:
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)
                                break
                    hits = False
                    for toY in range(y + 1, 10):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if hits is False:
                            if state_list[toY][toX] != '一一':
                                hits = True
                            else:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                        else:
                            if state_list[toY][toX] != '一一':
                                if '红' in state_list[toY][toX]:
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)
                                break

                # 红炮的合理走法
                elif state_list[y][x] == '红炮' and current_player_color == '红':
                    toY = y
                    hits = False
                    for toX in range(x - 1, -1, -1):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if hits is False:
                            if state_list[toY][toX] != '一一':
                                hits = True
                            else:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                        else:
                            if state_list[toY][toX] != '一一':
                                if '黑' in state_list[toY][toX]:
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)
                                break
                    hits = False
                    for toX in range(x + 1, 9):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if hits is False:
                            if state_list[toY][toX] != '一一':
                                hits = True
                            else:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                        else:
                            if state_list[toY][toX] != '一一':
                                if '黑' in state_list[toY][toX]:
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)
                                break

                    toX = x
                    hits = False
                    for toY in range(y - 1, -1, -1):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if hits is False:
                            if state_list[toY][toX] != '一一':
                                hits = True
                            else:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                        else:
                            if state_list[toY][toX] != '一一':
                                if '黑' in state_list[toY][toX]:
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)
                                break
                    hits = False
                    for toY in range(y + 1, 10):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if hits is False:
                            if state_list[toY][toX] != '一一':
                                hits = True
                            else:
                                if change_state(state_list, m) != old_state_list:
                                    moves.append(m)
                        else:
                            if state_list[toY][toX] != '一一':
                                if '黑' in state_list[toY][toX]:
                                    if change_state(state_list, m) != old_state_list:
                                        moves.append(m)
                                break

                # 黑兵的合法走子
                elif state_list[y][x] == '黑兵' and current_player_color == '黑':
                    toY = y - 1
                    toX = x
                    if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='黑'):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)
                    # 小兵过河
                    if y < 5:
                        toY = y
                        toX = x + 1
                        if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='黑'):
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)
                        toX = x - 1
                        if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='黑'):
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)

                # 红兵的合法走子
                elif state_list[y][x] == '红兵' and current_player_color == '红':
                    toY = y + 1
                    toX = x
                    if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='红'):
                        m = str(y) + str(x) + str(toY) + str(toX)
                        if change_state(state_list, m) != old_state_list:
                            moves.append(m)
                    # 小兵过河
                    if y > 4:
                        toY = y
                        toX = x + 1
                        if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='红'):
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)
                        toX = x - 1
                        if check_bounds(toY, toX) and check_obstruct(state_list[toY][toX], current_player_color='红'):
                            m = str(y) + str(x) + str(toY) + str(toX)
                            if change_state(state_list, m) != old_state_list:
                                moves.append(m)

    if K_x is not None and k_x is not None and K_x == k_x:
        face_to_face = True
        for i in range(K_y + 1, k_y, 1):
            if state_list[i][K_x] != '一一':
                face_to_face = False

    if face_to_face is True:
        if current_player_color == '黑':
            m = str(k_y) + str(k_x) + str(K_y) + str(K_x)
            if change_state(state_list, m) != old_state_list:
                moves.append(m)
        else:
            m = str(K_y) + str(K_x) + str(k_y) + str(k_x)
            if change_state(state_list, m) != old_state_list:
                moves.append(m)

    moves_id = []
    for move in moves:
        moves_id.append(move_action2move_id[move])
    return moves_id


# 棋盘逻辑控制
class Board(object):

    def __init__(self):
        self.state_list = copy.deepcopy(state_list_init)
        self.game_start = False
        self.winner = None
        self.state_deque = copy.deepcopy(state_deque_init)

    # 初始化棋盘的方法
    def init_board(self, start_player=1):   # 传入先手玩家的id
        # 增加一个颜色到id的映射字典，id到颜色的映射字典
        # 永远是红方先移动
        if start_player == 1:
            self.id2color = {1: '红', 2: '黑'}
            self.color2id = {'红': 1, '黑': 2}
        elif start_player == 2:
            self.id2color = {2: '红', 1: '黑'}
            self.color2id = {'红': 2, '黑': 1}
        # 当前手玩家，也就是先手玩家
        self.current_player_color = self.id2color[start_player]     # 红
        self.current_player_id = self.color2id['红']
        # 初始化棋盘状态
        self.state_list = copy.deepcopy(state_list_init)
        self.state_deque = copy.deepcopy(state_deque_init)
        # 初始化最后落子位置
        self.last_move = -1
        # 记录游戏中吃子的回合数
        self.kill_action = 0
        self.game_start = False
        self.action_count = 0   # 游戏动作计数器
        self.winner = None

    @property
    # 获的当前盘面的所有合法走子集合
    def availables(self):
        return get_legal_moves(self.state_deque, self.current_player_color)

    # 从当前玩家的视角返回棋盘状态，current_state_array: [9, 10, 9]  CHW
    def current_state(self):
        _current_state = np.zeros([9, 10, 9])
        # 使用9个平面来表示棋盘状态
        # 0-6个平面表示棋子位置，1代表红方棋子，-1代表黑方棋子, 队列最后一个盘面
        # 第7个平面表示对手player最近一步的落子位置，走子之前的位置为-1，走子之后的位置为1，其余全部是0
        # 第8个平面表示的是当前player是不是先手player，如果是先手player则整个平面全部为1，否则全部为0
        if self.game_start:
            _current_state[:7] = state_list2state_array(self.state_deque[-1]).transpose([2, 0, 1])  # [7, 10, 9]
            # 解构self.last_move
            move = move_id2move_action[self.last_move]
            start_position = int(move[0]), int(move[1])
            end_position = int(move[2]), int(move[3])
            _current_state[7][start_position[0]][start_position[1]] = -1
            _current_state[7][end_position[0]][end_position[1]] = 1
        # 指出当前是哪个玩家走子
        if self.action_count % 2 == 0:
            _current_state[8][:, :] = 1.0

        return _current_state

    # 根据move对棋盘状态做出改变
    def do_move(self, move):
        self.game_start = True  # 游戏开始
        self.action_count += 1  # 移动次数加1
        move_action = move_id2move_action[move]
        start_y, start_x = int(move_action[0]), int(move_action[1])
        end_y, end_x = int(move_action[2]), int(move_action[3])
        state_list = copy.deepcopy(self.state_deque[-1])
        # 判断是否吃子
        if state_list[end_y][end_x] != '一一':
            # 如果吃掉对方的帅，则返回当前的current_player胜利
            self.kill_action = 0
            if self.current_player_color == '黑' and state_list[end_y][end_x] == '红帅':
                self.winner = self.color2id['黑']
            elif self.current_player_color == '红' and state_list[end_y][end_x] == '黑帅':
                self.winner = self.color2id['红']
        else:
            self.kill_action += 1
        # 更改棋盘状态
        state_list[end_y][end_x] = state_list[start_y][start_x]
        state_list[start_y][start_x] = '一一'
        self.current_player_color = '黑' if self.current_player_color == '红' else '红'  # 改变当前玩家
        self.current_player_id = 1 if self.current_player_id == 2 else 2
        # 记录最后一次移动的位置
        self.last_move = move
        self.state_deque.append(state_list)

    # 是否产生赢家
    def has_a_winner(self):
        """一共有三种状态，红方胜，黑方胜，平局"""
        if self.winner is not None:
            return True, self.winner
        elif self.kill_action >= CONFIG['kill_action']:  # 平局
            return False, -1
        return False, -1

    # 检查当前棋局是否结束
    def game_end(self):
        win, winner = self.has_a_winner()
        if win:
            return True, winner
        elif self.kill_action >= CONFIG['kill_action']:  # 平局，没有赢家
            return True, -1
        return False, -1

    def get_current_player_color(self):
        return self.current_player_color

    def get_current_player_id(self):
        return self.current_player_id


# 在Board类基础上定义Game类，该类用于启动并控制一整局对局的完整流程，并收集对局过程中的数据，以及进行棋盘的展示
class Game(object):

    def __init__(self, board):
        self.board = board

    # 可视化
    def graphic(self, board, player1_color, player2_color):
        print('player1 take: ', player1_color)
        print('player2 take: ', player2_color)
        print_board(state_list2state_array(board.state_deque[-1]))

    # 用于人机对战，人人对战等
    def start_play(self, player1, player2, start_player=1, is_shown=1):
        if start_player not in (1, 2):
            raise Exception('start_player should be either 1 (player1 first) '
                            'or 2 (player2 first)')
        self.board.init_board(start_player)  # 初始化棋盘
        p1, p2 = 1, 2
        player1.set_player_ind(1)
        player2.set_player_ind(2)
        players = {p1: player1, p2: player2}
        if is_shown:
            self.graphic(self.board, player1.player, player2.player)

        while True:
            current_player = self.board.get_current_player_id()  # 红子对应的玩家id
            player_in_turn = players[current_player]  # 决定当前玩家的代理
            move = player_in_turn.get_action(self.board)  # 当前玩家代理拿到动作
            self.board.do_move(move)  # 棋盘做出改变
            if is_shown:
                self.graphic(self.board, player1.player, player2.player)
            end, winner = self.board.game_end()
            if end:
                if winner != -1:
                    print("Game end. Winner is", players[winner])
                else:
                    print("Game end. Tie")
                return winner

    # 使用蒙特卡洛树搜索开始自我对弈，存储游戏状态（状态，蒙特卡洛落子概率，胜负手）三元组用于神经网络训练
    def start_self_play(self, player, is_shown=False, temp=1e-3):
        self.board.init_board()     # 初始化棋盘, start_player=1
        p1, p2 = 1, 2
        states, mcts_probs, current_players = [], [], []
        # 开始自我对弈
        _count = 0
        while True:
            _count += 1
            if _count % 20 == 0:
                start_time = time.time()
                move, move_probs = player.get_action(self.board,
                                                     temp=temp,
                                                     return_prob=1)
                print('走一步要花: ', time.time() - start_time)
            else:
                move, move_probs = player.get_action(self.board,
                                                     temp=temp,
                                                     return_prob=1)
            # 保存自我对弈的数据
            states.append(self.board.current_state())
            mcts_probs.append(move_probs)
            current_players.append(self.board.current_player_id)
            # 执行一步落子
            self.board.do_move(move)
            end, winner = self.board.game_end()
            if end:
                # 从每一个状态state对应的玩家的视角保存胜负信息
                winner_z = np.zeros(len(current_players))
                if winner != -1:
                    winner_z[np.array(current_players) == winner] = 1.0
                    winner_z[np.array(current_players) != winner] = -1.0
                # 重置蒙特卡洛根节点
                player.reset_player()
                if is_shown:
                    if winner != -1:
                        print("Game end. Winner is:", winner)
                    else:
                        print('Game end. Tie')

                return winner, zip(states, mcts_probs, winner_z)


if __name__ == '__main__':
    """# 测试array2string
    _array = np.array([0, 0, 0, 0, 0, 0, 0])
    print(array2string(_array))"""

    """# 测试change_state
    new_state = change_state(state_list_init, move='0010')
    for row in range(10):
        print(new_state[row])"""

    """# 测试print_board
    _state_list = copy.deepcopy(state_list_init)
    print_board(state_list2state_array(_state_list))"""

    """# 测试get_legal_moves
    moves = get_legal_moves(state_deque_init, current_player_color='黑')
    move_actions = []
    for item in moves:
        move_actions.append(move_id2move_action[item])
    print(move_actions)"""

    # 测试Board中的start_play
    class Human1:
        def get_action(self, board):
            # print('当前是player1在操作')
            # print(board.current_player_color)
            # move = move_action2move_id[input('请输入')]
            move = random.choice(board.availables)
            return move

        def set_player_ind(self, p):
            self.player = p


    class Human2:
        def get_action(self, board):
            # print('当前是player2在操作')
            # print(board.current_player_color)
            # move = move_action2move_id[input('请输入')]
            move = random.choice(board.availables)
            return move

        def set_player_ind(self, p):
            self.player = p

    human1 = Human1()
    human2 = Human2()
    game = Game(board=Board())
    for i in range(20):
        game.start_play(human1, human2, start_player=2, is_shown=0)

## 四、神经网络预测

* 这一部分，让我们来创建神经网络。网络输入是当前的棋盘状态，然后经过前向传播返回走子概率向量和盘面价值。损失函数包含三部分：1、l2正则化，在优化器中定义；2、mse_loss，价值评估的损失（蒙特卡洛树搜索得到的价值和神经网络预测之间的误差）；3、概率向量一致性损失（蒙特卡洛树搜索得到的概率向量和神经网络预测之间的误差）

* 我们的神经网络将用于蒙特卡洛树搜索中，实现对树搜索过程中遇见的节点进行一个价值的评估，和赋予当前节点所有子节点的先验概率的功能。

* 我们通过神经网络来对搜索的宽度和深度进行一个裁剪，同时我们使用蒙特卡洛树搜索得到的数据来对神经网络进行一个训练。神经网络和蒙特卡洛搜索树是一个相辅相成的关系。

![](https://ai-studio-static-online.cdn.bcebos.com/2a823c3c91b243428508697f5f788a55892a90126bde4a6e85d0dd2c5a7b0a0c)


In [None]:
"""策略价值网络"""


import paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F


# 搭建残差块
class ResBlock(nn.Layer):

    def __init__(self, num_filters=256):
        super().__init__()
        self.conv1 = nn.Conv2D(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=1)
        self.conv1_bn = nn.BatchNorm2D(num_features=num_filters)
        self.conv1_act = nn.ReLU()
        self.conv2 = nn.Conv2D(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=1)
        self.conv2_bn = nn.BatchNorm2D(num_features=num_filters)
        self.conv2_act = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.conv1_bn(y)
        y = self.conv1_act(y)
        y = self.conv2(y)
        y = self.conv2_bn(y)
        y = x + y
        return self.conv2_act(y)


# 搭建骨干网络，输入：N, 9, 10, 9 --> N, C, H, W
class Net(nn.Layer):

    def __init__(self, num_channels=256, num_res_blocks=7):
        super().__init__()
        # 初始化特征
        self.conv_block = nn.Conv2D(in_channels=9, out_channels=num_channels, kernel_size=3, stride=1, padding=1)
        self.conv_block_bn = nn.BatchNorm2D(num_features=256)
        self.conv_block_act = nn.ReLU()
        # 残差块抽取特征
        self.res_blocks = nn.LayerList([ResBlock(num_filters=num_channels) for _ in range(num_res_blocks)])
        # 策略头
        self.policy_conv = nn.Conv2D(in_channels=num_channels, out_channels=16, kernel_size=1, stride=1)
        self.policy_bn = nn.BatchNorm2D(16)
        self.policy_act = nn.ReLU()
        self.policy_fc = nn.Linear(16 * 9 * 10, 2086)
        # 价值头
        self.value_conv = nn.Conv2D(in_channels=num_channels, out_channels=8, kernel_size=1, stride=1)
        self.value_bn = nn.BatchNorm2D(8)
        self.value_act1 = nn.ReLU()
        self.value_fc1 = nn.Linear(8 * 9 * 10, 256)
        self.value_act2 = nn.ReLU()
        self.value_fc2 = nn.Linear(256, 1)

    # 定义前向传播
    def forward(self, x):
        # 公共头
        x = self.conv_block(x)
        x = self.conv_block_bn(x)
        x = self.conv_block_act(x)
        for layer in self.res_blocks:
            x = layer(x)
        # 策略头
        policy = self.policy_conv(x)
        policy = self.policy_bn(policy)
        policy = self.policy_act(policy)
        policy = paddle.reshape(policy, [-1, 16 * 10 * 9])
        policy = self.policy_fc(policy)
        policy = F.log_softmax(policy)
        # 价值头
        value = self.value_conv(x)
        value = self.value_bn(value)
        value = self.value_act1(value)
        value = paddle.reshape(value, [-1, 8 * 10 * 9])
        value = self.value_fc1(value)
        value = self.value_act1(value)
        value = self.value_fc2(value)
        value = F.tanh(value)

        return policy, value


# 策略值网络，用来进行模型的训练
class PolicyValueNet:

    def __init__(self, model_file=None, use_gpu=True):
        self.use_gpu = use_gpu
        self.l2_const = 2e-3    # l2 正则化
        self.policy_value_net = Net()
        self.optimizer = paddle.optimizer.Adam(learning_rate=0.001,
                                               parameters=self.policy_value_net.parameters(),
                                               weight_decay=self.l2_const)
        if model_file:
            net_params = paddle.load(model_file)
            self.policy_value_net.set_state_dict(net_params)

    # 输入一个批次的状态，输出一个批次的动作概率和状态价值
    def policy_value(self, state_batch):
        self.policy_value_net.eval()
        state_batch = paddle.to_tensor(state_batch)
        log_act_probs, value = self.policy_value_net(state_batch)
        act_probs = np.exp(log_act_probs.numpy())
        return act_probs, value.numpy()

    # 输入棋盘，返回每个合法动作的（动作，概率）元组列表，以及棋盘状态的分数
    def policy_value_fn(self, board):
        self.policy_value_net.eval()
        # 获取合法动作列表
        legal_positions = board.availables
        current_state = np.ascontiguousarray(board.current_state().reshape(-1, 9, 10, 9)).astype('float32')
        current_state = paddle.to_tensor(current_state)
        # 使用神经网络进行预测
        log_act_probs, value = self.policy_value_net(current_state)
        act_probs = np.exp(log_act_probs.numpy().flatten())
        # 只取出合法动作
        act_probs = zip(legal_positions, act_probs[legal_positions])
        # 返回动作概率，以及状态价值
        return act_probs, value.numpy()

    # 得到模型参数
    def get_policy_param(self):
        net_params = self.policy_value_net.state_dict()
        return net_params

    # 保存模型
    def save_model(self, model_file):
        net_params = self.get_policy_param()    # 取得模型参数
        paddle.save(net_params, model_file)

    # 执行一步训练
    def train_step(self, state_batch, mcts_probs, winner_batch, lr=0.002):
        self.policy_value_net.train()
        # 包装变量
        state_batch = paddle.to_tensor(state_batch)
        mcts_probs = paddle.to_tensor(mcts_probs)
        winner_batch = paddle.to_tensor(winner_batch)
        # 清零梯度
        self.optimizer.clear_gradients()
        # 设置学习率
        self.optimizer.set_lr(lr)
        # 前向运算
        log_act_probs, value = self.policy_value_net(state_batch)
        value = paddle.reshape(x=value, shape=[-1])
        # 价值损失
        value_loss = F.mse_loss(input=value, label=winner_batch)
        # 策略损失
        policy_loss = -paddle.mean(paddle.sum(mcts_probs * log_act_probs, axis=1))  # 希望两个向量方向越一致越好
        # 总的损失，注意l2惩罚已经包含在优化器内部
        loss = value_loss + policy_loss
        # 反向传播及优化
        loss.backward()
        self.optimizer.minimize(loss)
        # 计算策略的熵，仅用于评估模型
        entropy = -paddle.mean(
            paddle.sum(paddle.exp(log_act_probs) * log_act_probs, axis=1)
        )
        return loss.numpy(), entropy.numpy()[0]


if __name__ == '__main__':
    net = Net()
    test_data = paddle.ones([8, 9, 10, 9])
    x_act, x_val = net(test_data)
    print(x_act.shape)  # 8, 2086
    print(x_val.shape)  # 8, 1


## 五、蒙特卡洛树搜索

* 这一节，我们来实现非常关键的蒙特卡洛树搜索。在开始之前，我们先考量一下强化学习价值迭代过程，如下视频所示。这是一个走迷宫游戏s8位置价值为1，价值越高，颜色越红。随着我们智能体的一次次价值迭代，我们可以发现越靠近s8的方块价值越早上升。因为s7走一步就到s8，这样s7就慢慢学到了具有高价值，然后是s4走一步到s7，s4的价值也就慢慢上升。也就是说，价值迭代是一个反向传播的过程。这和我们接下来蒙特卡洛树节点的更新非常相关，如果没有游戏终盘的胜：1，负：-1，我们也是不可能训练的了模型。包括走子概率的学习，也要依赖于状态价值的评估。

* 所以，在训练的早期，神经网络的预测值是无效的，只有当搜索树搜索到胜负的时候，才能利用胜负的价值起到一个训练指导作用。在训练的后期，神经网络就会慢慢的从游戏终局往游戏开始学习价值。从而使得不需要搜索到游戏终局，也能用神经网络的预测值作为替代。

<div align=center><img src="https://ai-studio-static-online.cdn.bcebos.com/9db2eba30f56460ba3d47589f290c86f9f42792ff38640018f1159babda46718"></div>

为了更方便的理解博弈树，如下是甲乙二人的一个博弈示例。

<div align=center><img src="https://ai-studio-static-online.cdn.bcebos.com/bde6225c59c0442cae2e79f8271348af252bf399d3a946f8837750c1efd3f394"></div>


* 接下来我们正式进入蒙特卡洛树搜索，为了方便理解，我们以井字棋为例，使用puct算法进行4次模拟，每次模拟都需要走到叶节点才停止（也就是尚未扩展的节点）。其中Q值是随意设的，并不代表真实情况。

如下，我们进行第一次模拟，根据puct算法得出三个动作值都是0，我们从中选一个最大值，假设就是第一个，得到神经网络估计的Q值0.6，然后我们进行一次反向传播更新其所有父节点的各个参数值（Q，访问次数）等。
![](https://ai-studio-static-online.cdn.bcebos.com/58242c15151c41fd8dce8415bf32e97c687f388099e44c198f25edc1866ea637)

然后我们进行第二次模拟，根据上一次模拟更新的参数再次计算动作值，第一个动作值最大为1.85，我们就选第一个节点。注意：因为第一个节点已经不是叶节点了，所以我们要再进行下一个子节点的选择，同样初次访问，所有动作值是0，假设我们选了第一个子节点，返回0.3，然后对父节点更新时要将Q值添加负号。算出第一个节点的平均Q值为0.15。
![](https://ai-studio-static-online.cdn.bcebos.com/f993bfa6b1674c8fa885935b6a0c0ec715a929f333e142cfa8dced7770f2b354)

同样，我们再进行第三次模拟。
![](https://ai-studio-static-online.cdn.bcebos.com/04234035fb8a4e378bd6d3aa9b046932368da7888fc440f98aa6e31e301c5f72)

第四次模拟。
![](https://ai-studio-static-online.cdn.bcebos.com/6d7c9720a520437eb632d096992ea393204b7e0544c44a2eb29e1f27ed3c2048)

如下所示，蒙特卡洛树搜索一直循环（选节点、扩展和评估、反向更新树中各节点）直到指定的模拟次数，最后进行一步落子。请务必要注意到神经网络在此过程中的缩减搜索树宽度和深度的含义。
![](https://ai-studio-static-online.cdn.bcebos.com/5d63fd157329483cbe1fcb5542beb0baff33e8462b1a40afa276f399670ddde9)

* 随着我们蒙特卡洛树搜索模拟的次数增多，对子节点的估值就会越准确。在本项目的搜索树中，展开搜索树2000个节点之后，以节点的选择次数作为走子概率，然后加上迪利克雷噪声进行真实走子。真实走子之后，根节点变换到走子后的盘面。树搜索次数越多让走子的结果更可靠，平均值依概率收敛到期望。

![](https://ai-studio-static-online.cdn.bcebos.com/011d564912f04da0934e72790d7e781775fe697ee98941fba376ec8cbf9f7527)

加上迪利克雷噪声用于探索未知盘面。





In [None]:
"""蒙特卡洛树搜索"""


import numpy as np
import copy
from config import CONFIG


def softmax(x):
    probs = np.exp(x - np.max(x))
    probs /= np.sum(probs)
    return probs


# 定义叶子节点
class TreeNode(object):
    """
    mcts树中的节点，树的子节点字典中，键为动作，值为TreeNode。记录当前节点选择的动作，以及选择该动作后会跳转到的下一个子节点。
    每个节点跟踪其自身的Q，先验概率P及其访问次数调整的u
    """

    def __init__(self, parent, prior_p):
        """
        :param parent: 当前节点的父节点
        :param prior_p:  当前节点被选择的先验概率
        """
        self._parent = parent
        self._children = {} # 从动作到TreeNode的映射
        self._n_visits = 0  # 当前当前节点的访问次数
        self._Q = 0         # 当前节点对应动作的平均动作价值
        self._u = 0         # 当前节点的置信上限         # PUCT算法
        self._P = prior_p

    def expand(self, action_priors):    # 这里把不合法的动作概率全部设置为0
        """通过创建新子节点来展开树"""
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] =  TreeNode(self, prob)

    def select(self, c_puct):
        """
        在子节点中选择能够提供最大的Q+U的节点
        return: (action, next_node)的二元组
        """
        return max(self._children.items(),
                   key=lambda act_node: act_node[1].get_value(c_puct))

    def get_value(self, c_puct):
        """
        计算并返回此节点的值，它是节点评估Q和此节点的先验的组合
        c_puct: 控制相对影响（0， inf）
        """
        self._u = (c_puct * self._P *
                   np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
        return self._Q + self._u

    def update(self, leaf_value):
        """
        从叶节点评估中更新节点值
        leaf_value: 这个子节点的评估值来自当前玩家的视角
        """
        # 统计访问次数
        self._n_visits += 1
        # 更新Q值，取决于所有访问次数的平均树，使用增量式更新方式
        self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits

    # 使用递归的方法对所有节点（当前节点对应的支线）进行一次更新
    def update_recursive(self, leaf_value):
        """就像调用update()一样，但是对所有直系节点进行更新"""
        # 如果它不是根节点，则应首先更新此节点的父节点
        if self._parent:
            self._parent.update_recursive(-leaf_value)
        self.update(leaf_value)

    def is_leaf(self):
        """检查是否是叶节点，即没有被扩展的节点"""
        return self._children == {}

    def is_root(self):
        return self._parent is None


# 蒙特卡洛搜索树
class MCTS(object):

    def __init__(self, policy_value_fn, c_puct=5, n_playout=2000):
        """policy_value_fn: 接收board的盘面状态，返回落子概率和盘面评估得分"""
        self._root = TreeNode(None, 1.0)
        self._policy = policy_value_fn
        self._c_puct = c_puct
        self._n_playout = n_playout

    def _playout(self, state):
        """
        进行一次搜索，根据叶节点的评估值进行反向更新树节点的参数
        注意：state已就地修改，因此必须提供副本
        """
        node = self._root
        while True:
            if node.is_leaf():
                break
            # 贪心算法选择下一步行动
            action, node = node.select(self._c_puct)
            state.do_move(action)

        # 使用网络评估叶子节点，网络输出（动作，概率）元组p的列表以及当前玩家视角的得分[-1, 1]
        action_probs, leaf_value = self._policy(state)
        # 查看游戏是否结束
        end, winner = state.game_end()
        if not end:
            node.expand(action_probs)
        else:
            # 对于结束状态，将叶子节点的值换成1或-1
            if winner == -1:    # Tie
                leaf_value = 0.0
            else:
                leaf_value = (
                    1.0 if winner == state.get_current_player_id() else -1.0
                )
        # 在本次遍历中更新节点的值和访问次数
        # 必须添加符号，因为两个玩家共用一个搜索树
        node.update_recursive(-leaf_value)

    def get_move_probs(self, state, temp=1e-3):
        """
        按顺序运行所有搜索并返回可用的动作及其相应的概率
        state:当前游戏的状态
        temp:介于（0， 1]之间的温度参数
        """
        for n in range(self._n_playout):
            state_copy = copy.deepcopy(state)
            self._playout(state_copy)

        # 跟据根节点处的访问计数来计算移动概率
        act_visits= [(act, node._n_visits)
                     for act, node in self._root._children.items()]
        acts, visits = zip(*act_visits)
        act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
        return acts, act_probs

    def update_with_move(self, last_move):
        """
        在当前的树上向前一步，保持我们已经直到的关于子树的一切
        """
        if last_move in self._root._children:
            self._root = self._root._children[last_move]
            self._root._parent = None
        else:
            self._root = TreeNode(None, 1.0)

    def __str__(self):
        return 'MCTS'


# 基于MCTS的AI玩家
class MCTSPlayer(object):

    def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0):
        self.mcts = MCTS(policy_value_function, c_puct, n_playout)
        self._is_selfplay = is_selfplay
        self.agent = "AI"

    def set_player_ind(self, p):
        self.player = p

    # 重置搜索树
    def reset_player(self):
        self.mcts.update_with_move(-1)

    def __str__(self):
        return 'MCTS {}'.format(self.player)

    # 得到行动
    def get_action(self, board, temp=1e-3, return_prob=0):
        # 像alphaGo_Zero论文一样使用MCTS算法返回的pi向量
        move_probs = np.zeros(2086)

        acts, probs = self.mcts.get_move_probs(board, temp)
        move_probs[list(acts)] = probs
        if self._is_selfplay:
            # 添加Dirichlet Noise进行探索（自我对弈需要）
            move = np.random.choice(
                acts,
                p=0.75*probs + 0.25*np.random.dirichlet(CONFIG['dirichlet'] * np.ones(len(probs)))
            )
            # 更新根节点并重用搜索树
            self.mcts.update_with_move(move)
        else:
            # 使用默认的temp=1e-3，它几乎相当于选择具有最高概率的移动
            move = np.random.choice(acts, p=probs)
            # 重置根节点
            self.mcts.update_with_move(-1)
        if return_prob:
            return move, move_probs
        else:
            return move

## 六、影分身训练

* 本项目的同步并行是基于数据库的方式来实现的。每个分身自我对弈结束都把自己的数据添加到同一个data_buffer中去。每次自我对弈结束，都拿到主体更新的最新模型进行下一次的对弈。主体每隔固定时间读取data_buffer进行一次模型训练。
![](https://ai-studio-static-online.cdn.bcebos.com/1c576a089d694fa89617bdfe7109f7da2c11bdeffbab4bc799124d3e98368cba)





In [None]:
"""自我对弈收集数据"""


from collections import deque
import copy
import os
import pickle
import time
from game import Board, Game, move_action2move_id, move_id2move_action, flip_map
from net import PolicyValueNet
from mcts import MCTSPlayer
from config import CONFIG


# 定义整个对弈收集数据流程
class CollectPipeline:

    def __init__(self, init_model=None):
        # 象棋逻辑和棋盘
        self.board = Board()
        self.game = Game(self.board)
        # 对弈参数
        self.temp = 1   # 温度
        self.n_playout = CONFIG['play_out']  # 每次移动的模拟次数
        self.c_puct = CONFIG['c_puct']       # u的权重
        self.buffer_size = CONFIG['buffer_size']    # 经验池大小
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.iters = 0

    # 从主体加载模型
    def load_model(self, model_path=CONFIG['model_path']):
        try:
            self.policy_value_net = PolicyValueNet(model_file=model_path)
            print('已加载最新模型')
        except:
            self.policy_value_net = PolicyValueNet()
            print('已加载初始模型')
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """左右对称变换，扩充数据集一倍，加速一倍训练速度"""
        extend_data = []
        # 棋盘状态shape is [9, 10, 9], 走子概率，赢家
        for state, mcts_prob, winner in play_data:
            # 原始数据
            extend_data.append((state, mcts_prob, winner))
            # 水平翻转后的数据
            state_flip = state.transpose([1, 2, 0])
            state = state.transpose([1, 2, 0])
            for i in range(10):
                for j in range(9):
                    state_flip[i][j] = state[i][8-j]
            state_flip = state_flip.transpose([2, 0, 1])
            mcts_prob_flip = copy.deepcopy(mcts_prob)
            for i in range(len(mcts_prob_flip)):
                mcts_prob_flip[i] = mcts_prob[move_action2move_id[flip_map(move_id2move_action[i])]]
            extend_data.append((state_flip, mcts_prob_flip, winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        # 收集自我对弈的数据
        for i in range(n_games):
            self.load_model()   # 从本体处加载最新模型
            winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp, is_shown=False)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # 增加数据
            play_data = self.get_equi_data(play_data)

            if os.path.exists(CONFIG['train_data_buffer_path']):
                while True:
                    try:
                        with open(CONFIG['train_data_buffer_path'], 'rb') as data_dict:
                            data_file = pickle.load(data_dict)
                            self.data_buffer = data_file['data_buffer']
                            self.iters = data_file['iters']
                            del data_file
                            self.iters += 1
                            self.data_buffer.extend(play_data)
                        print('成功载入数据')
                        break
                    except:
                        time.sleep(30)
            else:
                self.data_buffer.extend(play_data)
                self.iters += 1
            data_dict = {'data_buffer': self.data_buffer, 'iters': self.iters}
            with open(CONFIG['train_data_buffer_path'], 'wb') as data_file:
                pickle.dump(data_dict, data_file)
        return self.iters

    def run(self):
        """开始收集数据"""
        try:
            while True:
                iters = self.collect_selfplay_data()
                print('batch i: {}, episode_len: {}'.format(
                    iters, self.episode_len))
        except KeyboardInterrupt:
            print('\n\rquit')


collecting_pipeline = CollectPipeline(init_model='current_policy.model')
collecting_pipeline.run()



In [None]:
"""使用收集到数据进行训练"""


import random
import numpy as np
import pickle
import time
from net import PolicyValueNet
from config import CONFIG


# 定义整个训练流程
class TrainPipeline:

    def __init__(self, init_model=None):
        # 训练参数
        self.learn_rate = 1e-3
        self.lr_multiplier = 1  # 基于KL自适应的调整学习率
        self.temp = 1.0
        self.batch_size = CONFIG['batch_size']  # 训练的batch大小
        self.epochs = CONFIG['epochs']  # 每次更新的train_step数量
        self.kl_targ = CONFIG['kl_targ']  # kl散度控制
        self.check_freq = 100  # 保存模型的频率
        self.game_batch_num = CONFIG['game_batch_num']  # 训练更新的次数

        if init_model:
            try:
                self.policy_value_net = PolicyValueNet(model_file=init_model)
                print('已加载上次最终模型')
            except:
                # 从零开始训练
                print('模型路径不存在，从零开始训练')
                self.policy_value_net = PolicyValueNet()
        else:
            print('从零开始训练')
            self.policy_value_net = PolicyValueNet()

    def policy_updata(self):
        """更新策略价值网络"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)

        state_batch = [data[0] for data in mini_batch]
        state_batch = np.array(state_batch).astype('float32')

        mcts_probs_batch = [data[1] for data in mini_batch]
        mcts_probs_batch = np.array(mcts_probs_batch).astype('float32')

        winner_batch = [data[2] for data in mini_batch]
        winner_batch = np.array(winner_batch).astype('float32')

        # 旧的策略，旧的价值函数
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)

        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch,
                mcts_probs_batch,
                winner_batch,
                self.learn_rate * self.lr_multiplier
            )
            # 新的策略，新的价值函数
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)

            kl = np.mean(np.sum(old_probs * (
                np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                                axis=1))
            if kl > self.kl_targ * 4:  # 如果KL散度很差，则提前终止
                break

        # 自适应调整学习率
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))

        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        entropy,
                        explained_var_old,
                        explained_var_new))
        return loss, entropy

    def run(self):
        """开始训练"""
        try:
            for i in range(self.game_batch_num):
                time.sleep(30)  # 每10分钟更新一次模型
                while True:
                    try:
                        with open(CONFIG['train_data_buffer_path'], 'rb') as data_dict:
                            data_file = pickle.load(data_dict)
                            self.data_buffer = data_file['data_buffer']
                            self.iters = data_file['iters']
                            del data_file
                        print('已载入数据')
                        break
                    except:
                        time.sleep(30)
                print('step i {}: '.format(self.iters))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_updata()
                # 保存模型
                self.policy_value_net.save_model(CONFIG['model_path'])
                if (i + 1) % self.check_freq == 0:
                    print('current selfplay batch: {}'.format(i + 1))
                    self.policy_value_net.save_model('models/current_policy_batch{}.model'.format(i + 1))
        except KeyboardInterrupt:
            print('\n\rquit')


training_pipeline = TrainPipeline(init_model='current_policy.model')
training_pipeline.run()


## 七、参考与致谢

* 本项目主要参考的资料如下，十分感谢大佬们的分享

1、程世东 https://zhuanlan.zhihu.com/p/34433581 （中国象棋cchesszero ）

2、AI在打野 https://aistudio.baidu.com/aistudio/projectdetail/1403398 （用paddle打造的五子棋AI）

3、junxiaosong  https://github.com/junxiaosong/AlphaZero_Gomoku (五子棋alphazero)

4、书籍：边做边学深度强化学习：PyTorch 程序设计实践

5、书籍：强化学习第二版

后续应该会对该AI继续训练下去，亲手造一个超强的下棋AI简直太酷了！