# Application of MCTS to Simple Games
## What is Monte Carlo Tree Search?
Monte carlo tree search, or MCTS in short, is a search algorithm that uses random sampling to find the solution to a problem. It is a heuristic search algorithm that is used in decision-making and game theory. It is a best-first search algorithm that builds a tree of nodes that are visited during the search. The nodes are visited based on a selection policy. The selection policy is based on the value of the node. The value of the node is calculated using the results of random simulations. The node with the highest value is selected for further expansion. The algorithm is iterative and continues until a solution is found or a predefined number of iterations is reached.

Let's look at the algorithm applied to a simple game of tic-tac-toe. The algorithm is iterative and starts with an empty board, which is represented by the root node of the tree. The algorithm then selects a node to expand. Nodes are particular states of a game, i.e. a board configuration as shown in the figure below. The goal of the algorithm is to estimate the values of nodes. The value of the node is calculated using the results of random simulations (full games starting from that node). The game ends when it reaches a terminal node, a node from which no further move is possible. MCTS traverses the tree from the root node to a terminal node. The algorithm then backpropagates the results of the simulation to the nodes on the path from the root node to the terminal node. The value of the node is updated based on the results of the simulation. The algorithm then selects the node with the highest value for further expansion. The algorithm continues until a solution is found or a predefined number of iterations is reached.
![MCTS](https://int8.io/wp-content/uploads/2018/02/labeled-tic-tac-toe-game-tree-for-monte-carlo-tree-search.png)

Nodes are said to be fully expanded when all its children have been visited. The algorithm then selects the unvisited node with the highest value for further expansion. The algorithm continues until a solution is found or a predefined number of iterations is reached. Simulations staring on a chosen node are called rollouts. Only the starting nodes of rollouts are marked as visited. The criterion for node selection is called the rollout policy. 

## Backpropagation
The results of the simulation are backpropagated to all the nodes on the path from the terminal node to the root node. The value of the node is updated based on the results of the simulation. The value of the node is based on two main statistics:
- the number of times the node $s$ was visited, denoted by $N(s)$
- the Total simulation reward of the node, denoted by $Q(s)$ which is the sum of all rewards passing through node $s$.

The rollout policy tries to balance exploration and exploitation. The exploration term is based on the number of times the node was visited. The exploitation term is based on the total simulation reward of the node. The value of the node, also known as Upper Confidence Bound for trees ($UCT$), is calculated as follows:
\begin{equation}
UCT(s) = \frac{Q(s)}{N(s)} + c \sqrt{\frac{\log N(p)}{N(s)}}  
\end{equation}
where $p$ is the parent node of $s$ and $c$ is a constant that controls the balance between exploration and exploitation. The value of $c$ is usually set to $\sqrt{2}$.

In AlphaGo, they used a variant of the UCT shown above, that also included a prior probability choosing a move. This prior probability was calculated using a neural network.


These examples are reproduced from [this repository](https://github.com/int8/monte-carlo-tree-search)
## Tic Tac Toe

In [2]:
!pip install mctspy

Collecting mctspy
  Downloading mctspy-0.1.1-py3-none-any.whl.metadata (1.8 kB)
Downloading mctspy-0.1.1-py3-none-any.whl (6.3 kB)
Installing collected packages: mctspy
Successfully installed mctspy-0.1.1


In [3]:
import numpy as np
from mctspy.tree.nodes import TwoPlayersGameMonteCarloTreeSearchNode
from mctspy.tree.search import MonteCarloTreeSearch
from mctspy.games.examples.tictactoe import TicTacToeGameState

In [4]:
state = np.zeros((3,3))
initial_board_state = TicTacToeGameState(state = state, next_to_move=1)

root = TwoPlayersGameMonteCarloTreeSearchNode(state = initial_board_state)
mcts = MonteCarloTreeSearch(root)
best_node = mcts.best_action(10000)


In [5]:
best_node.state.board

array([[0., 0., 0.],
       [0., 1., 0.],
       [0., 0., 0.]])

In [6]:
best_node.state.next_to_move

-1

## Connect Four

In [7]:
from mctspy.games.examples.tictactoe import TicTacToeGameState, TicTacToeMove
from mctspy.games.common import TwoPlayersAbstractGameState, AbstractGameAction


In [8]:
class TicTacToeGameState(TwoPlayersAbstractGameState):

    x = 1
    o = -1

    def __init__(self, state, next_to_move=1, win=None):
        if len(state.shape) != 2 or state.shape[0] != state.shape[1]:
            raise ValueError("Only 2D square boards allowed")
        self.board = state
        self.board_size = state.shape[0]
        if win is None:
            win = self.board_size
        self.win = win
        self.next_to_move = next_to_move

    @property
    def game_result(self):
        # check if game is over
        for i in range(self.board_size - self.win + 1):
            rowsum = np.sum(self.board[i:i+self.win], 0)
            colsum = np.sum(self.board[:,i:i+self.win], 1)
            if rowsum.max() == self.win or colsum.max() == self.win:
                return self.x
            if rowsum.min() == -self.win or colsum.min() == -self.win:
                return self.o
        for i in range(self.board_size - self.win + 1):
            for j in range(self.board_size - self.win + 1):
                sub = self.board[i:i+self.win,j:j+self.win]
                diag_sum_tl = sub.trace()
                diag_sum_tr = sub[::-1].trace()        
                if diag_sum_tl == self.win or diag_sum_tr == self.win:
                    return self.x
                if diag_sum_tl == -self.win or diag_sum_tr == -self.win:
                    return self.o

        # draw
        if np.all(self.board != 0):
            return 0.

        # if not over - no result
        return None

    def is_game_over(self):
        return self.game_result is not None

    def is_move_legal(self, move):
        # check if correct player moves
        if move.value != self.next_to_move:
            return False

        # check if inside the board on x-axis
        x_in_range = (0 <= move.x_coordinate < self.board_size)
        if not x_in_range:
            return False

        # check if inside the board on y-axis
        y_in_range = (0 <= move.y_coordinate < self.board_size)
        if not y_in_range:
            return False

        # finally check if board field not occupied ye
        return self.board[move.x_coordinate, move.y_coordinate] == 0

    def move(self, move):
        if not self.is_move_legal(move):
            raise ValueError(
                "move {0} on board {1} is not legal". format(move, self.board)
            )
        new_board = np.copy(self.board)
        new_board[move.x_coordinate, move.y_coordinate] = move.value
        if self.next_to_move == self.x:
            next_to_move = self.o
        else:
            next_to_move = self.x
        return type(self)(new_board, next_to_move, self.win)

    def get_legal_actions(self):
        indices = np.where(self.board == 0)
        return [
            TicTacToeMove(coords[0], coords[1], self.next_to_move)
            for coords in list(zip(indices[0], indices[1]))
        ]

In [9]:
class Connect4GameState(TicTacToeGameState):

    def is_move_legal(self, move):
        # check if correct player moves
        if move.value != self.next_to_move:
            return False

        # check if inside the board on x-axis
        x_in_range = (0 <= move.x_coordinate < self.board_size)
        if not x_in_range:
            return False

        # check if inside the board on y-axis
        y_in_range = (0 <= move.y_coordinate < self.board_size)
        if not y_in_range:
            return False

        # finally check if board field not occupied yet
        return self.board[move.x_coordinate, move.y_coordinate] == 0 and (move.y_coordinate == 0 or self.board[move.x_coordinate, move.y_coordinate-1] != 0)

    def get_legal_actions(self):
        indices = np.where(np.count_nonzero(self.board,axis=1) != self.board_size)[0]
        # print(indices)
        return [
            TicTacToeMove(i, np.count_nonzero(self.board[i,:]), self.next_to_move)
            for i in indices
        ]

In [10]:
# define inital state
state = np.zeros((7, 7))
board_state = Connect4GameState(state=state, next_to_move=np.random.choice([-1, 1]), win=4)

# link pieces to icons
pieces = {0: " ", 1: "X", -1: "O"}


def stringify(row):
    '''print a single row of the board'''
    return " " + " | ".join(map(lambda x: pieces[int(x)], row)) + " "


def display(board):
    '''Display the whole board'''
    board = board.copy().T[::-1]
    for row in board[:-1]:
        print(stringify(row))
        print("-"*(len(row)*4-1))
    print(stringify(board[-1]))
    print()

In [11]:
display(board_state.board)
# keep playing until game terminates
while board_state.game_result is None:
    # calculate best move
    root = TwoPlayersGameMonteCarloTreeSearchNode(state=board_state)
    mcts = MonteCarloTreeSearch(root)
    best_node = mcts.best_action(simulations_number=50)#total_simulation_seconds=1)

    # update board
    board_state = best_node.state
    # display board
    display(board_state.board)

# print result
print(pieces[board_state.game_result])

   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   

   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   | X |   |   

   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
---------------------------
   |   |   |   |   |   |   
------------------