In [None]:
class SearchTree:
    class Node:
        def __init__(self, state, move=None, parent=None, children=None, w=0, n=0, next=None):
            self.state = state
            self.move = move
            self.parent = parent
            self.children = children
            self.w = 0
            self.n = 0
            self.next = next
        
    def __init__(self, board, player=2):
        self.root = self.Node(board)
        self.player = player

    def return_parent(self, node):
        return node.parent

    def return_children(self, node):
        return node.children

    def select_rollout_node(self, node):
        if node.n == 0:
            return node
        else:
            node = self.selection(node)
            return self.select_rollout_node(node)

    def branches(self, node):
        moves = __valid_positions__(node.state, self.player)
        children = [self.Node(place_piece(node.state, move[0], move[1], self.player), move, node, None, 0, 0) for move in moves]  
        return children

    def evaluation(self, state):
        static_evaluation = 100 * (np.sum(state == 2) - np.sum(state == 1)) / (np.sum(state==1) + np.sum(state==2))        
        return static_evaluation

    def ucb1(self, node, c=2):
        if node.n == 0:
            return np.inf
        exploitation = node.w / node.n
        exploration = c * (np.log(node.parent.n) / node.n)**0.5
        return exploitation + exploration       

    def selection(self, node):
        children = self.branches(node)   
        max_score = -np.inf
        best_state = None
        for child in children:
            score = self.ucb1(child)
            if score > max_score:
                best_state = child
        node.next = best_state
        return best_state

    def simulation(self, node, depth=5):          
        current_state = copy.deepcopy(node.state)
        current_piece = self.player
        d = 0
        while d < depth and len(__valid_positions__(current_state, current_piece)) != 0:            
            moves = __valid_positions__(current_state, current_piece)
            move = random.choice(moves)
            current_state = place_piece(current_state, move[0], move[1], current_piece)
            current_piece = switch_piece(current_piece)
            d += 1
        eval = self.evaluation(current_state)
        self.backpropagation(node, eval)

    def backpropagation(self, node, val):
        while node.parent != None:
            node.w += val
            node.n += 1
            node = node.parent

    def expansion(self, node):
        if node.n == 1:
            node.children = self.branches(node)

    def mcts(self, iterations=1000):
        node = self.root
        for iter in range(iterations):
            selected_node = self.select_rollout_node(node)   
            self.simulation(selected_node)
            self.expansion(selected_node)

    def return_move(self):
        self.mcts()       
        selected_node = self.selection(self.root)
        move = selected_node.move
        print("Move played : ", (move[0], move[1]))
        return move