# Adversarial sequential games

## Tic-Tac-Toe!

In [None]:
from copy import deepcopy
from enum import Enum
from typing import Any, NamedTuple, Tuple

import numpy as np
from numpy.typing import ArrayLike

# %matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

TIC_TAC_TOE = """
 | | 
-----
 | | 
-----
 | | 
"""

class State:
    def __init__(self, array: ArrayLike):
        self._array = array
        
    def __hash__(self):
        return hash(tuple(self._array.astype(int).flatten()))
        
    def __eq__(self, other):
        return np.all(np.equal(self._array, other._array))
    
    def __str__(self):
        return str(self._array)
    
    def __repr__(self):
        return repr(self._array)
    
    @property
    def array(self):
        return self._array
    
    def copy(self):
        return State(self._array.copy())
    
class Player(Enum):
    CrossPlayer = 1
    CirclePlayer = 2
    
class Action(NamedTuple):
    x: int
    y: int

class TicTacToe:
    def __init__(self, tic_tac_toe_str: str = TIC_TAC_TOE):
        tic_tac_toe = []
        for y, line in enumerate(tic_tac_toe_str.split("\n")):
            row = []
            for c in line:
                if c == " ":
                    row.append(0)  # spaces are 0s
                else:
                    row.append(0.7)  # walls are 0.7s
            tic_tac_toe.append(row)
        self._tic_tac_toe = np.array(tic_tac_toe[1:-1], dtype=np.float32)
        self._ax = None
        self._fig = None
        self._image = None
        
    def reset(self):
        return State(np.zeros(shape=[3, 3], dtype=np.int32))
    
    def get_next_state(self, state: State, action: Tuple[Player, Action]) -> Tuple[State, float, bool]:
        assert state.array[action[1].x][action[1].y] == 0
        next_state = state.copy()
        next_state.array[action[1].x][action[1].y] = action[0].value
        if (np.asarray(np.prod(next_state.array, axis=0) == action[0].value**3).sum() > 0 or
            np.asarray(np.prod(next_state.array, axis=1) == action[0].value**3).sum() > 0 or
            np.prod(np.diagonal(next_state.array)) == action[0].value**3 or
            np.prod(np.diagonal(np.fliplr(next_state.array))) == action[0].value**3):
            return next_state, -2 * action[0].value + 3, True
        elif np.sum(np.asarray(next_state.array == 0).nonzero()) == 0:
            return next_state, 0, True
        else:
            return next_state, 0, False
    
    def render(self, state: State) -> Any:
        if self._ax is None:
            fig, ax = plt.subplots(1)
            fig.canvas.set_window_title("tic-tac-toe")
            ax.set_aspect("equal")  # set the x and y axes to the same scale
            plt.xticks([])  # remove the tick marks by setting to an empty list
            plt.yticks([])  # remove the tick marks by setting to an empty list
            ax.invert_yaxis()  # invert the y-axis so the first row of data is at the top
            self._ax = ax
            self._fig = fig
            plt.ion()
        if self._image is None:
            self._image = self._ax.imshow(self._tic_tac_toe, cmap='Greys', vmin=0, vmax=1)
        else:
            self._image.set_data(self._tic_tac_toe)
        for row in range(3):
            for col in range(3):
                if state.array[row][col] == 1:
                    self._ax.scatter(2 * col, 2 * row, s=500, c='blue', marker='x')
                elif state.array[row][col] == 2:
                    self._ax.scatter(2 * col, 2 * row, s=500, facecolors='none', edgecolors='red')
        display(self._fig)
        clear_output(wait = True)
        plt.pause(1)
        
tic_tac_toe = TicTacToe()
tic_tac_toe.render(tic_tac_toe.reset())

## Minimax algorithm

![minimax](./img/Minimax.png)

In [None]:
from __future__ import annotations
from typing import Dict
from typing import Any, Tuple, List

class Tree:
    class Node:
        def __init__(self,
                     data: Any,
                     max_player: bool = True,
                     terminal: bool = False,
                     terminal_value: float = 0,
                     best_child: Tree.Node = None):
            self._data = data
            self._max_player = max_player
            self._terminal = terminal
            self._terminal_value = terminal_value
            self._best_child = best_child
            self._children: List[Tuple[Tree.Node, str]] = []
            
        @property
        def data(self):
            return self._data
        
        @property
        def max_player(self):
            return self._max_player
        
        @property
        def terminal(self):
            return self._terminal
        
        @property
        def terminal_value(self):
            return self._terminal_value
        
        @property
        def best_child(self):
            return self._best_child
        
        def __eq__(self, other: Tree.Node):
            return self._data.__eq__(other._data)
        
        def __hash__(self):
            return hash(self._data)
        
        def __str__(self):
            return str(self._data)
        
        def __repr__(self):
            return 'Node(data: {}, max_player: {}, terminal: {}, best child: {})'.format(
                repr(self._data),
                'true' if self._max_player else 'false',
                'true [{}]'.format(self._terminal_value) if self._terminal else 'false',
                repr(self._child._data) if self._child is not None else None)
            
    def __init__(self):
        self._nodes: Dict[Any, Tree.Node] = {}
    
    def get_node(self, data: Any):
        if data not in self._nodes:
            self._nodes[data] = Tree.Node(data)
        return self._nodes[data]
        
    def get_children(self, node: Node) -> List[Tuple[Node, str]]:
        if node.data not in self._nodes or len(node._children) == 0:
            node._children = list(self.generate_children(node))
            assert all((c[0].max_player and not node.max_player) or
                       (not c[0].max_player and node.max_player)
                       for c in node._children)
            self._nodes[node.data] = node
        return self._nodes[node.data]._children
    
    def generate_children(self, node: Node) -> List[Tuple[Node, str]]:
        raise NotImplementedError
    
    def is_terminal(self, node: Node) -> bool:
        return node.terminal
    
    def render(self, node: Node) -> None:
        pass

In [None]:
from typing import Callable

def minimax(node : Tree.Node,
            tree: Tree,
            depth : int,
            maximizing_player : bool,
            evaluate : Callable[[Tree.Node], float]):
    if depth == 0 or tree.is_terminal(node):
        return evaluate(node)
    if maximizing_player:
        value = -float('inf')
        for child in tree.get_children(node):
            tentative = minimax(child[0], tree, depth - 1, False, evaluate)
            if tentative >= value:
                node._best_child = child
                value = tentative
        return value
    else:
        value = float('inf')
        for child in tree.get_children(node):
            tentative = minimax(child[0], tree, depth - 1, True, evaluate)
            if tentative <= value:
                node._best_child = child
                value = tentative
        return value

In [None]:
class TicTacToeTree(Tree):
    def __init__(self, tic_tac_toe):
        super().__init__()
        self._tic_tac_toe = tic_tac_toe
    
    def generate_children(self, node: Tree.Node) -> List[Tuple[Tree.Node, str]]:
        state = node.data
        avail_posx, avail_posy = np.asarray(state.array == 0).nonzero()
        for i in range(len(avail_posx)):
            next_state, value, terminal = self._tic_tac_toe.get_next_state(
                state,
                (Player.CrossPlayer if node.max_player else Player.CirclePlayer,
                 Action(x=avail_posx[i], y=avail_posy[i]))
            )
            yield (
                Tree.Node(data=next_state,
                          max_player = not node.max_player,
                          terminal = terminal,
                          terminal_value = value),
                '{} at ({}, {})'.format(
                    'cross' if node.max_player else 'circle',
                    str(avail_posx[i]),
                    str(avail_posy[i])
                )
            )
    
    def render(self, node: Tree.Node) -> None:
        self._tic_tac_toe.render(node.data)

In [None]:
tic_tac_toe = TicTacToe(TIC_TAC_TOE)
tic_tac_toe_tree = TicTacToeTree(tic_tac_toe)
minimax(node=Tree.Node(data=tic_tac_toe.reset()),
        tree = tic_tac_toe_tree,
        depth=1000,
        maximizing_player=True,
        evaluate = lambda n : n.terminal_value)
node = tic_tac_toe_tree.get_node(data=tic_tac_toe.reset())
tic_tac_toe.render(node.data)
while not node.terminal:
    print('Action: {}'.format(node.best_child[1]))
    node = node.best_child[0]
    tic_tac_toe.render(node.data)

## Alpha-Beta Pruning

![Alpha-Beta Pruning](img/AB_pruning.png)

In [None]:
from typing import Callable

def alphabeta(node : Tree.Node,
              tree: Tree,
              depth : int,
              alpha : float,
              beta : float,
              maximizing_player : bool,
              evaluate : Callable[[Tree.Node], float]):
    if depth == 0 or tree.is_terminal(node):
        return evaluate(node)
    if maximizing_player:
        value = -float('inf')
        for child in tree.get_children(node):
            tentative = alphabeta(child[0], tree, depth - 1, alpha, beta, False, evaluate)
            if tentative >= value:
                node._best_child = child
                value = tentative
            if value >= beta:
                break
            alpha = max(alpha, value)
        return value
    else:
        value = float('inf')
        for child in tree.get_children(node):
            tentative = alphabeta(child[0], tree, depth - 1, alpha, beta, True, evaluate)
            if tentative <= value:
                node._best_child = child
                value = tentative
            if value <= alpha:
                break
            beta = min(beta, value)
        return value

In [None]:
tic_tac_toe = TicTacToe(TIC_TAC_TOE)
tic_tac_toe_tree = TicTacToeTree(tic_tac_toe)
alphabeta(node=Tree.Node(data=tic_tac_toe.reset()),
          tree = tic_tac_toe_tree,
          depth=1000,
          alpha=-float("inf"),
          beta=float("inf"),
          maximizing_player=True,
          evaluate = lambda n : n.terminal_value)
node = tic_tac_toe_tree.get_node(data=tic_tac_toe.reset())
tic_tac_toe.render(node.data)
while not node.terminal:
    print('Action: {}'.format(node.best_child[1]))
    node = node.best_child[0]
    tic_tac_toe.render(node.data)

## Playing against a random player

In [None]:
import random

tic_tac_toe = TicTacToe(TIC_TAC_TOE)
tic_tac_toe_tree = TicTacToeTree(tic_tac_toe)

def call_alphabeta_pruning(tic_tac_toe_tree: TicTacToeTree,
                           node: Tree.Node) -> None:
    alphabeta(node=node,
              tree = tic_tac_toe_tree,
              depth=1000,
              alpha=-float("inf"),
              beta=float("inf"),
              maximizing_player=True,
              evaluate = lambda n : n.terminal_value)
    
def call_random_player(tic_tac_toe_tree: TicTacToeTree,
                       node: Tree.Node) -> None:
    node._best_child = random.sample(tic_tac_toe_tree.get_children(node), 1)[0]

node = tic_tac_toe_tree.get_node(data=tic_tac_toe.reset())
tic_tac_toe.render(node.data)

while not node.terminal:
    if node.max_player:
        call_alphabeta_pruning(tic_tac_toe_tree, node)
    else:
        call_random_player(tic_tac_toe_tree, node)
        
    node = node.best_child[0]
    tic_tac_toe.render(node.data)

## Playing _optimally_ against a random player

In [None]:
from __future__ import annotations
from typing import Dict

class ProbabilisticGameGraph:
    class StateNode:
        def __init__(self, data: Tree.Node):
            self._data = data
            self._best_action = None
            self._best_value = None
            self._successors: List[ProbabilisticGameGraph.ActionNode] = []
            
        @property
        def data(self):
            return self._data
        
        @property
        def best_action(self):
            return self._best_action
        
        @property
        def best_value(self):
            return self._best_value
            
        def __eq__(self, other: ProbabilisticGameGraph.StateNode):
            return self._data.__eq__(other._data)
        
        def __hash__(self):
            return hash(self._data)
        
        def __str__(self):
            return str(self._data)
        
        def __repr__(self):
            return 'Node(data: {}, best action: {}, best value: {})'.format(
                repr(self._data),
                repr(self._best_action) if self._best_action is not None else None,
                repr(self._best_value) if self._best_value is not None else None)
    
    class ActionNode:
        def __init__(self, data: Any):
            self._data = data
            self._successors: List[Tuple[ProbabilisticGameGraph.StateNode, float]] = []
            
        @property
        def data(self):
            return self._data
            
        def __eq__(self, other: ProbabilisticGameGraph.ActionNode):
            return self._data.__eq__(other._data)
        
        def __hash__(self):
            return hash(self._data)
        
        def __str__(self):
            return str(self._data)
        
        def __repr__(self):
            return 'ActionNode(data: {})'.format(
                repr(self._data))
    
    def __init__(self,
                 game_tree: Tree,
                 opponent_policy: Callable[[Tree.Node],
                                           List[Tuple[float, Tree.Node]]]):
        self._nodes: Dict[Any, ProbabilisticGameGraph.StateNode] = {}
        self._game_tree = game_tree
        self._opponent_policy = opponent_policy
    
    def get_node(self, data: Any):
        if data not in self._nodes:
            self._nodes[data] = ProbabilisticGameGraph.StateNode(data)
        return self._nodes[data]
        
    def get_successors(self, node: StateNode) -> List[ActionNode]:
        if node.data not in self._nodes or len(node._successors) == 0:
            node._successors = list(self.generate_successors(node))
            self._nodes[node.data] = node
        return self._nodes[node.data]._successors
    
    def generate_successors(self, node: StateNode) -> List[ActionNode]:
        for tree_node, action_str in self._game_tree.get_children(node.data):
            action_node = ProbabilisticGameGraph.ActionNode(data=(tree_node, action_str))
            if self._game_tree.is_terminal(tree_node):
                action_node._successors.append(tuple([self.get_node(tree_node),
                                                      1.0]))
            else:
                for probability, next_tree_node in self._opponent_policy(tree_node):
                    action_node._successors.append(tuple([self.get_node(next_tree_node),
                                                        probability]))
            yield action_node
    
    def is_goal(self, node: StateNode) -> bool:
        return self._game_tree.is_terminal(node.data)
    
    def render(self, node: StateNode) -> None:
        self._game_tree.render(node.data)

In [None]:
from typing import Optional, Set

class GameRTDP:
    
    def __init__(
        self,
        graph: ProbabilisticGameGraph,
        heuristic: Optional[
            Callable[[ProbabilisticGameGraph.StateNode], float]
        ] = None,
        max_steps: int = 1000,
        trials_number: int = 100,
        verbose: bool = False,
        render: bool = False,
    ) -> None:

        self._graph = graph
        self._heuristic = (
            (lambda _: 0.0) if heuristic is None else heuristic
        )
        self._max_steps = max_steps
        self._trials_number = trials_number
        self._verbose = verbose
        self._render = render
        self._values = {}

    def solve_from(self, tree_node: Tree.Node) -> None:
        
        def extender(node, explored):
            actions = []
            for action in self._graph.get_successors(node):
                for next_state, _ in action._successors:
                    if next_state not in explored:
                        if self._verbose:
                            print('New node {}'.format(str(next_state)))
                        next_state._best_value = self._heuristic(next_state)
                        explored.add(next_state)
                actions.append(action)
            return actions
        
        root_node = self._graph.get_node(tree_node)
        trial_id = 0
        explored = set()
        explored.add(root_node)
        root_node._best_value = self._heuristic(root_node)
        
        while trial_id < self._trials_number:
            self.trial(root_node, extender, explored)
            trial_id += 1
        
        tree_node._best_child = root_node._best_action.data
    
    def trial(self,
              root_node: ProbabilisticGameGraph.StateNode,
              extender : Callable[[ProbabilisticGameGraph.StateNode,
                                   Set[ProbabilisticGameGraph.StateNode]],
                                  List[ProbabilisticGameGraph.ActionNode]],
              explored: Set[ProbabilisticGameGraph.StateNode]) -> None:
        
        state_node = root_node
        steps = 0
        
        while not self._graph.is_goal(state_node) and steps < self._max_steps:
            action_node, best_value = self.greedy_action(state_node, extender, explored)
            self.update(state_node, action_node, best_value)
            state_node = self.pick_next_state(action_node)
            steps += 1
    
    def greedy_action(self,
                      node: ProbabilisticGameGraph.StateNode,
                      extender : Callable[[ProbabilisticGameGraph.StateNode,
                                           Set[ProbabilisticGameGraph.StateNode]],
                                          List[ProbabilisticGameGraph.ActionNode]],
                      explored: Set[ProbabilisticGameGraph.StateNode]):
        best_value = -float('inf')
        best_action = None
        for action_node in extender(node, explored):
            action_value = 0
            for next_state, probability in action_node._successors:
                action_value += probability * next_state._best_value
            if action_value > best_value:
                best_value = action_value
                best_action = action_node
        assert best_action is not None
        return best_action, best_value
    
    def update(self,
               state_node: ProbabilisticGameGraph.StateNode,
               action_node: ProbabilisticGameGraph.ActionNode,
               value: float):
        state_node._best_value = value
        state_node._best_action = action_node
    
    def pick_next_state(self, action_node: ProbabilisticGameGraph.ActionNode):
        population = []
        weights = []
        for ns, prob in action_node._successors:
            population.append(ns)
            weights.append(prob)
        return random.choices(population, weights=weights, k=1)[0]

In [None]:
import random

tic_tac_toe = TicTacToe(TIC_TAC_TOE)
tic_tac_toe_tree = TicTacToeTree(tic_tac_toe)

def call_game_rtdp(game_graph: ProbabilisticGameGraph,
                   node: Tree.Node,
                   max_value: float) -> None:
    rtdp = GameRTDP(
        graph=game_graph,
        heuristic = lambda n : n.data.terminal_value if n.data.terminal else max_value,
        max_steps=1000,
        trials_number=100,
        verbose=False,
        render=False)
    rtdp.solve_from(node)
    
def call_random_player(tic_tac_toe_tree: TicTacToeTree,
                       node: Tree.Node) -> None:
    node._best_child = random.sample(tic_tac_toe_tree.get_children(node), 1)[0]
    
def random_player_policy(node: Tree.Node) -> List[Tuple[float, Tree.Node]]:
    num_samples = len(tic_tac_toe_tree.get_children(node))
    return [(1. / float(num_samples), n) for n, _ in tic_tac_toe_tree.get_children(node)]

node = tic_tac_toe_tree.get_node(data=tic_tac_toe.reset())
tic_tac_toe.render(node.data)

game_graph = ProbabilisticGameGraph(game_tree=tic_tac_toe_tree,
                                    opponent_policy=random_player_policy)

while not node.terminal:
    if node.max_player:
        call_game_rtdp(game_graph, node, 1)
    else:
        call_random_player(tic_tac_toe_tree, node)
        
    node = node.best_child[0]
    tic_tac_toe.render(node.data)

## Back to epic combats: RTDP vs RTDP!

In [None]:
tic_tac_toe = TicTacToe(TIC_TAC_TOE)
tic_tac_toe_tree = TicTacToeTree(tic_tac_toe)

def opponent_policy(tree_node: Tree.Node,
                    opponent_game_graph: ProbabilisticGameGraph) -> List[Tuple[float, Tree.Node]]:
    if (opponent_game_graph is not None and
        tree_node in opponent_game_graph._nodes and
        opponent_game_graph.get_node(tree_node)._best_action is not None):
        return [(1., opponent_game_graph.get_node(tree_node)._best_action.data[0])]
    else:
        num_samples = len(tic_tac_toe_tree.get_children(tree_node))
        return [(1. / float(num_samples), n) for n, _ in tic_tac_toe_tree.get_children(tree_node)]

def call_game_rtdp(game_tree: Tree,
                   node: Tree.Node,
                   opponent_game_graph: ProbabilisticGameGraph,
                   max_value: float,
                   max_or_min_player: bool) -> None:
    game_graph = ProbabilisticGameGraph(
        game_tree=game_tree,
        opponent_policy=lambda tree_node : opponent_policy(tree_node, opponent_game_graph)
    )
    rtdp = GameRTDP(
        graph=game_graph,
        heuristic = lambda n : ((2 * int(max_or_min_player) - 1) * n.data.terminal_value
                                if n.data.terminal else
                                (2 * int(max_or_min_player) - 1) * max_value),
        max_steps=1000,
        trials_number=100,
        verbose=False,
        render=False)
    rtdp.solve_from(node)
    return game_graph

node = tic_tac_toe_tree.get_node(data=tic_tac_toe.reset())
tic_tac_toe.render(node.data)
current_game_graph = None

while not node.terminal:
    print('Player {}\'s turn'.format(
        'Cross' if node.max_player else 'Circle'
    ))
    current_game_graph = call_game_rtdp(
        tic_tac_toe_tree,
        node,
        current_game_graph,
        1,
        node.max_player
    )
    
    node = node.best_child[0]
    tic_tac_toe.render(node.data)

## Trying a more difficult game: connect-4

In [None]:
from muzero.games.connect4 import Connect4

connect4 = Connect4()
connect4.render()

In [None]:
import numpy as np

# %matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

class Board:
    def __init__(self, board: ArrayLike):
        self._board = board
        
    def __hash__(self):
        return hash(tuple(self._board.astype(int).flatten()))
        
    def __eq__(self, other):
        return np.all(np.equal(self._board, other._board))
    
    def __str__(self):
        return str(self._board)
    
    def __repr__(self):
        return repr(self._board)
    
    @property
    def board(self):
        return self._board
    
    def copy(self):
        return Board(self._board.copy())
    

class Connect4Tree(Tree):
    # We will store the Connect4 boards in the tree nodes
    # I.e: Tree.Node._data is a Board (i.e. hashable numpy array) as defined in the Connect4 class
            
    def __init__(self):
        super().__init__()
        self._ax = None
        self._fig = None
        self._image = None
    
    def reset(self) -> Tree.Node:
        connect4 = Connect4()
        connect4.reset()
        return self.get_node(data=Board(connect4.board))
    
    def generate_children(self, node: Tree.Node) -> List[Tuple[Tree.Node, str]]:
        connect4 = Connect4()
        connect4.board = node.data.board
        
        for action in connect4.legal_actions():
            connect4.player =  2 * int(node.max_player) -1
            saved_board = connect4.board.copy()
            obs, reward, done = connect4.step(action)
            next_board = connect4.board.copy()
            connect4.board = saved_board
            
            yield (
                Tree.Node(
                    data=Board(next_board),  # obs is not what we want here
                    max_player=not node.max_player,
                    terminal=done,
                    terminal_value=reward if done else 0
                ),
                'Player {} on column {}'.format(
                    'green' if node.max_player else 'red',
                    action
                )
            )
    
    def render(self, node: Tree.Node) -> None:
        board_to_render = np.zeros(shape=(2 * node.data.board[::-1].shape[0] + 1,
                                          2 * node.data.board[::-1].shape[1] + 1),
                                   dtype=np.float32)
        
        for r in range(int(board_to_render.shape[0] / 2) + 1):
            board_to_render[2*r,:] = 0.7 * np.ones(board_to_render.shape[1])
        for c in range(int(board_to_render.shape[1] / 2) + 1):
            board_to_render[:,2*c] = 0.7 * np.ones(board_to_render.shape[0])
            
        if self._ax is None:
            fig, ax = plt.subplots(1)
            fig.canvas.set_window_title("connect-4")
            ax.set_aspect("equal")  # set the x and y axes to the same scale
            plt.xticks([])  # remove the tick marks by setting to an empty list
            plt.yticks([])  # remove the tick marks by setting to an empty list
            ax.invert_yaxis()  # invert the y-axis so the first row of data is at the top
            self._ax = ax
            self._fig = fig
            plt.ion()
        if self._image is None:
            self._image = self._ax.imshow(board_to_render, cmap='Greys', vmin=0, vmax=1)
        else:
            self._image.set_data(board_to_render)
        
        nb_rows = node.data.board[::-1].shape[0]
        for r in range(node.data.board[::-1].shape[0]):
            for c in range(node.data.board[::-1].shape[1]):
                if node.data.board[::-1][r,c] == 1:
                    self._ax.scatter(2*c + 1, 2*r + 1, facecolors='green', edgecolors='green')
                elif node.data.board[::-1][r,c] == -1:
                    self._ax.scatter(2*c + 1, 2*r + 1, facecolors='red', edgecolors='red')
        
        display(self._fig)
        clear_output(wait = True)
        plt.pause(1)

connect4_tree = Connect4Tree()
connect4_tree.render(connect4_tree.reset())

In [None]:
connect4_tree = Connect4Tree()

node = connect4_tree.reset()
connect4_tree.render(node)

while not node.terminal:
    
    if node.best_child is None:
        connect4_tree = Connect4Tree()
        alphabeta(node=node,
              tree = connect4_tree,
              depth=5,
              alpha=-float("inf"),
              beta=float("inf"),
              maximizing_player=True,
              evaluate = lambda n : n.terminal_value
        )
        
#     print('Action: {}'.format(node.best_child[1]))
    node = node.best_child[0]
    connect4_tree.render(node)

### MuZero

![MCTS](img/MCTS.png)

![MuZero](img/muzero.gif)

In [None]:
%cd muzero/
!pip install -r requirements.txt
!pip uninstall -y pyarrow
%load_ext tensorboard
%cd ..

In [None]:
%tensorboard --logdir ./muzero/results

In [None]:
%cd muzero/
!python muzero.py connect4 '{"training_steps": 100}'
%cd ..

In [None]:
current_dir = os.path.abspath('')
muzero_dir = os.path.join(current_dir, 'muzero')
os.environ["PYTHONPATH"] = current_dir + ":" + muzero_dir + ":" + os.environ.get("PYTHONPATH", "")
%export 

from muzero.muzero import MuZero
from muzero.games.connect4 import MuZeroConfig, Game
import os


config = MuZeroConfig()
config.training_steps = 100
muzero = MuZero('connect4', config)
muzero.train()