In [1]:
import numpy as np
from connect_four import Game, GameType 

In [2]:
game_init = np.array(
    [
        [2, 2, 2, 1, 0, 1, 0],
        [2, 1, 1, 1, 0, 2, 0],
        [1, 2, 2, 2, 0, 1, 0],
        [2, 1, 1, 1, 0, 2, 0],
        [1, 1, 1, 2, 0, 2, 0],
        [2, 2, 1, 2, 0, 1, 0],
    ]
)

In [3]:
game = Game(game_state=game_init, game_type=GameType.MCTS_VS_RANDOM, mcts_maxiter=1000)

#### Running the MCTS for several "epochs"
The tree nodes store the state of the game when they are created. This states are the result from (first) the action chosen to create the node plus (second) the random action of player 2. This node states are frozen and kept untouched. This means we have a limited view of all the possible outcomes. 

We run the MCTS for sevral "epochs" before choosing an action to take. Each epoch consists of a single independent run of the tree search algorithm. This means that we will have as many action-value pairs as epochs (for the root node). We take the average of the q-values to select an action. 


In [16]:
from typing import Dict
from connect_four.utils import display_circles


class QValuesDict(dict):
    """This keeps a running average of q-values returned by mcts"""

    def __init__(self, *args, **kwargs) -> None:
        self._updates = 0
        self._cache = []
        return super().__init__(*args, **kwargs)

    def update(self, new: Dict[int, float]) -> None:
        self._cache.append(new)
        for key in new:
            # Update the running average of each action
            # This is: \bar{x}_{N+1} = 1 / (N + 1) * (N * \bar{x}_N + x_{N + 1}) 
            # Where \bar{x}_N is the mean value of x given N updates
            self[key] = 1 / (self._updates + 1) * (self._updates * self.get(key, 0.0) + new[key])

        self._updates += 1


def select_best_action_dope(self, epochs: int = 25):
    print("Running tree search to choose action ...")
    qvalues = QValuesDict()
    for e in range(epochs):
        print(f"\tRunning epoch {e} to find best action ...")
        epoch_qvalues = self.mcts.run(game_state=self.game_board.snapshot())
        qvalues.update(epoch_qvalues)
        print(qvalues)
    print(f"Qvalues: {qvalues}")
    print(f"Choosing {int(max(qvalues, key=qvalues.get))}")
    return int(max(qvalues, key=qvalues.get))

def play(self):

    self.init_gameboard()

    while not self.game_board.is_finished:
        self.game_board.play(
            first_action=self.first_move(), second_action=self.second_move()
        )

    winner = self.game_board.check_winner()
    print(f"The winner of the game is: {winner if winner else 'DRAW'}")



In [17]:
Game.select_best_action = select_best_action_dope
Game.play = play

In [18]:
game.play()

Running tree search to choose action ...
	Running epoch 0 to find best action ...
{np.int64(4): 0.7142857142857143, np.int64(6): -0.247557003257329}
	Running epoch 1 to find best action ...
{np.int64(4): 0.7093912511471399, np.int64(6): 0.3086792844876582}
	Running epoch 2 to find best action ...
{np.int64(4): 0.7098752919294185, np.int64(6): 0.44615803560797396}
	Running epoch 3 to find best action ...
{np.int64(4): 0.7085428325834274, np.int64(6): 0.4770973402653025}
	Running epoch 4 to find best action ...
{np.int64(4): 0.7089218081542841, np.int64(6): 0.4329093993058381}
	Running epoch 5 to find best action ...
{np.int64(4): 0.7081999765224489, np.int64(6): 0.4783263030287461}
	Running epoch 6 to find best action ...
{np.int64(4): 0.7088649452949395, np.int64(6): 0.4431254310962462}
	Running epoch 7 to find best action ...
{np.int64(4): 0.7282334808295701, np.int64(6): 0.48752899089234303}
	Running epoch 8 to find best action ...
{np.int64(4): 0.7266837289913638, np.int64(6): 0.476