In [4]:
import numpy as np
from collections import defaultdict

# Tic-tac-toe game design

0


In [2]:
class gomoku_move():
    def __init__(self, x_cor, y_cor, value) -> None:
        self.x_cor = x_cor
        self.y_cor = y_cor
        self.value = value

class gomoku_game():
    x = 1
    o = -1
    def __init__(self, state, next_person, win=5) -> None:
        """


        Input: 
            state: the game  state (15 * 15)
        """
        self.board = state
        self.board_size = state.shape[0]
        self.win = win # how many in the board lead to win
        self.next_person = next_person


    def check_win(self):
        """
        1 means x wins 
        -1 means o wins
        0 means draw
        None means no result yet.
        """
        for i in range(self.board_size - self.win + 1):
            row_sum = np.sum(self.board[i:i+self.win], axis = 1)
            colum_sum = np.sum(self.board[i:i+self.win], axis= 0)
            if row_sum.max() == self.win or colum_sum.max() == self.win:
                return self.x
            
            if row_sum.min() == -self.win or colum_sum.min() == -self.win:
                return self.o
        for i in range(self.board_size - self.win + 1):
            for j in range(self.board_size - self.win + 1):
                sub_board = self.board[i:i+self.win, j:j+self.win]
                dig_sum = np.trace(sub_board)
                dig_sum_inv = np.trace(sub_board[::-1])
                if dig_sum == self.win or dig_sum_inv == self.win:
                    return self.x
                if dig_sum == -self.win or dig_sum_inv == -self.win:
                    return self.o
        # draw 
        if np.all(self.board != 0):
            return 0
        
        return None
    
    def is_game_over(self):
        """
        check wheter game is over 
        retrun 1: game over
               0: keep going 
        """
        return self.check_win() is not None
    
    def check_move_legal(self, move: gomoku_move):
        if move.value != self.next_person:
            return False
        
        x_in_range = (0 <= move.x_cor < self.board_size)
        y_in_range = (0 <= move.y_cor < self.board_size)
        if not x_in_range or not y_in_range:
            return False

        return self.board[move.x_cor, move.y_cor] == 0
    
    def move(self, move: gomoku_move):
        if not self.check_move_legal(move):
            raise ValueError("It is ilegal move")
        new_board = self.board.copy()
        new_board[move.x_cor, move.y_cor] = move.value
        next_person = - self.next_person

        return type(self)(new_board, next_person) #######
    
    def get_legal_action(self):
        index = np.where(self.board == 0)
        return [gomoku_move(coord[0], coord[1], self.next_person)
                for coord in list(zip(index[0], index[1]))]

In [5]:
state = np.zeros((15,15))

ttc = gomoku_game(state=state, next_person=1)
ttc.is_game_over()
print(ttc.check_win())


None


# Monte Carlo Tree node Design

In [5]:
class MC_TreeNode():
    def __init__(self, state: gomoku_game, parent=None) -> None:
        self.state = state
        self.parent: MC_TreeNode = parent
        self.children = []
        
        self.number_of_visit = 0
        self.untried_action = None
        self.result = defaultdict(int)

        
    def is_terminal_node(self):
        return self.state.is_game_over()
    
    def is_fully_expand(self):
        
        return len(self.find_untried_action_of_node()) == 0
    

    def find_untried_action_of_node(self) -> gomoku_move:
        if self.untried_action is None:
            self.untried_action = self.state.get_legal_action()
        return self.untried_action
    
    def q(self):
        win = self.result[self.parent.state.next_person]
        loss = self.result[-self.parent.state.next_person]
        return win - loss


    def n(self):
        return self.number_of_visit


    def expand(self):
        action = self.find_untried_action_of_node().pop()
        next_state = self.state.move(action)
        child_node = MC_TreeNode(next_state, parent=self)
        self.children.append(child_node)
        return child_node
    
    def rollout_policy(self, possible_moves):
        return possible_moves[np.random.randint(len(possible_moves))]
    
    def rollout(self):
        current_state = self.state
        while not current_state.is_game_over():
            possible_action = current_state.get_legal_action()
            action = self.rollout_policy(possible_action)
            current_state = current_state.move(action)
        return current_state.check_win()
    
    def backpropagate(self, result):
        self.number_of_visit += 1
        self.result[result] += 1
        if self.parent:
            self.parent.backpropagate(result)


    def best_children(self,c_para = 1.5):
        weight = []
        for each_children in self.children:
            each_weight = each_children.q() / each_children.n() + c_para * np.sqrt(np.log(self.n()) / each_children.n())
            weight.append(each_weight)
        return self.children[np.argmax(weight)]

# Monte Carlo Tree seach

In [6]:
class MC_TreeSearch():
    def __init__(self, node: MC_TreeNode) -> None:
        self.root = node

    def best_move(self, simulation_number = 1000):
        for _ in range(0, simulation_number):
           v = self.policy_in_searchtree() # v is the next move
           reward = v.rollout()
           v.backpropagate(reward)
        return self.root.best_children()



    def policy_in_searchtree(self) -> MC_TreeNode:
        """
        select a node to run rollout
        return a node
        """
        currenn_node = self.root
        while not currenn_node.is_terminal_node():
            if not currenn_node.is_fully_expand(): 
                return currenn_node.expand()
            else:
                currenn_node = currenn_node.best_children()
        return currenn_node
        

In [7]:
state = np.zeros((3, 3))
board_state = tic_tac_toe(state=state, next_person=1)

root = MC_TreeNode(state=board_state,parent=None)
mcts = MC_TreeSearch(root)
next_move = mcts.best_move(10000)
print(next_move.state.board)
next_person = 1
while not next_move.state.is_game_over():
    next_person = -1 *next_person
    board_state = tic_tac_toe(next_move.state.board, next_person=next_person)
    root = MC_TreeNode(board_state, parent=None)
    mcts = MC_TreeSearch(root)
    next_move = mcts.best_move(10000)
    print(next_move.state.board)


[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]]
[[ 0.  0.  0.]
 [ 0. -1.  0.]
 [ 0.  0.  1.]]
[[ 0.  0.  0.]
 [ 0. -1.  0.]
 [ 0.  1.  1.]]
[[ 0.  0.  0.]
 [ 0. -1.  0.]
 [-1.  1.  1.]]
[[ 0.  0.  1.]
 [ 0. -1.  0.]
 [-1.  1.  1.]]
[[ 0.  0.  1.]
 [ 0. -1. -1.]
 [-1.  1.  1.]]
[[ 0.  0.  1.]
 [ 1. -1. -1.]
 [-1.  1.  1.]]
[[ 0. -1.  1.]
 [ 1. -1. -1.]
 [-1.  1.  1.]]
[[ 1. -1.  1.]
 [ 1. -1. -1.]
 [-1.  1.  1.]]


In [8]:
state = np.array([[1,1,0],
                  [-1,-1,0],
                  [1,-1,-1]])
board_state = tic_tac_toe(state=state, next_person=1)

root = MC_TreeNode(state=board_state,parent=None)
mcts = MC_TreeSearch(root)
next_move = mcts.best_move(10000)
print(next_move.state.board)

[[ 1  1  1]
 [-1 -1  0]
 [ 1 -1 -1]]
