<center><img src='https://i.postimg.cc/TPR1n1rp/AI-Tech-PL-RGB.png' height="60"></center>

AI TECH - Akademia Innowacyjnych Zastosowań Technologii Cyfrowych. Programu Operacyjnego Polska Cyfrowa na lata 2014-2020
<hr>

<center><img src='https://i.postimg.cc/Gpq2KRQz/logotypy-aitech.jpg'></center>

<center>
Projekt współfinansowany ze środków Unii Europejskiej w ramach Europejskiego Funduszu Rozwoju Regionalnego 
Program Operacyjny Polska Cyfrowa na lata 2014-2020,
Oś Priorytetowa nr 3 "Cyfrowe kompetencje społeczeństwa" Działanie  nr 3.2 "Innowacyjne rozwiązania na rzecz aktywizacji cyfrowej" 
Tytuł projektu:  „Akademia Innowacyjnych Zastosowań Technologii Cyfrowych (AI Tech)”
</center>

# Lab ??: Monte-Carlo Tree Search

In this lab, you'll play Connect4 against the Monte-Carlo Tree Search (MCTS) planner you need to implement yourself. MCTS is composed of four phases, which it iterates, that you'll implement in this order:

1. Selection,
2. Expansion,
3. Simulation,
4. Backpropagation.


![MCTS diagram](https://www.researchgate.net/profile/Jacek-Mandziuk/publication/319126544/figure/fig1/AS:614155115581472@1523437398677/MCTS-algorithm-overview-6.png)

In [11]:
import copy
import random
import time
from dataclasses import dataclass
from functools import partial

import numpy as np

def prompt_for_integer(prompt):
    while True:
        try:
            return(int(input(prompt)))
        except ValueError:
            print("[!] Please enter a valid integer")

## Connect4

* The game is played on 6 by 7 board.
* Two players alternate turns dropping one of their stones (discs) at a time into an unfilled column.
* The goal is to Connect 4 of player's stones in a row/col/diagonal while preventing its opponent from doing the same.
* If the board fills up before either player achieves four in a row/col/diagonal, then the game is a draw.

```
 | Player 1: x; Player -1: o |\
 |---------------------------||
 | 0 | 1 | 2 | 3 | 4 | 5 | 6 ||
 |---|---|---|---|---|---|---||
[[ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | x | - | - | - ]|
 [ - | - | - | x | - | - | - ]|
 [ - | - | - | o | - | - | - ]|
 [ - | - | o | x | - | o | - ]]
```

In [12]:
class Connect4Board():
    """Connect4 board with the game logic."""

    def __init__(self, pieces=None, win_length=4):
        """Initialize the board with the pieces position."""
        if pieces is None:
            self._pieces = np.zeros([6, 7]) # Default height = 6 x width = 7
        else:
            self._pieces = pieces
        self._win_length = win_length

        # Cache winner
        self._winner = self._get_winner()

    def add_stone(self, column):
        """Create copy of the board containing a new stone."""
        available_idx, = np.where(self.pieces[:, column] == 0)
        if len(available_idx) == 0:
            raise ValueError(f'Can\'t play column {column} on board {self}')

        assert not self.is_terminal, 'Can\'t plan on the terminal board'

        pieces_ = np.copy(self.pieces)
        pieces_[available_idx[-1]][column] = self.player_turn
        return Connect4Board(pieces_, self._win_length)

    def _get_winner(self):
        """Checks if the board is terminal and returns a winner.
        
        Returns:
            1 if player 1 wins.
            -1 if player -1 wins.
            0 when draw.
            None when the game not ended yet.
        """
        # Check if the previous player move won
        previous_player = self.player_turn * -1
        player_pieces = self.pieces == previous_player
        # Check rows, columns, and diagonals for a win
        if (self._is_straight_winner(player_pieces) or
            self._is_straight_winner(player_pieces.transpose()) or
            self._is_diagonal_winner(player_pieces)):
            return previous_player

        # Return a draw
        if not self.valid_moves:
            return 0

        # Game is not ended yet
        return None

    def _is_diagonal_winner(self, player_pieces):
        """Checks if `player_pieces` contains a diagonal win."""
        for i in range(len(player_pieces) - self.win_length + 1):
            for j in range(len(player_pieces[0]) - self.win_length + 1):
                if all(player_pieces[i + x][j + x] for x in range(self.win_length)):
                    return True
            for j in range(self.win_length - 1, len(player_pieces[0])):
                if all(player_pieces[i + x][j - x] for x in range(self.win_length)):
                    return True
        return False

    def _is_straight_winner(self, player_pieces):
        """Checks if `player_pieces` contains a vertical or horizontal win."""
        run_lengths = [player_pieces[:, i:i + self.win_length].sum(axis=1)
                       for i in range(len(player_pieces) - self.win_length + 2)]
        return max([x.max() for x in run_lengths]) >= self.win_length

    @property
    def is_terminal(self):
        return self.winner is not None

    @property
    def pieces(self):
        return self._pieces

    @property
    def player_turn(self):
        sum_pieces = np.sum(self.pieces)
        if sum_pieces == 0:
            return 1
        elif sum_pieces == 1:
            return -1
        else:
            raise ValueError(f'Invalid state on board {self}')

    @property
    def valid_moves(self):
        # Any zero value in top row in a valid move
        return list(np.where(self.pieces[0] == 0)[0])

    @property
    def win_length(self):
        return self._win_length

    @property
    def winner(self):
        """Returns a winner or None.

        Values:
            1 if player 1 wins.
            -1 if player -1 wins.
            0 when draw.
            None when the game not ended yet.
        """
        return self._winner

    def __str__(self):
        def piece_sign(piece):
            if piece == 1.0:
                return ' x '
            elif piece == -1.0:
                return ' o '
            else:
                return ' - '
        formatter = dict(float=piece_sign)

        return ('\n | Player 1: x; Player -1: o |\\\n' +
                ' |' + ''.join([f'----' for i in range(self.pieces.shape[1] - 1)]) + '---||\n' +
                ' |' + ''.join([f' {i} |' for i in range(self.pieces.shape[1])]) + '|\n' +
                ' |' + ''.join(['---|' for _ in range(self.pieces.shape[1])]) + '|\n' +  
                np.array2string(self.pieces, sign=' ', separator='|', formatter=formatter) + '\n')

    def __eq__(self, other):
        if not isinstance(other, Connect4Board):
            return False
        return (np.all(self.pieces == other.pieces) and
                self.win_length == other.win_length)

    def __hash__(self):
        return hash((self.pieces.tobytes(), self.win_length))

In [13]:
def human(board):
    """Returns the next board after the human move."""
    print(board)
    while True:
        column = prompt_for_integer(
            f'Player {board.player_turn} adds stone in a column [0-{board.pieces.shape[1]-1}]: ')
        try:
            return board.add_stone(column)
        except ValueError as err:
            print(f'[!] Value Error: {err}')
            continue
        except IndexError as err:
            print(f'[!] Index Error: Column {column} is an illegal move')
            continue

def play(player_first, player_second, board=None):
    """Alternately calls two players starting from `board` until termination."""
    if board is None:
        board = Connect4Board()
    
    players = [player_first, player_second]
    player_idx = 0
    while not board.is_terminal:
        board = players[player_idx](board)
        player_idx = (player_idx + 1) % 2

    print(board)
    if board.winner == 0:
        print('No one wins...')
    else:
        print(f'Player {board.winner} wins!')

In [14]:
play(human, human)


 | Player 1: x; Player -1: o |\
 |---------------------------||
 | 0 | 1 | 2 | 3 | 4 | 5 | 6 ||
 |---|---|---|---|---|---|---||
[[ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]]



KeyboardInterrupt: ignored

## Tree

Our tree will be represented by a dictionary where keys are boards and values are tree nodes. It's important to check if a tree node isn't in the tree already before creating a new one!

### Exercise

Answer the questions.

- Is it tree search or graph search?
  > Answer: it is graph search - nodes in the tree are path indifferent
- How it changes the planner performance?
  > Answer: for graph, statistics can be reused enabling faster convergence

In [15]:
tree = dict() # Keys are boards and values are tree nodes

@dataclass
class TreeNode():
    """Class for keeping track of a tree node statistics."""
    board: Connect4Board
    children: tuple = None # None means unexplored (non-expanded)
    total_reward: float = 0.0
    visit_count: int = 0

def get_node(tree, board):
    """Either gets a node from `tree` or adds and returns a new node."""
    if board in tree:
        return tree[board]
    
    node = TreeNode(board)
    tree[board] = node
    return node

## 1. Selection

Starting from the root `node`, select successive child nodes until a leaf node is reached. The root being the current game state and the leaf being either a terminal state or a node from which no simulation (playout/rollout) has yet been initiated -- we call such node "unexplored". In order to expand the game tree towards the most promising moves, we select children according to the UCT (Upper Confidence Bound applied to trees) selection rule. We choose children that that maximize:

$$
\max\big( \underbrace{\frac{w_i}{n_i}}_{\text{exploitation}} + c \cdot \underbrace{\sqrt{\frac{\ln N_i}{n_i}}}_{\text{exploration}} \big)
$$

where $w_i$ is a total reward, $n_i$ is a visit count, and $N_i$ is a parent visit count of the considered child node after $i$-th MCTS iteration.

The first component of the formula above corresponds to exploitation -- it is high for nodes with high win rate. The second component corresponds to exploration -- it is high for nodes which haven't been selected recently.

### Exercise

Implement the selection phase below.

In [117]:
def select(tree, node, expl_weight):
    """Returns a path to an unexplored (or terminal) descendent of `node`."""
    path = []
    
    path.append(node)

    while not (node.children is None or node.children == ()):
      max_score = -float('inf')
      best_child = None

      for child in node.children:
          if child.children is None:
              # If child is unexplored and not terminal, visit it
              path.append(child)
              return path

          # UCT formula
          if child.visit_count > 0:
            score = (child.total_reward / child.visit_count) + expl_weight * (
                np.sqrt(np.log(node.visit_count) / child.visit_count))

            if score > max_score:
                max_score = score
                best_child = child


      if best_child is None:  # If all children are terminal states, pick a random one
        node = random.choice(node.children)
        path.append(node)
        return path

      node = best_child
      path.append(node)

    return path



In [118]:
# TEST find unexplored

# Set-up
tree_test = dict()
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
node = get_node(tree_test, board_test)
node.children = (TreeNode(board_test), TreeNode(board_test), TreeNode(board_test))
node.children[0].children = tuple()
node.children[1].children = tuple()

# Run
path = select(tree_test, node, 2.)

# Test
assert path[0] == node
assert path[1] == node.children[2], [path[1] is child for child in node.children]

# TEST pick greedy best

# Set-up
tree_test = dict()
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
node = get_node(tree_test, board_test)
node.children = (TreeNode(board_test), TreeNode(board_test), TreeNode(board_test))
node.visit_count = 16
node.children[0].children = (TreeNode(board_test),)
node.children[0].total_reward = 3.
node.children[0].visit_count = 3
node.children[1].children = (TreeNode(board_test),)
node.children[1].total_reward = 10.
node.children[1].visit_count = 9.
node.children[2].children = (TreeNode(board_test),)
node.children[2].total_reward = 1.
node.children[2].visit_count = 5

# Run
path = select(tree_test, node, 0.)

# Test
assert path[0] == node
assert path[1] == node.children[1]
assert path[2] == node.children[1].children[0]

# TEST pick UCB best

# Set-up
# Same as above...

# Run
path = select(tree_test, node, 2.)

# Test
assert path[0] == node
assert path[1] == node.children[0]
assert path[2] == node.children[0].children[0]

## 2. Expansion

Unless the leaf `node` ends the game (i.e. win/loss/draw) for either player, create the node's children and add them to the node. Child nodes are any valid moves from the game position defined by the leaf.

> Hint: The terminal state shall have empty children tuple.

> Hint 2: It's important to check if a tree node isn't in the tree already before creating a new one!

### Exercise

Implement the expansion phase below.

In [119]:
def expand(tree, node):
    """Expands `node` children."""
    board = node.board
    if node.children is not None:
        assert board.is_terminal
        return # Already expanded, it can happen for terminal boards

    node.children = tuple(
        get_node(tree, board.add_stone(move)) for move in board.valid_moves
    ) if not board.is_terminal else tuple()

In [120]:
# TEST expand terminal

# Set-up
tree_test = dict()
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
board_test._winner = 0 # Draw is a terminal state
node = get_node(tree_test, board_test)

# Run
expand(tree_test, node)

# Test
assert node.children == tuple()

# TEST expand non-terminal

# Set-up
tree_test = dict()
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
node = get_node(tree_test, board_test)

# Run
expand(tree_test, node)

# Test
assert node.children == tuple(get_node(tree_test, board_test.add_stone(column))
                              for column in board_test.valid_moves)

## 3. Simulation

Complete one random simulation (playout/rollout) from the leaf `node` by choosing uniform random moves until the game is decided (i.e. either player win, loss, or draw). Return the winner (1 if player 1 wins; -1 if player -1 wins; 0 when draw).

### Exercise

Implement the simulation phase below.

In [121]:
def simulate(node):
    """Returns a winner of a random simulation (to completion) from `node`."""
    board = node.board
    while True:
      if board.is_terminal:
        return board.winner

      random_move = random.choice(board.valid_moves)
      board = board.add_stone(random_move)

In [122]:
# TEST simulate, winner is 0

# Set-up
tree_test = dict()
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
node = get_node(tree_test, board_test)
random.seed(42)

# Run
winner = simulate(node)

# Test
assert winner == 0

# TEST simulate, winner is 1

# Set-up
tree_test = dict()
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
node = get_node(tree_test, board_test)
random.seed(7)

# Run
winner = simulate(node)

# Test
assert winner == 1

# TEST simulate, winner is -1

# Set-up
tree_test = dict()
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
node = get_node(tree_test, board_test)
random.seed(666)

# Run
winner = simulate(node)

# Test
assert winner == -1

## 4. Backpropagation

Use the result of the simulation to update information in the nodes on the path from the leaf to the root. Invert the reward for each player so its children have positive values when it wins.

### Exercise

Implement the backpropagation phase below.

In [123]:
def backprop(path, winner):
    """Back-propagates the reward throughout `path`."""
    for node in reversed(path):
        node.visit_count += 1
        winner *= -1
        node.total_reward += winner

In [124]:
# TEST backprop, winner is 1

# Set-up
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
path = [TreeNode(board_test), TreeNode(board_test), TreeNode(board_test)]

# Run
backprop(path, 1)

# Test
assert path[0].total_reward == -1.
assert path[0].visit_count == 1
assert path[1].total_reward == 1.
assert path[1].visit_count == 1
assert path[2].total_reward == -1.
assert path[2].visit_count == 1

# TEST backprop, winner is 0

# Set-up
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
path = [TreeNode(board_test), TreeNode(board_test), TreeNode(board_test)]

# Run
backprop(path, 0)

# Test
assert path[0].total_reward == 0.
assert path[0].visit_count == 1
assert path[1].total_reward == 0.
assert path[1].visit_count == 1
assert path[2].total_reward == 0.
assert path[2].visit_count == 1

# TEST backprop, winner is -1

# Set-up
board_test = Connect4Board(np.zeros([3, 3]), win_length=3)
path = [TreeNode(board_test), TreeNode(board_test), TreeNode(board_test)]

# Run
backprop(path, -1)

# Test
assert path[0].total_reward == 1.
assert path[0].visit_count == 1
assert path[1].total_reward == -1.
assert path[1].visit_count == 1
assert path[2].total_reward == 1.
assert path[2].visit_count == 1

## MCTS

![MCTS diagram](https://www.researchgate.net/profile/Jacek-Mandziuk/publication/319126544/figure/fig1/AS:614155115581472@1523437398677/MCTS-algorithm-overview-6.png)

### Exercise

- Put the four phases together.
- Play against the MCTS planner with different number of iterations `n_iter` and exploration weights `expl_weight`.
- How it behaves? Can you beat it?

In [125]:
def mcts(tree, board, n_iter=200, expl_weight=2.):
    """Returns the best move (next board) after `n_iter` MCTS iterations."""
    root = get_node(tree, board)
    for i in range(n_iter):
        path = select(tree, root, expl_weight)

        selected_node = path[-1]
        expand(tree, selected_node)

        winner = simulate(selected_node)
        backprop(path, winner)
    
    print('DEBUG | <column>: (<value>, <count>) ' + '; '.join(
        [f'{move}: ({node.total_reward/node.visit_count:.2f}, {node.visit_count})'
         for move, node in map(lambda move: (move, tree[board.add_stone(move)]), board.valid_moves)]
    ))

    return max(root.children, key=lambda node: node.visit_count).board

In [137]:
random.seed(time.time())
tree = dict() # Keys are boards and values are tree nodes
# play(human, partial(mcts, tree, n_iter=200, expl_weight=2.)) # n_iter=200
play(human, partial(mcts, tree, n_iter=1000, expl_weight=.5))


 | Player 1: x; Player -1: o |\
 |---------------------------||
 | 0 | 1 | 2 | 3 | 4 | 5 | 6 ||
 |---|---|---|---|---|---|---||
[[ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]]

Player 1 adds stone in a column [0-6]: 1
DEBUG | <column>: (<value>, <count>) 0: (0.02, 883); 1: (-1.00, 2); 2: (-1.00, 2); 3: (-1.00, 2); 4: (-1.00, 2); 5: (-0.38, 16); 6: (-0.14, 92)

 | Player 1: x; Player -1: o |\
 |---------------------------||
 | 0 | 1 | 2 | 3 | 4 | 5 | 6 ||
 |---|---|---|---|---|---|---||
[[ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ - | - | - | - | - | - | - ]|
 [ o | x | - | - | - | - | - ]]

Player 1 adds stone in a column [0-6]: 1
DEBUG | <column>: (<value>, <count>) 0: (-0.21, 38); 1: (-0.04, 421); 2: (-1.00, 2); 3: (-1.00, 2); 4: (-0.06, 293); 5: (0.03, 231

### Extra

Calculate and print additional statistics after each MCTS move:

- Tree size (nodes number),
- Depth of the tree (longest path),
- Breadth of the tree (leafs number),
- Values of the root children,
- Counts of the root children,
- ...?