# Monte Carlo Tree search

The four steps of a monte carlo tree search algorithm:
    
#### 1) Select
At each depth of the tree from the root node, select the node that has the highest Upper Confidence Bound (UCB) score and continue that trajectory selected down the tree until a leaf is reached.
   
#### 2) Expand
From the leaf chosen by maximising the UCB at every level, expand the leaf node by one time step, calculating all possible children nodes and their prior probabilities.

*example: in Connect-X, the prior for the children is chosen as uniformly distributed across all possible next moves == 1/N with N=number of possible moves*

#### 3) Simulate
For each children node, simulate all future actions (and environment reactions) by using a RANDOM decision policy, until:
 - end of game/problem is reached if game/problem HAS a FINITE END 
 - D time steps are taken in the future, where D is a GLOBAL parameter chosen by the user of the model which represent the RANGE in the future over which we want to optimise our Long Range Reward *(i.e. the truncated end of the summation of rewards in the future)*.
 
#### 4) Backpropagate
For each of the trajectories obtained until D (or until end of game) from each of the children nodes, backpropagate (update) the samples of the node transversed all the way back to the root node, recalculating UCB scores.

Once samples (and UCB scores from samples) are updated for all relevant nodes, repeat whole loop from step 1 for as many simulations as we want. 

# Connect X

In [1]:
import copy
import itertools
import math
import random
import time

In [2]:
import numpy as np

In [3]:
# --- globals
NUM_COLS = 4
NUM_ROWS = 1
CONNECTX = 2
EMPTY = 0

In [4]:
def window(seq, n=CONNECTX):
    "Returns a sliding window (of width n) over data from the iterable"
    "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
    it = iter(seq)
    result = tuple(itertools.islice(it, n))
    if len(result) == n:
        yield result
    for elem in it:
        result = result[1:] + (elem,)
        yield result

In [5]:
# # test window function
# tmp = [1,2,3,4,5,6]
# list(window(tmp, n=3))

In [6]:
def all_equal(iterator):
    iterator = iter(iterator)
    try:
        first = next(iterator)
    except StopIteration:
        return True
    return all(first == x for x in iterator)

def all_full(iterator):
    return EMPTY not in iterator

##### Test functions 

In [7]:
# all_equal(window(tmp))

In [8]:
# # test all_equal: it is to check when X elements (for one window) are equal == to win the game
# tmp2 = [0,1,1,0]
# print(list(window(tmp2)))
# print(all_equal(list(window(tmp2))[1]))
# print(all_equal(window(tmp2)))

## Game class

In [9]:
class Game:
    def __init__(self, num_cols=NUM_COLS, num_rows=NUM_ROWS, num_connect=CONNECTX):
        self.num_cols = num_cols
        self.num_rows = num_rows
        self.num_connect = num_connect
        cols = [EMPTY,] * num_cols
        self.board = [list(cols) for _ in range(num_rows)]
        self._heights = [0,] * num_cols
        self.PIECES = {"X": "O", "O": "X"} # current player: next player
        
        self.last_played = None
    
    @property
    def heights(self):
        return self._heights
    
    def place(self, piece, col):
        
        error_msg = f"Only allowed players are :{list(self.PIECES.keys())}\n"
        assert piece in self.PIECES, error_msg
        
        if self.last_played is not None:
            error_msg = (
                f"Next player need to be :{self.PIECES[self.last_played]}\n"
                f"Passed instead:{piece}\n"
            )
            assert piece == self.PIECES[self.last_played], error_msg            
        
        row = self.heights[col] # latest empty height
        self.board[row][col] = piece # fill the latest height with piece
        self._heights[col] += 1  # update the height
        
        self.last_played = piece
    
    def is_win(self):
        board = list(self.board)
        x = self.num_connect
        if self.is_horizontal_win(board, x) or self.is_diagonal_win(board, x):
            return True
        else:
            board = self.rotate(board)
        
        return (
            self.is_horizontal_win(board, x)
            or self.is_diagonal_win(board, x)
        )
    
    def get_valid_moves(self):
        upper_row = list(next(reversed(self.board)))
        return [n for n, val in enumerate(upper_row) if val == EMPTY]
    
    @staticmethod
    def rotate(board):
        """
        Rotate the board anti-clockwise
        """
        return [list(row) for row in zip(*board)][::-1]
    
    @staticmethod
    def is_horizontal_win(board, num_connect):
        for row in reversed(board):
            for quad in window(row, n=num_connect):
                # we are looping via a window that is size of Connect (not of board)
                # all_full checks that the windows has items in it
                if all_full(quad) and all_equal(quad):
                    return True

        return False
    
    @staticmethod
    def is_diagonal_win(board, num_connect):
        num_rows = len(board)
        num_cols = len(board[0])
        for row in range(num_rows + 1 - num_connect):
            for col in range(num_cols + 1 - num_connect):
                diag = [
                    board[row + i][col + i]
                    for i in range(num_connect)
                ]
                if all_full(diag) and all_equal(diag):
                    return True

        return False

    def is_terminal(self):
        is_win = self.is_win()
        is_full = not self.get_valid_moves()
        return is_win or is_full
    
    def get_reward(self):
        """
        Reward is:
        -1: If lose
         0: If draw or if game is not done.
        """
        # I will only know if I lose, because I ask is_win after opponents has done his move
        # and BEFORE I move. If I win, opponent will ask is_win and realises he lost (-1)

        is_win = self.is_win()
        if is_win:
            return -1
        return 0
    
    def __repr__(self):
        lines = []
        for row in reversed(self.board):
            line = [f"{'' if val == EMPTY else val:^3}" for val in row]
            line = "|" + "|".join(line) + "|"
            div = itertools.islice(itertools.cycle("+---"), len(line))
            lines.append("".join(div))
            lines.append(line)

        div = itertools.islice(itertools.cycle("+---"), len(line))
        lines.append("".join(div))
        
        return "\n".join(lines)

##### Test game class

In [10]:
simulated_game = Game(num_rows=1)
simulated_game.place("X", 1)
# display(simulated_game)
# print()

for _ in range(3):

    valid_moves = simulated_game.get_valid_moves()
    num_moves = len(valid_moves)
    print('valid moves:', num_moves)

    random_move = random.choice(valid_moves)
    print('random move:', random_move)

    next_player = simulated_game.PIECES[simulated_game.last_played]
    print(f'next_player:"{next_player}"')
    print()
    simulated_game.place(next_player, random_move) 
simulated_game


valid moves: 3
random move: 2
next_player:"O"

valid moves: 2
random move: 3
next_player:"X"

valid moves: 1
random move: 0
next_player:"O"



+---+---+---+---+
| O | X | O | X |
+---+---+---+---+

In [14]:
print(simulated_game.is_win())
print(simulated_game.is_terminal())

False
True


In [11]:
# game_obj.place('X', 2)
# print(game_obj)
# print()
# game_obj.place('O', 2)
# print(game_obj)

In [None]:
# game_obj.is_horizontal_win(game_obj.board, num_connect=game_obj.num_connect)

In [None]:
# list(next(reversed(game_obj.board)))

In [None]:
# num_cols = 4
# cols = [EMPTY,] * num_cols
# cols

In [None]:
# num_rows=3
# board = [list(cols) for _ in range(num_rows)]
# board

## Monte Carlo Tree Search

In [15]:
from anytree import NodeMixin, RenderTree

In [16]:
def ucb_score(node):
    """
    The score for an action that would transition between the parent and child.
    """
    if node.is_root:
        raise ValueError("Cannot compute UCB score for a root node.")

    ratio = math.sqrt(node.parent.visit_count) / (node.visit_count + 1)
    prior_score = node.prior * ratio
    return -node.value + prior_score


In [None]:
# obj = NodeMixin()

In [None]:
# [x for x in dir(obj) if not x.startswith('_')]

'depth' is distance from 'root' node (eg above me) -- but these trees are upside down
'height' is distance from 'me' to the end of the tree (the 'leaf' node) 
'siblings' is if you have node of same 'depth'

In [18]:
# sum([])

In [112]:
class Node(NodeMixin):
    DISCOUNT_FACTOR = 0.8 # weight of future rewards relative to the present
    
    def __init__(self, game, to_play, prior=0, action=None):
        super(Node, self).__init__()
        self.game = game
        
#         error_msg = f"Only allowed players are :{list(self.game.PIECES.keys())}\n"
#         assert to_play in self.game.PIECES, error_msg
                    
        self.to_play = to_play
        self.prior = prior # how much probability there is to reach this node from its parent
        self.action = action # what is the action that brought me to this node
        
        # instead of keeping track of samples, we simply keep track of sum of values and counts
        # because we only care about the average
        self.samples = []
        self.visit_count = 0
    
    def get_next_board(self, move):
        game = copy.deepcopy(self.game)
        game.place(self.to_play, move)
        return game
    
    
    @property
    def value(self):
        """
        Estimate of the expected long term reward.
        """
        # this is just the values of the sample across time steps (long term reward)
        try:
            return sum(self.samples) / self.visit_count
        except ZeroDivisionError:
            return 0

    @property
    def best_child(self):
        # children with highest ucb_score, so the one that we go next for the simulation
        return max(self.children, key=ucb_score, default=None)

    def attach_child(self, node):
        self.children = self.children + (node,)

    def __repr__(self):
        """
        Debugger pretty print node info
        """
        prior = f"{self.prior:.2g}"
        
        try:
            ucb_val = ucb_score(self)
            ucb_val = f"{ucb_val:.3g}"
        except ValueError:
            ucb_val = None
            
        return (
            f"{self.game} - Prior: {prior}  "
            f"Count: {self.visit_count}  Value(L): {self.value:.3g}  "
            f" UCB: {ucb_val}  Next_player: '{self.to_play}'"
        )

    
    # --- Functions for MCTS:
    def select(self):
        node = self
        while not node.is_leaf:
            node = node.best_child

        return node
    
    def expand(self):
        if not self.is_leaf:
            # already expanded
            return
        
        if self.game.is_terminal():
            # there are no more moves in the games, it has finished
            # either because of no more spaces, or because of winning
            return

        valid_moves = self.game.get_valid_moves()
        num_moves = len(valid_moves)
        for move in valid_moves:
            next_game = self.get_next_board(move) 
            next_player = self.game.PIECES[self.to_play] # switch the player
            child = Node(game=next_game, to_play=next_player, prior=1/num_moves, action=move)
            self.attach_child(child)
        
    def simulate(self, max_tree_depth=20):
        
        # max_tree_depth: future timestep summation end term for calculating long term reward

        if not self.is_leaf:
            # already expanded
            raise ValueError(f"How come this node:\n{self}\n has already been expanded?")
        
        i = 0
        current_depth = self.depth

        simulated_game = copy.deepcopy(self.game)
        
        rewards = []
        discounts = []
        signs = []
        while((i<=(max_tree_depth-current_depth)) and (not simulated_game.is_terminal())):

            valid_moves = simulated_game.get_valid_moves()
            num_moves = len(valid_moves)
#             print('valid moves:', num_moves)
            
            random_move = random.choice(valid_moves)
#             print('random move:', random_move)

            next_player = simulated_game.PIECES[simulated_game.last_played]
#             print(f'player:"{next_player}"')
#             print()
            simulated_game.place(next_player, random_move) 
            
            rewards.append(simulated_game.get_reward())
            discounts.append(self.DISCOUNT_FACTOR**i)
            signs.append((-1)**i)
#             display(simulated_game)
            
            i+=1
#             print()
#             print(i)

        return rewards, discounts, signs
        

    def backpropagate(self, rewards:list, discounts:list, signs:list):
                
        gamma = self.DISCOUNT_FACTOR

        for n, node in enumerate(self.iter_path_reverse()):
            # it goes through the parent recursively back until root. 
            # the parent in MonteC trees are UNIQUE.
            
            immediate_reward = node.game.get_reward()
            
            rewards = [immediate_reward] + rewards
            discounts = [1] + [x*gamma for x in discounts]
            signs = [1] + [x*(-1) for x in signs]
            
            normalisation = sum(discounts)
            
            L = sum(r*d*s for r,d,s in zip(rewards,discounts,signs))/normalisation
            node.samples.append(L)
                        
            node.visit_count += 1
            
        # WE HAVE A SINGLE TREE FOR BOTH PLAYERS!!!! hence we need to change signs alternatively
        # one can never win at present (so reward!=1 because of gamma), but one can lose at present
            
    # --- for visuals:
    def _render_str(self, maxlevel=None):
        result = ""
        for pre, _, node in RenderTree(self, maxlevel=maxlevel):
            result += (f"{pre!s}{node.game!r}\n")

        return result

    def render(self, maxlevel=None):
        print(self._render_str())

In [53]:
one_game = Game()

# one_game.board   # Numeric representation of board
one_game   # Nice visual representation of board

+---+---+---+---+
|   |   |   |   |
+---+---+---+---+

In [54]:
root_node = Node(one_game, 'X')
root_node

+---+---+---+---+
|   |   |   |   |
+---+---+---+---+ - Prior: 0  Count: 0  Value(L): 0   UCB: None  Next_player: 'X'

In [55]:
root_node.select()

+---+---+---+---+
|   |   |   |   |
+---+---+---+---+ - Prior: 0  Count: 0  Value(L): 0   UCB: None  Next_player: 'X'

In [56]:
root_node.expand()

In [57]:
root_node.children

(+---+---+---+---+
 | X |   |   |   |
 +---+---+---+---+ - Prior: 0.25  Count: 0  Value(L): 0   UCB: 0  Next_player: 'O',
 +---+---+---+---+
 |   | X |   |   |
 +---+---+---+---+ - Prior: 0.25  Count: 0  Value(L): 0   UCB: 0  Next_player: 'O',
 +---+---+---+---+
 |   |   | X |   |
 +---+---+---+---+ - Prior: 0.25  Count: 0  Value(L): 0   UCB: 0  Next_player: 'O',
 +---+---+---+---+
 |   |   |   | X |
 +---+---+---+---+ - Prior: 0.25  Count: 0  Value(L): 0   UCB: 0  Next_player: 'O')

In [58]:
node = root_node.children[0]

In [59]:
node

+---+---+---+---+
| X |   |   |   |
+---+---+---+---+ - Prior: 0.25  Count: 0  Value(L): 0   UCB: 0  Next_player: 'O'

In [60]:
r, d, s = node.simulate()

valid moves: 3
random move: 2
player:"O"



+---+---+---+---+
| X |   | O |   |
+---+---+---+---+


1
valid moves: 2
random move: 1
player:"X"



+---+---+---+---+
| X | X | O |   |
+---+---+---+---+


2


In [61]:
print(r)
print(d)
print(s)

[0, -1]
[1.0, 0.8]
[1, -1]


In [62]:
node

+---+---+---+---+
| X |   |   |   |
+---+---+---+---+ - Prior: 0.25  Count: 0  Value(L): 0   UCB: 0  Next_player: 'O'

In [63]:
node.backpropagate(rewards=r, discounts=d, signs=s)

In [64]:
node

+---+---+---+---+
| X |   |   |   |
+---+---+---+---+ - Prior: 0.25  Count: 1  Value(L): -0.262   UCB: 0.387  Next_player: 'O'

In [44]:
# root_node.children[0].game.is_win()

In [46]:
# node.get_next_board(node.game.get_valid_moves()[0])

In [47]:
# list(node.iter_path_reverse())

In [None]:
# type(node.get_next_board(node.game.get_valid_moves()[0]))

In [None]:
# root_node.children[0].descendants

In [None]:
# root_node.descendants

### MonteCarlo tree search steps

In [82]:
game = Game()

In [83]:
# game.place("X", 2)

In [84]:
root_node = Node(game, 'O')

In [85]:
for _ in range(1000):
    node = root_node.select()
    node.expand()
    if node.is_leaf:
        node.backpropagate([], [], [])
    else:
        for child in node.children:
            r, d, s = child.simulate()
            child.backpropagate(rewards=r, discounts=d, signs=s)

what my mind thinks: what is the value of making this move as player x

what it is instead: the value of having x in that position as player 'o'

In [86]:
root_node

+---+---+---+---+
|   |   | X |   |
+---+---+---+---+ - Prior: 0  Count: 1005  Value(L): -0.258   UCB: None  Next_player: 'O'

In [79]:
# root_node.samples

In [None]:
After the simulation has finished, what determines which

In [87]:
root_node.children

(+---+---+---+---+
 | O |   | X |   |
 +---+---+---+---+ - Prior: 0.33  Count: 267  Value(L): 0.444   UCB: -0.405  Next_player: 'X',
 +---+---+---+---+
 |   | O | X |   |
 +---+---+---+---+ - Prior: 0.33  Count: 369  Value(L): 0.434   UCB: -0.405  Next_player: 'X',
 +---+---+---+---+
 |   |   | X | O |
 +---+---+---+---+ - Prior: 0.33  Count: 369  Value(L): 0.434   UCB: -0.405  Next_player: 'X')

In [71]:
# root_node.children[1]

+---+---+---+---+
|   | X |   |   |
+---+---+---+---+ - Prior: 0.25  Count: 10  Value(L): -0.21   UCB: 0.326  Next_player: 'O'

In [69]:
# root_node.children[1].children

(+---+---+---+---+
 | O | X |   |   |
 +---+---+---+---+ - Prior: 0.33  Count: 3  Value(L): 0.296   UCB: -0.0328  Next_player: 'X',
 +---+---+---+---+
 |   | X | O |   |
 +---+---+---+---+ - Prior: 0.33  Count: 3  Value(L): 0.296   UCB: -0.0328  Next_player: 'X',
 +---+---+---+---+
 |   | X |   | O |
 +---+---+---+---+ - Prior: 0.33  Count: 3  Value(L): 0.444   UCB: -0.181  Next_player: 'X')

In [None]:
# THE END OF THE TREE
# root_node.leaves

# PLAYING THE GAME 


In [89]:
from IPython.display import clear_output

In [118]:
from abc import ABC, abstractmethod

In [None]:
class GeneralAgent(ABC):

    def __init__(self, piece, name=None):
        self.piece = piece
        if name is None:
            self.name = f"{self.__class__.name}-{piece}"
    
    @abstractmethod
    def choose_action(self, game):
        action = random.choice(game.get_valid_moves())
        return action

In [113]:
class RandomAgent:
    # defines a player (==agent)
    def __init__(self, piece):
        self.piece = piece
    
    def choose_action(self, game):
        action = random.choice(game.get_valid_moves())
        return action

class MCTSAgent:
    # defines a player
    def __init__(self, piece, num_simulations=1000, max_tree_depth=20):
        self.piece = piece
        self.num_simulations = num_simulations
        self.max_tree_depth = max_tree_depth
    
    def choose_action(self, game):
        """
        Runs MCTS
        """
        root_node = Node(game, self.piece)
        for _ in range(self.num_simulations):
            node = root_node.select()
            node.expand()
            if node.is_leaf:
                node.backpropagate([], [], [])
            else:
                for child in node.children:
                    r, d, s = child.simulate(max_tree_depth=self.max_tree_depth)
                    child.backpropagate(rewards=r, discounts=d, signs=s)

        best = max(root_node.children, key=lambda x: x.visit_count)
#         best = random.choices(root.children, weights=[x.visit_count for x in root.children])[0]
        return best.action

    
class HumanAgent:
    def __init__(self, piece):
        self.piece = piece
    
    def choose_action(self, game):
        return int(input(f"Choose from {game.get_valid_moves()}:"))

In [91]:
# tmp = {'a':3, 'b':3, 'c':2}
# max(tmp, key=lambda x: tmp[x])

'a'

###  MCTS vs random

In [95]:
import time

In [116]:
# game = Game()  # only connect-2, with 4 columns and 1 row

game = Game(num_cols=7, num_rows=6, num_connect=4)
player1 = RandomAgent("O")
player2 = MCTSAgent("X", num_simulations=100)
turns_randomfirst = itertools.cycle([player1, player2]) # cycling the players
turns_mctsfirst = itertools.cycle([player2, player1]) # cycling the players


turns = turns_mctsfirst
# turns = turns_randomfirst

print(game)
while True:
    player = next(turns)
    action = player.choose_action(game)
    game.place(player.piece, action)
    clear_output(wait=True)
    print(game)
    print(f"Played: {action}, by {player.__class__.__name__}")
    if game.is_terminal():
        if game.is_win():
            print(f"Player {player.__class__.__name__} won!")
        else:
            print("Game finished with draw")
        break
    time.sleep(1)

+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   | O |   |   |   |
+---+---+---+---+---+---+---+
| O | O | X | X | X | X |   |
+---+---+---+---+---+---+---+
Played: 5, by MCTSAgent
Player MCTSAgent won!


### MCTS vs HUMAN

In [115]:
game = Game(num_cols=7, num_rows=6, num_connect=4)
player1 = HumanAgent("O")
player2 = MCTSAgent("X", num_simulations=1000, max_tree_depth=42)
turns_randomfirst = itertools.cycle([player1, player2]) # cycling the players
turns_mctsfirst = itertools.cycle([player2, player1]) # cycling the players


turns = turns_mctsfirst
# turns = turns_randomfirst

print(game)
while True:
    player = next(turns)
    action = player.choose_action(game)
    game.place(player.piece, action)
    clear_output(wait=True)
    print(game)
    print(f"Played: {action}, by {player.__class__.__name__}")
    if game.is_terminal():
        if game.is_win():
            print(f"Player {player.__class__.__name__} won!")
        else:
            print("Game finished with draw")
        break
    time.sleep(1)

+---+---+---+---+---+---+---+
| O | O |   | O |   | O | X |
+---+---+---+---+---+---+---+
| O | X |   | X |   | X | O |
+---+---+---+---+---+---+---+
| X | O | O | X |   | X | O |
+---+---+---+---+---+---+---+
| O | X | X | O |   | X | X |
+---+---+---+---+---+---+---+
| X | O | O | X | O | O | X |
+---+---+---+---+---+---+---+
| X | O | O | X | X | O | X |
+---+---+---+---+---+---+---+
Played: 2, by HumanAgent
Player HumanAgent won!


In [117]:
game = Game(num_cols=7, num_rows=6, num_connect=4)
player1 = MCTSAgent("O", num_simulations=1000, max_tree_depth=20)
player2 = MCTSAgent("X", num_simulations=1000, max_tree_depth=42)
turns_randomfirst = itertools.cycle([player1, player2]) # cycling the players
turns_mctsfirst = itertools.cycle([player2, player1]) # cycling the players


turns = turns_mctsfirst
# turns = turns_randomfirst

print(game)
while True:
    player = next(turns)
    action = player.choose_action(game)
    game.place(player.piece, action)
    clear_output(wait=True)
    print(game)
    print(f"Played: {action}, by {player.__class__.__name__}")
    if game.is_terminal():
        if game.is_win():
            print(f"Player {player.__class__.__name__} won!")
        else:
            print("Game finished with draw")
        break
    time.sleep(1)

+---+---+---+---+---+---+---+
|   |   | X |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | O | X |   | O |   |
+---+---+---+---+---+---+---+
| X |   | O | O |   | O |   |
+---+---+---+---+---+---+---+
| O |   | O | O |   | O |   |
+---+---+---+---+---+---+---+
| X | X | X | X |   | X |   |
+---+---+---+---+---+---+---+
| X | X | O | X |   | X | O |
+---+---+---+---+---+---+---+
Played: 1, by MCTSAgent
Player MCTSAgent won!


In [None]:
game = Game(7, 6, 4)
player1 = HumanAgent("X")
player2 = MCTSAgent("O", num_simulations=1000)
turns = itertools.cycle([player1, player2])

In [None]:


print(game)
while True:
    player = next(turns)
    action = player.choose_action(game)
    game.place(player.piece, action)
    clear_output(wait=True)
    print(game)
    print(f"Played: {action}, by {player.__class__.__name__}")
    if game.is_win():
        print(f"Player {player.__class__.__name__} won!")
        break

In [None]:
node = root_node.select()
node

In [None]:
node.path

In [None]:
node.expand()
node.children

In [None]:
for child in node.children:
    child.backpropagate()

In [None]:
root_node   # perspective of X player

In [None]:
root_node.children   # perspective of O player

In [None]:
root_node.children[1].children

# Example

In [None]:
from IPython.display import clear_output

In [None]:
class RandomAgent:
    def __init__(self, piece):
        self.piece = piece
    
    def choose_action(self, game):
        action = random.choice(game.get_valid_moves())
        return action

class MCTSAgent:
    def __init__(self, piece, num_simulations=1000):
        self.piece = piece
        self.num_simulations = num_simulations
    
    def choose_action(self, game):
        """
        Runs MCTS
        """
        root = Node(game, self.piece)
        for _ in range(self.num_simulations):
            node = root.select()
            node.expand()
            # TODO maybe change back to children
            node.backpropagate()

        best = max(root.children, key=lambda x: x.visit_count)
#         best = random.choices(root.children, weights=[x.visit_count for x in root.children])[0]
        return best.action

    
class HumanAgent:
    def __init__(self, piece):
        self.piece = piece
    
    def choose_action(self, game):
        return int(input(f"Choose from {game.get_valid_moves()}:"))

In [None]:
game = Game(7, 6, 4)
player1 = HumanAgent("X")
player2 = MCTSAgent("O", num_simulations=1000)
turns = itertools.cycle([player1, player2])

In [None]:
print(game)
while True:
    player = next(turns)
    action = player.choose_action(game)
    game.place(player.piece, action)
    clear_output(wait=True)
    print(game)
    print(f"Played: {action}, by {player.__class__.__name__}")
    if game.is_win():
        print(f"Player {player.__class__.__name__} won!")
        break

In [None]:
game