---
layout: code-post
title: AlphaZero for Brandub
description: We provide a from scratch implementation of AlphaZero to learn how to play brandub.
tags: [tafl, neural nets]
---

In this post I'm going to explore how to implement [AlphaZero](https://arxiv.org/pdf/1712.01815.pdf)
to learn how to play Brandub (a tafl variant). I previously put together a [Brandub package](https://github.com/kevinnowland/brandub)
which I won't use directly, but whose code I will modify as needed. You can find the rules of the
game there.

The main purpose of this is to prove that I understand the algorithm well enough to implement
a very crude version of it. How crude? I doubt the best trained player after even a few days of training
will be able to play very well. My big question is whether the trained model can beat me, who has
never played a full game before despite all this coding that I've done.


## Algorithm description

### Game play technology

AlphaZero plays games using two main pieces of technology. The first is a neural network $f\_\theta(s)$ with
parameters $\theta$ which takes the current game state $s$ and outputs $(p, z)$ where where $p$ is a probability
vector of all possible moves and $z \in [-1, 1]$ is an estimated game score for the current player with -1 being a
loss, 0 a draw, and 1 a win. Only the neural network parameters $\theta$ will be learned.

One could play 
just using $f\_\theta$ by choosing moves according to the output move probabilities $p$, 
but instead of doing this, we rely on
the second piece of technology, a Monte Carlo Tree Search (MCTS) algorithm. The tree search
uses the neural network to explore the space of moves and choose the best one taking into account
how the game might proceed. The neural network output, as we will
see later, is extremely raw and can suggest moves which are illegal (moving an opponent's piece) or 
even impossible (moving off the board). In addition to suggestiongm oves which will likely lead to victory
for the curernt player, the tree search algorithm helps encode the rules of the game
by by only exploring using legal moves. The tree
search is not learned directly, but does take as inputs the neural network $f\_\theta$ and current game state. 
The output is a policy $\pi$, a probability vector that is used to select the next move.

### Learning

Give the neural network and MCTS algorithm, play proceeds until a game ends at step $T$.
At each step $t$ of the game we have neural network output $(p\_t, v\_t)$ as well as 
the policy $\pi\_t$ governing move selection. The final game result is $s\_T \in \{-1, 0, 1\}$. 
The loss $\ell\_t$ for a step is
$$
  \ell\_t = (v\_t - s\_T)^2 - \pi\_t^t\log p\_t + c \|\theta\|^2,
$$
where $c$ is an $\ell^2$-regularization parameter. Thus we are looking at mean squared
error for the 
predicted output $v$ with cross-entropy forcing predicted moved probabilities $p$ to look like 
the MCTS policy $\pi$ with a regularization term.
The overall loss function is the average of this over the $T$ moves making up the game.

After each game, we backpropagate from the loss function through the nerual network to
change the neural network parameters $\theta$. Note that while policies $\pi\_t$ do depend on 
$\theta$ since the MCTS takes the nerual network as input, we will pretend it does not and thus 
does not affect backpropogation.


## Implementation

### Game State

The state of a game is encoded similarly but not identically to what was done in the brandub
package. The state tensor has is a stack of 7x7 binary tensors and some constant tensors.
The first layer encodes the current player's pawns and the second layer is the current player's
monarch. The current player might be the attacker and not have a monarch, but that is fine and
the input plane will be all zero. These layers repeat with the opponent's positions. 
AlphaZero maintains 8 timesteps of history, but we will try to keep only 3, if that. While the connections
from these planes into the neural net
will be learned and truly will be input, I suspect their main use will be checking to see if a 
position has repeated leading to a draw. AlphaZero for Chess contains two more constant planes, 
either 1 or 0, depending on if the current position has repeated once or twice before.
Probably the order matters, although I am not sure. Since brandub ends in a draw with only a 
single repeated position, we do not need to encode this information for the player.

In addition we will include a constant plane that is not repeated to indicate which player is
playing, either 0 for defense or 1 for attack.

Overall, this leads to a game state tensor of size $7\times7\times(4\times3 + 1)$, i.e.,
we have 637 features.

In [31]:
import torch

def get_initial_game_state():
    """returns pytorch tensor encoding initial brandub game state"""
    
    game_state = torch.zeros([7, 7, 13], dtype=torch.float64)
    
    # attacking pawns
    game_state[3, 0, 0] = 1
    game_state[3, 1, 0] = 1
    game_state[3, 5, 0] = 1
    game_state[3, 6, 0] = 1
    game_state[0, 3, 0] = 1
    game_state[1, 3, 0] = 1
    game_state[5, 3, 0] = 1
    game_state[6, 3, 0] = 1
    
    # defensive pawns
    game_state[3, 2, 2] = 1
    game_state[3, 4, 2] = 1
    game_state[2, 3, 2] = 1
    game_state[4, 3, 2] = 1
    
    # defensive monarch
    game_state[3, 3, 3] = 1
    
    # attack's turn
    game_state[:, :, 12] = 1
    
    return game_state

Now we write some functions to check for end game conditions. While I won't be optimizing this
code too much, I won't bother to check types and board legality.

In [126]:
def get_monarch_index(game_state):
    """ return which game state layer encodes the defensive monarch 
    for the current board """
    if game_state[0, 0, 12] == 1:
        return 3
    else:
        return 1

    
def attack_victory(board, monarch_index):
    """ attack wins if the monarch is not on the current board"""
    
    return board[:, :, monarch_index].max() == 0


def defense_victory(board, monarch_index):
    """ defense wins if the monarch is in the forest """
    
    return board[0, 0, monarch_index] == 1 or \
        board[0, 6, monarch_index] == 1 or \
        board[6, 0, monarch_index] == 1 or \
        board[6, 6, monarch_index] == 1


def board_shadow(board, is_attack_turn):
    """ return a 7x7 tensor encoding current state with
    -1 being an attacking pawn, 1 a defensive pawn,
    2 the monarch.
    
    board is a 7x7x4 tensor
    is_attack_turn is a boolean
    
    This is used to look for draws with game states"""
    
    raw_shadow = board[:, :, 0] + 2 * board[:, :, 1] - \
        board[:, :, 2] - 2 * board[:, :, 3]
    
    if is_attack_turn:
        return -1 * raw_shadow
    else:
        return raw_shadow
    
    
def game_shadow(game_state):
    """ returns the board shadow for the current board
    and player. """
    
    return board_shadow(game_state[:, :, :4], game_state[0, 0, 12] == 1)
        

def drawn_game(board, game_state):
    """ take a new board and compare to the boards in the game state
    to see if it would result in a draw.
    
    NOTE: This assumes that the current player for board is DIFFERENT
    than the current player in the game_state. 
    
    NOTE: Do not bother checking most recent game state, as we
    assume the presented board is after a move. """
    
    is_attack_turn = game_state[0, 0, 12] == 0  # note: changing game state
    new_board_shadow = board_shadow(board, is_attack_turn)
    
    # check two moves ago
    board_shadow_1 = board_shadow(game_state[:, :, 4:8], is_attack_turn)
    
    if torch.all(torch.eq(new_board_shadow, board_shadow_1)):
        return True
    else:
        board_shadow_2 = board_shadow(game_state[:, :, 8:12],
                                     not is_attack_turn)
        return torch.all(torch.eq(new_board_shadow, board_shadow_2))

In [127]:
initial_game_state = get_initial_game_state()
print(game_shadow(initial_game_state))

tensor([[-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0.,  1., -0., -0., -0.],
        [-1., -1.,  1.,  2.,  1., -1., -1.],
        [-0., -0., -0.,  1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.],
        [-0., -0., -0., -1., -0., -0., -0.]], dtype=torch.float64)


### Movement

Movement will be encoded as a $7\times7\times24$ tensor. The first two coordinates
indicate the position on the board where a piece should be found and the
final coordinate indicates the direction. The value modulo 6 then adding 1 (not
modulo 6) is the amount of spaces to move, raw values between 0 and 5 indicate
moving down, 6 to 11 moving up, 12 to 17 moving right, 18 to 23 moving
left.

This will be the shape of the probability vector put out by the neural net and
will include invalid moves. The first step of the MCTS will be to reduce this
only to valid moves.

A singular move will be an index tensor that we can samples using the `Categorical`
object in `torch.distributions.categorical`.

We will need auxiliary functions to determine legal moves based on the
current board, to map from a raw probability vector from the network to probabilities
of only legal moves, and then a function which takes the game state and a legal
move and returns a new game state. We start with the functions which find all
valid moves.

In [240]:
def is_forest(position_2d):
    """ returns whether the first two coords of
    position_2d is a forest"""
    return position_2d[0] in [0, 6] and position_2d[1] in [0, 6]


def is_castle(position_2d):
    """ returns whether the first two coords of position_2d
    is a castle"""
    return position_2d[0] == 3 and position_2d[1] == 3


def valid_moves(board, position):
    """ find valid movement indices 0-24 for a piece
    that is (assumed to be) at position, which is a list like with length 2"""
    
    is_pawn = board[position[0], position[1], 1] == 0
    
    shadow = board[:, :, 0] + board[:, :, 1] + board[:, :, 2] + board[:, :, 3]
    
    def check_direction(direction_vector):
        """ get valid moves in the given direction
        direction must be [+/-1, 0] or [0, +/-1] torch tensors
        """

        valid_moves = []

        coord = 0 if direction_vector[0] != 0 else 1
        positive_direction = direction_vector[coord] == 1
        end_value = 6 if positive_direction else 0
        
        if coord == 0:
            if positive_direction:
                base = 0
            else:
                base = 6
        else:
            if positive_direction:
                base = 12
            else:
                base = 18

        keep_going = position[coord] != end_value
        i = 0
        while keep_going:
            i += 1

            new_pos = position + i * direction_vector

            # stop if run into a piece
            if shadow[tuple(new_pos)] == 1:
                break

            # ignore the castle
            if is_castle(new_pos):
                continue

            keep_going = new_pos[coord] != end_value

            # if pawn and at the wall, see if its a forest but don't add
            if not keep_going and is_pawn and is_forest(new_pos):
                break

            valid_moves.append(base + i - 1)

        return valid_moves
    
    direction_vectors = (
        torch.tensor([1, 0]),
        torch.tensor([-1, 0]),
        torch.tensor([0, 1]),
        torch.tensor([0, -1])
    )

    return [
        direction
        for dvec in direction_vectors
        for direction in check_direction(dvec)
    ]
    
    
def all_valid_moves(board):
    """ get all indices  of valid moves. Only pieces on the first
    two planes can move. 
    
    Returns 7x7x24 binary tensor """
    
    # get all pieces that can move
    positions = [pos[:2] for pos in board[:, :, :2].nonzero()]
    pos_moves = [
        (pos[0], pos[1], move_index)
        for pos in positions
        for move_index in valid_moves(board, pos)
    ]
    
    moves = torch.zeros(7, 7, 24)
    for pos_move in pos_moves:
        moves[pos_move] = 1
        
    return moves

Now we write the functions that allow us to move pieces and then remove any pieces that have
been captured.

In [None]:
def capture_pieces(board):
    """ remove any pieces on the board which are captured """
    return board


def move_piece(board, move):
    """ take the board and do the move and return a new board
    after that move. Move is a tensor of size 3.
    
    NOTE: this will perform invalid moves, although it
    does guarantee that a piece was there.
    
    This will flip the attacker and defender."""
    
    # piece position
    i, j = move[0], move[1]
    
    # whether to move vertically or horizontally
    move_vertical = move[2] < 12
    
    # how much to move
    direction = 1 if move[2] % 12 < 6 else -1
    move_val = direction * (move[2] % 6 + 1)
    
    # find where the piece is
    plane = None
    k = 0
    while plane is None and k < 4:
        if board[i, j, k] == 1:
            plane = k
        k += 1
    
    assert plane is not None, "No piece found"
    
    # move the piece
    board_1 = torch.zeros(7, 7, 4, dtype=torch.float64)
    board_1[:, :, :] = board
    
    board_1[i, j, plane] = 0
    
    if move_vertical:
        board_1[i + move_val, j, plane] = 1
    else:
        board_1[i, j + move_val, plane] = 1
        
    
    # capture pieces
    board_2 = capture_pieces(board_1)
    
    
    # flip players
    new_board = torch.zeros(7, 7, 4, dtype=torch.float64)
    new_board[:, :, :2] = board_2[:, :, 2:]
    new_board[:, :, 2:] = board_2[:, :, :2]
    
    return new_board


def advance_game_state(board, game_state):
    """ return a new game state by adding the given
    board.
    
    The game may or may not be over at this point. """
    
    new_game_state = tensor.zeros(7, 7, 13, dtype=torch.float64)
    
    # update boards
    new_game_state[:, :, :4] = board
    new_game_state[:, :, 4:8] = game_state[:, :, :4]
    new_game_state[:, :, 8:12] = game_state[:, 4:8]
    
    # update player
    new_game_state[:, :, 12] = 1 - game_state[:, :, 12]
    
    return new_game_state