In [1]:
import numpy as np
import itertools
import copy
import multiprocessing as mp
from collections import Counter
from math import sqrt, log

playerA = 1  # 持子
playerB = -1


class Board:
    """
    记录棋盘的状态和可以走的点，并返回是否有人赢了。
    """

    def __init__(self, size):
        self.board_size = size
        self.board = np.zeros([size, size])
        self.available_points_list = [(x, y) for x in range(size) for y in range(size)]
        self.move_recode = []

    def check_continus(self, stone_list, player_list):
        """
        检查连续性
        """
        # print(stone_list, player_list)

        return player_list in [list(g) for k, g in itertools.groupby(stone_list)]

    def check_for_win(self, player, continus_number):
        """
        检查是否获胜
        """
        m, n = self.board.shape
        player_list = [player] * continus_number

        for horizontal_index in range(n):
            check_ans = self.check_continus(list(self.board[horizontal_index, :]), player_list)

            if check_ans == True:
                return True

        for vertical_index in range(m):
            check_ans = self.check_continus(list(self.board[:, vertical_index]), player_list)

            if check_ans == True:
                return True

        for diagonal_index in range(-m // 2, m // 2):
            check_ans = self.check_continus(list(self.board.diagonal(diagonal_index)), player_list)

            if check_ans == True:
                return True

        for diagonal_index in range(-n // 2, n // 2):
            check_ans = self.check_continus(list(np.fliplr(self.board).diagonal(diagonal_index)), player_list)

            if check_ans == True:
                return True
        return False

    def upgrade_available_points(self):
        """
        更新空点信息
        """
        self.available_points_list = []
        index_list = np.where(self.board == 0)
        for x, y in zip(index_list[0], index_list[0]):
            self.available_points_list.append((x, y))

    def add_stone(self, x, y, player):
        """
        落子
        """
        if player in [-1, 0, 1] \
                and x < self.board_size \
                and y < self.board_size \
                and x >= 0 \
                and y >= 0 \
                and self.board[x][y] == 0:
            self.board[x][y] = player
            self.available_points_list.remove((x, y))
            self.move_recode.append((x, y))
        else:
            raise Exception('can not add stone here!!')


class Tree:
    """
    构建搜索树，可以返回任意一条路径的任意深度的结果和状态
    并对选定的节点进行统计采样
    """

    def __init__(self, initial_board, firstplayer=playerA, continus_number=3):
        self.firstplayer = firstplayer
        self.initial_board = initial_board
        self.continus_number = continus_number

    def winner_detection(self, board, path, player):
        """
        搜索给定的path的结果，中间换人落子，player优先落子
        """
        winner = None
        for point in path:
            board.add_stone(point[0], point[1], player=player)
            winner_exist = board.check_for_win(player, self.continus_number)
            if winner_exist:  # 如果有人获胜
                winner = player
                break
            player = -player  # 换人落子
        return winner

    def fully_expanded(self, para):
        """
        对指定深度的point进行全覆盖的搜索
        """
        # deep = 2  # 指定搜索深度
        player = -self.firstplayer
        board = copy.deepcopy(self.initial_board)
        board.add_stone(para['point'][0], para['point'][1], self.firstplayer)
        availabel_points_list = board.available_points_list  # 可落子的点
        # path_list = list(itertools.permutations(availabel_points_list, len(availabel_points_list)))  # 得到全覆盖的path list
        path_list = list(itertools.permutations(availabel_points_list, para['deep']))  # 得到深度为2的path list

        winner_count_list = []
        for path in path_list:  # 接下来统计这个深度的胜率
            subboard = copy.deepcopy(board)
            winner = self.winner_detection(subboard, path, player=player)
            winner_count_list.append(winner)

        winner_recode = Counter(winner_count_list)  # 统计遍历结果
        # win_rate = winner_recode[self.firstplayer] / (winner_recode[-self.firstplayer]+1)  # 计算胜负率 Q/N
        UCT = (winner_recode[self.firstplayer] / (len(winner_count_list))) + sqrt(2) * sqrt(log(para['sum_root_path'])) / len(
            winner_count_list)  # 计算胜率 Q/All path

        return (para['point'], winner_recode, UCT)

    def traverse(self, deep):
        """
        蒙特卡洛树中找到best_uct节点
        """
        deep = deep  # 指定搜索深度
        win_rate_recode = {}
        best_point = None
        best_rate = None
        poll = mp.Pool(processes=8)  # 并行运算
        root_width = len(self.initial_board.available_points_list)
        fully_expanded_para_list = []

        # 计算总的搜索数量
        sum_root_path = 1
        for i in range(deep):
            sum_root_path *= root_width-i

        for point in self.initial_board.available_points_list:
            fully_expanded_para_list.append({'point': point, 'sum_root_path': sum_root_path, 'deep': deep})

        recode_collection = poll.map(self.fully_expanded, fully_expanded_para_list)
        for recode in recode_collection:
            if best_point == None or best_rate == None:
                best_point = recode[0]
                best_rate = recode[2]
            elif recode[2] > best_rate:
                best_rate = recode[2]
                best_point = recode[0]

            win_rate_recode[recode[0]] = recode[2]
        return best_point

    # def monte_carlo_tree_search(self,deep):
    #     start_time = time.time()
    #     board = copy.deepcopy(self.initial_board)
    #     # while time.time() - start_time < 50:
    #     leaf = self.traverse(deep)
    #     print(leaf)

        # while resources_left(time, computational power):
        #     leaf = traverse(root)  # leaf = unvisited node
        #     simulation_result = rollout(leaf)
        #     backpropagate(leaf, simulation_result)
        # return best_child(root)


# if __name__ == "__name__":
#     import time

#     board = Board(3)
#     # for point in [(0,0),(1,1), (2,2),(3,3),(4,4)]:
#     #     board.add_stone(point[0], point[1], 1)
#     # result = board.check_for_win(1,5)
#     # print(result)
#     # print(board.available_points_list)
#     # board.check_for_win(1, 3)
#     # board.add_stone(1, 1, -1)

#     tree = Tree(board, continus_number=2)
#     a = time.time()
#     result = tree.monte_carlo_tree_search()
#     b = time.time()
#     print(b - a)


In [2]:
import time
board = Board(3)
tree = Tree(board, continus_number=3)
a = time.time()
result = tree.traverse(deep=3)
print(result)
b = time.time()
print(b - a)

(0, 0)
0.514704704284668
