---
layout: code-post
title: AlphaZero for Brandub
description: We provide a from scratch implementation of AlphaZero to learn how to play brandub.
tags: [tafl, neural nets]
---

In this post I'm going to explore how to implement [AlphaZero](https://arxiv.org/pdf/1712.01815.pdf)
to learn how to play Brandub (a tafl variant). I previously put together a [Brandub package](https://github.com/kevinnowland/brandub)
which I won't use directly, but whose code I will modify as needed. You can find the rules of the
game there.

The main purpose of this is to prove that I understand the algorithm well enough to implement
a very crude version of it. How crude? I doubt the best trained player after even a few days of training
will be able to play very well. My big question is whether the trained model can beat me, who has
never played a full game before despite all this coding that I've done.


## Algorithm description

### Game play technology

AlphaZero plays games using two main pieces of technology. The first is a neural network $f\_\theta(s)$ with
parameters $\theta$ which takes the current game state $s$ and outputs $(p, z)$ where where $p$ is a probability
vector of all possible moves and $z \in [-1, 1]$ is an estimated game score for the current player with -1 being a
loss, 0 a draw, and 1 a win. Only the neural network parameters $\theta$ will be learned.

One could play 
just using $f\_\theta$ by choosing moves according to the output move probabilities $p$, 
but instead of doing this, we rely on
the second piece of technology, a Monte Carlo Tree Search (MCTS) algorithm. The tree search
uses the neural network to explore the space of moves and choose the best one taking into account
how the game might proceed. The neural network output, as we will
see later, is extremely raw and can suggest moves which are illegal (moving an opponent's piece) or 
even impossible (moving off the board). In addition to suggestiongm oves which will likely lead to victory
for the curernt player, the tree search algorithm helps encode the rules of the game
by by only exploring using legal moves. The tree
search is not learned directly, but does take as inputs the neural network $f\_\theta$ and current game state. 
The output is a policy $\pi$, a probability vector that is used to select the next move.

### Learning

Give the neural network and MCTS algorithm, play proceeds until a game ends at step $T$.
At each step $t$ of the game we have neural network output $(p\_t, v\_t)$ as well as 
the policy $\pi\_t$ governing move selection. The final game result is $s\_T \in \{-1, 0, 1\}$. 
The loss $\ell\_t$ for a step is
$$
  \ell\_t = (v\_t - s\_T)^2 - \pi\_t^t\log p\_t + c \|\theta\|^2,
$$
where $c$ is an $\ell^2$-regularization parameter. Thus we are looking at mean squared
error for the 
predicted output $v$ with cross-entropy forcing predicted moved probabilities $p$ to look like 
the MCTS policy $\pi$ with a regularization term.
The overall loss function is the average of this over the $T$ moves making up the game.

After each game, we backpropagate from the loss function through the nerual network to
change the neural network parameters $\theta$. Note that while policies $\pi\_t$ do depend on 
$\theta$ since the MCTS takes the nerual network as input, we will pretend it does not and thus 
does not affect backpropogation.


## Implementation

### Game State

The state of a game is encoded similarly but not identically to what was done in the brandub
package. The state tensor has is a stack of 7x7 binary tensors and some constant tensors.
The first layer encodes the current player's pawns and the second layer is the current player's
monarch. The current player might be the attacker and not have a monarch, but that is fine and
the input plane will be all zero. These layers repeat with the opponent's positions. 
AlphaZero maintains 8 timesteps of history, but we will try to keep only 3, if that. While the connections
from these planes into the neural net
will be learned and truly will be input, I suspect their main use will be checking to see if a 
position has repeated leading to a draw. AlphaZero for Chess contains two more constant planes, 
either 1 or 0, depending on if the current position has repeated once or twice before.
Probably the order matters, although I am not sure. Since brandub ends in a draw with only a 
single repeated position, we do not need to encode this information for the player.

In addition we will include a constant plane that is not repeated to indicate which player is
playing, either 0 for defense or 1 for attack.

Overall, this leads to a game state tensor of size $7\times7\times(4\times3 + 1)$, i.e.,
we have 637 features.

In [1]:
import torch

def get_initial_game_state():
    """returns pytorch tensor encoding initial brandub game state"""
    
    game_state = torch.zeros([7, 7, 13], dtype=torch.float64)
    
    # attacking pawns
    game_state[3, 0, 0] = 1
    game_state[3, 1, 0] = 1
    game_state[3, 5, 0] = 1
    game_state[3, 6, 0] = 1
    game_state[0, 3, 0] = 1
    game_state[1, 3, 0] = 1
    game_state[5, 3, 0] = 1
    game_state[6, 3, 0] = 1
    
    # defensive pawns
    game_state[3, 2, 2] = 1
    game_state[3, 4, 2] = 1
    game_state[2, 3, 2] = 1
    game_state[4, 3, 2] = 1
    
    # defensive monarch
    game_state[3, 3, 3] = 1
    
    # attack's turn
    game_state[:, :, 12] = 1
    
    return game_state

Now we write some functions to check for end game conditions. While I won't be optimizing this
code too much, I won't bother to check types and board legality.

In [386]:
def get_monarch_index(game_state):
    """ return which game state layer encodes the defensive monarch 
    for the current board """
    if game_state[0, 0, 12] == 1:
        return 3
    else:
        return 1

    
def attack_victory(board, monarch_index):
    """ attack wins if the monarch is not on the current board"""
    
    return board[:, :, monarch_index].max() == 0


def defense_victory(board, monarch_index):
    """ defense wins if the monarch is in the forest """
    
    return board[0, 0, monarch_index] == 1 or \
        board[0, 6, monarch_index] == 1 or \
        board[6, 0, monarch_index] == 1 or \
        board[6, 6, monarch_index] == 1


def get_relative_end_value(board, game_state):
    """ returns 1, -1, 0, or None based on if the current
    player is winning. In this setup, the board
    is a new board we might want to potentially add to the
    game_state before the players are switched. """
    
    attacker_is_playing = game_state[0, 0, 12] == 1
    monarch_index = 3 if attacker_is_playing else 1
    
    if attack_victory(board, monarch_index):
        if attacker_is_playing:
            return 1
        else:
            return -1
    elif defense_victory(board, monarch_index):
        if attacker_is_playing:
            return -1
        else:
            return 1
    elif drawn_game(board, game_state):
        return 0
    else:
        return None

def get_end_value(board, game_state):
    """ return 1, -1, 0, or None based on attacker
    winning, defender winning, draw, no result """
    
    attacker_is_playing = game_state[0, 0, 12] == 1
    monarch_index = 3 if attacker_is_playing else 1
    
    if attack_victory(board, monarch_index):
        return 1
    elif defense_victory(board, monarch_index):
        return -1
    elif drawn_game(board, game_state):
        return 0
    else:
        return None


def board_shadow(board, is_attack_turn):
    """ return a 7x7 tensor encoding current state with
    -1 being an attacking pawn, 1 a defensive pawn,
    2 the monarch.
    
    board is a 7x7x4 tensor
    is_attack_turn is a boolean
    
    This is used to look for draws with game states"""
    
    raw_shadow = board[:, :, 0] + 2 * board[:, :, 1] - \
        board[:, :, 2] - 2 * board[:, :, 3]
    
    if is_attack_turn:
        return -1 * raw_shadow
    else:
        return raw_shadow
    
    
def game_shadow(game_state):
    """ returns the board shadow for the current board
    and player. """
    
    return board_shadow(game_state[:, :, :4], game_state[0, 0, 12] == 1)
        

def drawn_game(board, game_state):
    """ take a new board and compare to the boards in the game state
    to see if it would result in a draw.
    
    NOTE: This assumes that the current player for board is THE SAME AS
    the current player in the game_state. 
    
    NOTE: Do not bother checking most recent game state, as we
    assume the presented board is after a move. """
    
    is_attack_turn = game_state[0, 0, 12] == 1
    new_board_shadow = board_shadow(board, is_attack_turn)
    
    # check two moves ago
    board_shadow_1 = board_shadow(game_state[:, :, 4:8], is_attack_turn)
    
    if torch.all(torch.eq(new_board_shadow, board_shadow_1)):
        return True
    else:
        board_shadow_2 = board_shadow(game_state[:, :, 8:12],
                                     not is_attack_turn)
        return torch.all(torch.eq(new_board_shadow, board_shadow_2))

In [3]:
initial_game_state = get_initial_game_state()
print(game_shadow(initial_game_state))

tensor([[-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0.,  1., -0., -0., -0.],
        [-1., -1.,  1.,  2.,  1., -1., -1.],
        [-0., -0., -0.,  1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.]], dtype=torch.float64)


### Movement

Movement will be encoded as a $7\times7\times24$ tensor. The first two coordinates
indicate the position on the board where a piece should be found and the
final coordinate indicates the direction. The value modulo 6 then adding 1 (not
modulo 6) is the amount of spaces to move, raw values between 0 and 5 indicate
moving down, 6 to 11 moving up, 12 to 17 moving right, 18 to 23 moving
left. $(7, 7, 4)$ will also be the shape of the probability vector put out by the neural net and
will include invalid moves. The first step of the MCTS will be to reduce this
only to valid moves. A singular move will be an index tensor that we can samples
using the `Categorical` object in `torch.distributions.categorical`.

We will need auxiliary functions to find valid moves and then perform
those moves. Again, since we're trying to remove unnecessary cruft, we will
not verify many things, such as piece existence. In the following code block
we will write auxiliary functions involved with finding valid moves
and the new position of a piece given some move.

In [187]:
def is_forest(position_2d):
    """ returns whether the first two coords of
    position_2d is a forest"""
    return position_2d[0] in [0, 6] and position_2d[1] in [0, 6]


def is_castle(position_2d):
    """ returns whether the first two coords of position_2d
    is a castle"""
    return position_2d[0] == 3 and position_2d[1] == 3
   

def find_piece_plane(board, position_2d, planes=[2, 3, 0, 1]):
    """ returns the plane of piece at given (i, j) position """
    for k in planes:
        if board[position_2d[0], position_2d[1], k] == 1:
            return k
        
        
def find_piece_at_position(board, position_2d):
    """ returns the (i, j, k) position of the piece at position (i, j) 
    if some is present else returns None """
    
    k = find_piece_plane(board, position_2d, [2, 3, 0, 1])
    if k is not None:
        return (position_2d[0], position_2d[1], k)
    else:
        return None

        
def valid_moves(board, position):
    """ find valid movement indices 0-24 for a piece
    that is (assumed to be) at position, which is a tensor with length 2"""
    
    is_pawn = board[position[0], position[1], 1] == 0
    
    shadow = board[:, :, 0] + board[:, :, 1] + board[:, :, 2] + board[:, :, 3]
    
    def check_direction(direction_vector):
        """ get valid moves in the given direction
        direction must be [+/-1, 0] or [0, +/-1] torch tensors
        """

        valid_moves = []

        coord = 0 if direction_vector[0] != 0 else 1
        positive_direction = direction_vector[coord] == 1
        end_value = 6 if positive_direction else 0
        
        if coord == 0:
            if positive_direction:
                base = 0
            else:
                base = 6
        else:
            if positive_direction:
                base = 12
            else:
                base = 18

        keep_going = position[coord] != end_value
        i = 0
        while keep_going:
            i += 1

            new_pos = position + i * direction_vector

            # stop if run into a piece
            if shadow[tuple(new_pos)] == 1:
                break

            # ignore the castle
            if is_castle(new_pos):
                continue

            keep_going = new_pos[coord] != end_value

            # if pawn and at the wall, see if its a forest but don't add
            if not keep_going and is_pawn and is_forest(new_pos):
                break

            valid_moves.append(base + i - 1)

        return valid_moves
    
    direction_vectors = (
        torch.tensor([1, 0]),
        torch.tensor([-1, 0]),
        torch.tensor([0, 1]),
        torch.tensor([0, -1])
    )

    return [
        direction
        for dvec in direction_vectors
        for direction in check_direction(dvec)
    ]


def valid_move_indices(board):
    move_indices = [
        (pos[0], pos[1], move_index)
        for pos in board[:, :, :2].nonzero(as_tuple=False)
        for move_index in valid_moves(board, pos[:2])
    ]
    
    return move_indices


def valid_move_tensor(board):
    """ get tensor of all valid moves. Only pieces on the first
    two planes can move. 
    
    Returns 7x7x24 binary tensor """
    
    # get all pieces that can move
    move_indices = valid_move_indices(board)
    
    move_tensor = torch.zeros(7, 7, 24)
    for move_index in move_indices:
        move_tensor[move_index] = 1
        
    return move_tensor


def find_new_position(board, move):
    """ find the piece at the given
    position and return a tuple containing
    the pieces new location.
    
    NOTE: assumes piece is on the first or second plane """
    
    # whether to move vertically or horizontally
    move_vertical = move[2] < 12
    
    # how much to move
    direction = 1 if move[2] % 12 < 6 else -1
    move_val = direction * (move[2] % 6 + 1)
    
    # find where the piece is
    plane = find_piece_plane(board, move[:2], [0, 1])
    
    if move_vertical:
        return (move[0] + move_val, move[1], plane)
    else:
        return (move[0], move[1] + move_val, plane)

Now we write functions that allow us to advance game state by performing a move
and then capturing the pieces.

In [348]:
def is_piece_captured(board, position):
    """ determine if a piece at (i, j, k) is captured.
    
    NOTE: does not check that a piece is at the position
    
    NOTE: This does not validate that the raw_board is a
    vaid raw_board. """
    
    i = position[0]
    j = position[1]

    topography_inds_corners = [(0, 0), (0, 6), (6, 0), (6, 6)]

    if position[2] == 0 or position[2] == 2:
        # pawn logic
        
        if position[2] == 0:
            enemy_board = board[:, :, 2] + board[:, :, 3]
        else:
            enemy_board = board[:, :, 0] + board[:, :, 1]
        
        if board[3, 3, 1] + board[3, 3, 3] == 0:
            # no monarch? then castle is threat
            topography_inds = topography_inds_corners + [(3, 3)]
        else:
            topography_inds = topography_inds_corners

        topography = torch.zeros(7, 7)
        for ind in topography_inds:
            topography[ind] = 1
        bad_things = enemy_board + topography

        if i == 0 or i == 6:
            # on top or bottom (can't be in corner)
            if bad_things[i, j-1] == 1 and bad_things[i, j+1] == 1:
                return True
            else:
                return False
        elif j == 0 or j == 6:
            # on left or right side (can't be in corner)
            if bad_things[i-1, j] == 1 and bad_things[i+1, j] == 1:
                return True
            else:
                return False
        else:
            # otherwise just check
            if bad_things[i, j-1] == 1 and bad_things[i, j+1] == 1:
                return True
            elif bad_things[i+1, j] == 1 and bad_things[i-1, j] == 1:
                return True
            else:
                return False

    elif position[2] == 1 or position[2] == 3:
        # monarch logic
        
        if position[2] == 1:
            enemy_board = board[:, :, 2]
        else:
            enemy_board = board[:, :, 0]

        if i != 3 or j != 3:
            # empty castles are threats
            topography_inds = topography_inds_corners + [(3, 3)]
        else:
            topography_inds = topography_inds_corners

        topography = torch.zeros(7, 7)
        for ind in topography_inds:
            topography[ind] = 1
        bad_things = enemy_board + topography

        if (i, j) in [(3, 2), (3, 3), (3, 4), (2, 3), (4, 3)]:
            # in or next to castle have to be surrounded
            if bad_things[i+1, j] == 1 and bad_things[i-1, j] == 1 \
                and bad_things[i, j+1] == 1 and bad_things[i, j-1] == 1:
                return True
            else:
                return False
        else:
            if (i, j) in [(0, 0), (0, 6), (6, 0), (6, 6)]:
                # safe in corner
                return False
            elif i == 0 or i == 6:
                # top or bottom: non corner
                if bad_things[i, j+1] == 1 and bad_things[i, j-1] == 1:
                    return True
                else:
                    return False
            elif j == 0 or j == 6:
                # left or right: non corner
                if bad_things[i+1, j] == 1 and bad_things[i-1, j] == 1:
                    return True
                else:
                    return False
            else:
                # any other spot on board
                if bad_things[i+1, j] == 1 and bad_things[i-1, j] == 1:
                    return True
                elif bad_things[i, j+1] == 1 and bad_things[i, j-1] == 1:
                    return True
                else:
                    return False
    else:
        msg = "position[2] must be in {0, 1, 2, 3}"
        raise Exception(msg)


def game_state_move(game_state, move):
    """ advance game state by moving piece, performing captures, 
    flipping players, and returning a new game state. Also
    returns if the game is over as needed.
    
    NOTE: asumes the move is valid but does not check 
    
    This takes about 0.55 ms"""
    
    # perform the move
    old_board = game_state[:, :, :4]
    new_pos = find_new_position(game_state[:, :, :4], move)
    old_pos = (move[0], move[1], new_pos[2])
    
    new_board_ = torch.zeros(7, 7, 4)
    new_board_[:, :, :] = old_board
    new_board_[old_pos] = 0
    new_board_[new_pos] = 1
    
    # check any pieces near the given piece to see if they are captured
    # if so, remove them from the new_board
    i = new_pos[0]
    j = new_pos[1]
    
    for check_pos in [(i, j), (i-1, j), (i+1, j), (i, j-1), (i, j+1)]:
        
        if check_pos[0] >= 0 and check_pos[0] <= 6 \
            and check_pos[1] >=0 and check_pos[1] <= 6:
            
            piece_pos_ = find_piece_at_position(new_board_, check_pos)
            
            if piece_pos_ is not None:
                if is_piece_captured(new_board_, piece_pos_):
                    new_board_[piece_pos_] = 0
                    
    
    # see if the new_board_ is in a victory state
    # return if so
    end_game_value = get_relative_end_value(new_board_, game_state)
        
    if end_game_value is not None:
        return None, end_game_value
    else:
        # flip players
        new_board = torch.zeros(7, 7, 4)
        new_board[:, :, :2] = new_board_[:, :, 2:]
        new_board[:, :, 2:] = new_board_[:, :, :2]
    
        # form the new game_state
        new_game_state = torch.zeros(7, 7, 13, dtype=torch.float64)

        # update boards
        new_game_state[:, :, :4] = new_board
        new_game_state[:, :, 4:8] = game_state[:, :, :4]
        new_game_state[:, :, 8:12] = game_state[:, :, 4:8]

        # update player
        new_game_state[:, :, 12] = 1 - game_state[:, :, 12]
    
    
        return new_game_state, None

We could now play a game starting from the initial game state provided by `get_initial_game_state()`
and applying moves using `game_state_move()`. We would have to be careful to only supply
valid moves though, because the board could easily get into an invalid state since we are
not verifying many things. As written, `game_state_move()` returns two values, one of which will
always be none. The `new_game_state` is returned if the game is not over and the second
value is returned as `None`. If the game is over, `new_game_state` will be `None` (and returned first)
and the second parameter will be 1 if the player who submitted the move won, -1 if they lost because
of their move, and 0 if the game ended in a draw.

### Monte Carlo Tree Search

Now that we have a playable game that can also find valid moves, we should be able to implement the
Monte Carlo Tree Search algorithm. The main references I'm using fo this are [this blog post](https://www.analyticsvidhya.com/blog/2019/01/monte-carlo-tree-search-introduction-algorithm-deepmind-alphago/) for explicit examples of how MCTS proceeds,
the [mcstpy package](https://github.com/int8/monte-carlo-tree-search) from which I will liberally take code, and an AlphaZero 
explainer [here](https://web.stanford.edu/~surag/posts/alphazero.html) from a grad student. It appears that one of the main
differences between the AlphaZero MCTS implementation and the standard implementation is that game state values
are not obtained via rollout and are instead obtained directly from the neural net evaluator.
We can test it out initially with a dummy move probability
tensor generator and values.

To program this, we will create a new `Node` class that is the base unit for the tree created during MCTS.
This will not be functional programming, as I believe it makes more sense to keep nodes around and update
the relevant values they hold versus a more functional approach. I'm open to hearing about how this is wrong,
though!

The node will know about its children, its parent, the probability that it was chosen by its parent
initially, the predicted value of that position for whoever is playing it, and the raw probabilities 
for the next step. A node also needs to know how to backpropagate a value to its parent, as the
nodes perceived value will change at every step. The trickiest part is thinking about
what value to give to a node and what value to propagate. The neural network $f\_\theta$ always
provides the value from the current player's perspective. Since the player of the child node
is the opponent of its parent node, the child nodes value to the parent is the opposite of
what the child perceives.

But what does that mean for the backpropagation? A parent node chooses the node with
the highest value (since the child values are given according to what the parent thinks).
The child does the same thing, choosing what is a grandchild node to the original parent node.
The grandchild will send the negative of its value to the child node because the child node needs
to become less attractive to the original parent as the child chose a node to help itself, not
its parent.

In [477]:
import copy


class Node:
    
    def __init__(self,
                 is_attacker,
                 move_taken,
                 parent,
                 game_state,
                 p,
                 v,
                 raw_child_p_tensor,
                 game_over):
        
        self.__is_attacker = is_attacker
        self.__move_taken = move_taken
        self.__parent = parent
        self.__children = []
        self.__game_state = game_state
        self.__p = p
        self.__v = v
        self.__initial_v = copy.deepcopy(v)
        self.__raw_child_p_tensor = raw_child_p_tensor
        self.__game_over = game_over
        self.__n = torch.tensor(0.0)
    
    @property
    def is_attacker(self):
        return self.__is_attacker
    
    @property
    def move_taken(self):
        return self.__move_taken
    
    @property
    def parent(self):
        return self.__parent
    
    @property
    def children(self):
        return self.__children
    
    @property
    def game_state(self):
        return self.__game_state
    
    @property
    def board(self):
        return self.__game_state[:, :, :4]
    
    @property
    def p(self):
        return self.__p
    
    @property
    def v(self):
        return self.__v
    
    @property
    def initial_v(self):
        return self.__initial_v
    
    @property
    def n(self):
        return self.__n
    
    @property
    def raw_child_p_tensor(self):
        return self.__raw_child_p_tensor
    
    @property
    def u(self):
        return self.v + 1 * self.p * (torch.sqrt(self.parent.n) / (1 + self.n))
    
    @property
    def game_over(self):
        return self.__game_over
    
    @property
    def has_children(self):
        return len(self.children) > 0
    
    def increment_n(self):
        self.__n += 1
        
    def add_children(self, evaluator):
        """ use the evaluator (the neural net) to
        add child nodes for each possible move"""
        
        # find all posible valid moves and their raw probabilities
        # then convert to new probabilities
        move_tensor = valid_move_tensor(self.board)
        
        temp_p = move_tensor * self.raw_child_p_tensor
        new_p = temp_p / temp_p.sum()
        
        for move_index in move_tensor.nonzero(as_tuple=False):
            
            new_game_state, end_game_value = game_state_move(self.game_state,
                                                             move_index)
            
            node_p = new_p[move_index[0], move_index[1], move_index[2]]
            
            if new_game_state is not None:
                node_child_p_tensor, node_v = evaluator(new_game_state)
                new_node = Node(not self.is_attacker,
                                move_index,
                                self,
                                new_game_state,
                                node_p,
                                -node_v,
                                node_child_p_tensor,
                                False)
            else:
                new_node = Node(self.is_attacker,
                                move_index,
                                self,
                                None,
                                node_p,
                                end_game_value,
                                None,
                                True)
            
            self.__children.append(new_node)
            
    def backpropagate(self, new_v):
        """ add this to our v and to the parent v """
        self.__v += new_v
        if self.parent is not None:
            self.parent.backpropagate(new_v)
            
    def choose_favorite_child(self):
        """ returns the child with the highest u value """
        return max(self.children, key=lambda c: c.u)
    
    def detach_parent(self):
        """ set parent to None """
        self.__parent = None
    
    def get_child(self, move):
        """ return child node based on the move taken """
        
        for child in self.children:
            if torch.all(torch.eq(child.move_taken, move)):
                return copy.deepcopy(child)

Note how the backpropagation gives the node the value given to it but then passes the opposite 
value to its own parent. In the code below we perform an iteration as we reasoned above: the
chosen grandchild node gives its parent the negative of its own value and chain of
backpropagations continues.

In [478]:
def mcts_iteration(root_node, evaluator):
    """ run one iteration of the MCTS """
        
    node = root_node
        
    keep_going = True
    while keep_going:
        
        if node.n == 0 and node.parent is not None:
            node.parent.backpropagate(-node.v)
            node.increment_n()
            keep_going = False
        else:
            if node.has_children:
                node.increment_n()
                node = node.choose_favorite_child()
            else:
                if node.game_over:
                    node.parent.backpropagate(-node.v)
                    node.increment_n()
                    keep_going = False
                else:
                    node.add_children(evaluator)
                    node.increment_n()
                    node = node.choose_favorite_child()
            

            
def get_move_probabilities(root_node, evaluator, num_iterations):
    """ run the mcts with a given number of iterations 
    
    returns moves and their probabilities """
    
    for _ in range(num_iterations):
        mcts_iteration(root_node, evaluator)
        
    policy_tensor = torch.zeros(7, 7, 24, dtype=torch.float64)
    
    moves = [
        child.move_taken
        for child in root_node.children
        if child.n > 0
    ]
    
    visits = torch.tensor([
        child.n
        for child in root_node.children
        if child.n > 0
    ], dtype=torch.float64)
        
    return moves, visits / visits.sum()


def convert_to_policy_tensor(moves, probs):
    """ convert moves and probs lists into a policy tensor"""
    
    move_tensor = torch.zeros(7, 7, 24, dtype=torch.float64)
    
    for i in range(len(moves)):
        move_tensor[tuple(moves[i])] = probs[i]

In [495]:
from torch.distributions.categorical import Categorical


def get_initial_node(evaulator):
    """ get the initial node """
    
    initial_state = get_initial_game_state()
    p, v = evaulator(initial_state)
    
    return Node(is_attacker=True,
                move_taken=None,
                parent=None,
                game_state=initial_state,
                p=p,
                v=v,
                raw_child_p_tensor=p,
                game_over=False)


def play_game(evaluator, num_iterations):
    """ initialize and play one game. returns policy tensors,
    predicted movement probabilities, properly signed game
    results, and predicted game values. """
    
    
    predicted_movement_probs = []
    predicted_game_values = []
    policy_tensors = []
    
    node = get_initial_node(evaluator)
    while not node.game_over:
        
        if node.is_attacker:
            print('\nattack turn:\n', game_shadow(node.game_state))
        else:
            print('\ndefense turn:\n', game_shadow(node.game_state))
        moves, probs = get_move_probabilities(node,
                                              evaluator,
                                              num_iterations)
        
        # record values
        predicted_movement_probs.append(node.raw_child_p_tensor)
        predicted_game_values.append(node.initial_v)
        policy_tensors.append(convert_to_policy_tensor(moves, probs))
        
        # make move
        cat = Categorical(probs=probs)
        move = moves[cat.sample()]
        print('move:', move)
        node = node.get_child(move)
        node.detach_parent()
        
    # endgame node preserves who was the last player
    if node.initial_v == 0:
        game_values = [0 for _ in range(len(predicted_game_values))]
    elif node.is_attacker:
        game_values = [
            node.initial_v * torch.pow(torch.tensor(-1, dtype=torch.float64), i % 2)
            for i in range(len(predicted_game_values))
        ]
    else:
        game_values = [
            node.initial_v * torch.pow(torch.tensor(-1, dtype=torch.float64), (i % 2) + 1)
            for i in range(len(predicted_game_values))
        ]
    
    return predicted_movement_probs, policy_tensors, \
        predicted_game_values, game_values

In [500]:
pmp, pt, pgv, gv = play_game(dummy_evaluator, 20)


attack turn:
 tensor([[-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0.,  1., -0., -0., -0.],
        [-1., -1.,  1.,  2.,  1., -1., -1.],
        [-0., -0., -0.,  1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.]], dtype=torch.float64)
move: tensor([3, 6, 0])

defense turn:
 tensor([[ 0.,  0.,  0., -1.,  0.,  0.,  0.],
        [ 0.,  0.,  0., -1.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  1.,  0.,  0.,  0.],
        [-1., -1.,  1.,  2.,  1., -1.,  0.],
        [ 0.,  0.,  0.,  1.,  0.,  0., -1.],
        [ 0.,  0.,  0., -1.,  0.,  0.,  0.],
        [ 0.,  0.,  0., -1.,  0.,  0.,  0.]], dtype=torch.float64)
move: tensor([3, 4, 0])

attack turn:
 tensor([[-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0.,  1., -0., -0., -0.],
        [-1., -1.,  1.,  2., -0., -1., -0.],
        [-0., -0., -0.,  1.,  1., -0., -1.],
       

With that done, we are ready to play a single game!

In [379]:
from torch.distributions.uniform import Uniform

def dummy_evaluator(*args):
    """ get a dummy 7x7x24 tensor of movement probabilities """
    u_p = Uniform(torch.zeros(7, 7, 24), torch.ones(7, 7, 24))
    p = u_p.sample()
    
    u_v = Uniform(-1, 1)
    v = u_v.sample()
    
    return p / p.sum(), v



