In [234]:
import h5py
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import time
import torch
from uuid import uuid4

print(f"cuda available: {torch.cuda.is_available()}")
print(f"torch version: {torch.__version__}")
print(f"numpy version: {np.__version__}")

cuda available: False
torch version: 2.1.0
numpy version: 1.26.0


# Hexz

A board is represented by an (N, 10, 11) numpy array. Each 10x11 slice is a one-hot encoding of the presence of specific type of piece/obstacle/etc. The slices are:

* 0: flags by P1
* 1: cell value 1-5 for P1
* 2: cells blocked for P1 (any occupied cell or a cell next to a 5)
* 3: next value for P1 
* 4: flags by P2
* 5: cell value 1-5 for P2
* 6: cells blocked for P2
* 7: next value for P2
* 8: grass cells with value 1-5

An action is specified by a (2, 10, 11) numpy array. The first 10x11 slice represents a flag move,
the second one represents a regular cell move. A flag move must have a single 1 set, a normal move
must have a single value 1-5 set.

In [122]:
def valid_idx(r_c):
    """Returns True if (r, c) = r_c represents a valid hexz board index."""
    r, c = r_c
    return r >= 0 and r < 11 and c >= 0 and c < 10-r%2


def neighbors_map():
    """Returns a dict mapping all valid (r, c) indices to their neighbor indices.
    
    The neighbor indices are represented as (row, column) tuples."""
    result = {}
    for r in range(11):
        shift = r%2 # Depending on the row, neighbors below and above are shifted.
        for c in range(10-r%2):
            ns = filter(valid_idx, [
                (r, c+1),
                (r-1, c+shift),
                (r-1, c-1+shift),
                (r, c-1),
                (r+1, c-1+shift),
                (r+1, c+shift),
            ])
            nr, nc = zip(*ns)  # unzip
            result[(r, c)] = (np.array(nr), np.array(nc))
            
    return result


In [124]:
class Board:
    """Numpy representation of a hexz board."""

    # Used to quickly get the indices of neighbor cells.
    neighbors = neighbors_map()
    
    def __init__(self, other=None):
        """Generates a new randomly initialized board or returns a copy of other, if set."""
        if other:
            self.b = other.b.copy()
            self.nflags = list(other.nflags)
            return
        self.b = np.zeros((9, 11, 10))
        self.nflags = [3, 3]  # number of flags remaining per player
        # Even rows have 10 cells, odd rows only 9, so mark the last cell in odd rows as blocked for P1+P2.
        self.b[2, [1, 3, 5, 7, 9], 9] = 1
        self.b[6, [1, 3, 5, 7, 9], 9] = 1        
        # 2-tuple of valid indices in each slice.
        free_cells = (1 - self.b[2]).nonzero()
        # 15 randomly placed stones.
        rng = np.random.default_rng()
        stones = rng.choice(np.arange(0, len(free_cells[0])), replace=False, size=15)
        self.b[2, free_cells[0][stones], free_cells[1][stones]] = 1
        self.b[6, free_cells[0][stones], free_cells[1][stones]] = 1
        free_cells = (1 - self.b[2]).nonzero()
        # 5 grass cells
        grass = rng.choice(np.arange(0, len(free_cells[0])), replace=False, size=5)
        self.b[8, free_cells[0][grass], free_cells[1][grass]] = [1, 2, 3, 4, 5]
        self.b[2, free_cells[0][grass], free_cells[1][grass]] = 1
        self.b[6, free_cells[0][grass], free_cells[1][grass]] = 1
    
    # Helpers to retrieve slices of the board "by name".
    def flags(self, player):
        return self.b[0 + player * 4]
    def values(self, player):
        return self.b[1 + player * 4]
    def blocked(self, player):
        return self.b[2 + player * 4]
    def next_values(self, player):
        return self.b[3 + player * 4]
    def grass(self):
        return self.b[8]
    def quickview(self):
        return (self.b[0] * 8) + self.b[1] - (self.b[4] * 8) - self.b[5]
        
    def score(self):
        """Returns the current score as a 2-tuple."""
        return (self.b[1].sum(), self.b[5].sum())
    
    def result(self):
        """Returns the final result of the board.
        
        1 (player 0 won), 0 (draw), -1 (player 1 won).
        """
        p0, p1 = self.score()
        return np.sign(p0 - p1)
        
    def make_move(self, player, move):
        """Makes the given move.
        
        Args:
          player: 0 or 1
          move: a 4-tuple of (typ, r, c, val), where typ = 0 (flag) or 1 (normal)
        Does not check that it is a valid move. Should be called only
        with moves returned from `next_moves`.
        """
        typ, r, c, val = move
        self.b[typ + player*4, r, c] = val
        played_flag = typ == 0
        # Block played cell for both players.
        self.b[2, r, c] = 1
        self.b[6, r, c] = 1
        # Set next value to 0 for occupied cell.
        self.b[3, r, c] = 0
        self.b[7, r, c] = 0
        # Block neighboring cells if a 5 was played.
        nx, ny = Board.neighbors[(r, c)]
        # Update next value of neighboring cells. If we played a flag, the next value is 1.
        if played_flag:
            next_val = 1
            self.nflags[player] -= 1
        else:
            next_val = val + 1
        if next_val <= 5:
            for nr, nc in zip(nx, ny):
                if self.b[2 + player*4, nr, nc] == 0:
                    if next_val > 5:
                        self.b[3 + player*4, nr, nc] = 0
                    if self.b[3 + player*4, nr, nc] == 0:
                        self.b[3 + player*4, nr, nc] = next_val
                    elif self.b[3 + player*4, nr, nc] > next_val:
                        self.b[3 + player*4, nr, nc] = next_val
        else:
            # Played a 5: block neighboring cells and clear next value.
            self.b[2 + player*4, nx, ny] = 1
            self.b[3 + player*4, nx, ny] = 0  # Clear next value.

        # Occupy neighboring grass cells.
        if not played_flag:
            self.occupy_grass(player, r, c)
            
    def occupy_grass(self, player, r, c):
        """Occupies the neighboring grass cells of move_idx (a 3-tuple index into a move) for player.
        
        Expects that the move has already been played.
        """
        nx, ny = Board.neighbors[(r, c)]
        for i, j in zip(nx, ny):
            grass_val = self.b[8, i, j]
            if grass_val > 0 and grass_val <= self.b[1 + player*4, r, c]:
                # Occupy: first remove grass
                self.b[8, i, j] = 0                
                # the rest is exactly like playing a move.
                self.make_move(player, (1, r, c, grass_val))
        
    def next_moves(self, player):
        """Returns all possible next moves.
        
        A move is represented as a (2, 11, 10) ndarray. The first slice represents
        flag moves, the second one represents normal moves. A flag move will have exactly
        one element set to 1 in slice 0. A normal move will have exactly one element set to
        1-5 in slice 1.
        """
        moves = []
        # Do we have unoccupied cells and flags left? Then we can place another one.
        if self.nflags[player] > 0:
            # Flag in any unoccupied cell.
            rs, cs = np.nonzero(self.b[2 + player*4] == 0)  # funky way to get indices for all free cells.
            moves.extend((0, r, c, 1) for r, c in zip(rs, cs))
        # Collect all cells with a non-zero next value.
        rs, cs = np.nonzero(self.b[3 + player*4])
        moves.extend((1, r, c, self.b[3 + player*4, r, c]) for r, c in zip(rs, cs))
        return moves


In [125]:
%%timeit
rng = np.random.default_rng()
b = Board()
player = 0
moves = b.next_moves(player)
num_moves = 0
while moves:
    b.make_move(player, moves[rng.integers(0, len(moves))])
    num_moves += 1
    player = 1 - player
    moves = b.next_moves(player)
    if not moves:
        # No more moves for the player. See if the other player can continue.
        player = 1 - player
        moves = b.next_moves(player)
# print(f"Done after {num_moves} moves. Flags left: {b.nflags}. Score: {b.score()}")
# b.quickview()

761 µs ± 4.37 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [259]:
class Node:
    """Nodes of the MCTS tree."""
    def __init__(self, parent, player, move):
        self.parent = parent
        self.player = player
        self.move = move
        self.wins = 0.0
        self.visit_count = 0
        self.children = []
    
    def uct(self):
        if self.parent.visit_count == 0:
            return 0.5
        if self.visit_count == 0:
            win_rate = 0.5
            adjusted_count = 1
        else:
            win_rate = self.wins / self.visit_count
            adjusted_count = self.visit_count
        return (
            win_rate + 
            1.0 * np.sqrt(np.log(self.parent.visit_count) / adjusted_count)
        )

    def __str__(self):
        return f"Node(p={self.player}, m={self.move}, w={self.wins/self.visit_count:.3f}, n={self.visit_count}, u={self.uct():.3f}, cs={len(self.children)})"

    def __repr__(self):
        return str(self)
    
    def move_likelihoods(self):
        """Returns the move likelihoods for all children as a (2, 11, 10) ndarray.
        
        The ndarray indicates the likelihoods (based on visit count) for flags
        and normal moves. It sums to 1.
        """
        p = np.zeros((2, 11, 10))
        for child in self.children:
            typ, r, c, _ = child.move
            p[typ, r, c] = child.visit_count
        return p / p.sum()
    
    def best_child(self):
        """Returns the best among all children.
        
        The best child is the one with the greatest visit count, a common
        choice in the MCTS literature.
        """
        return max(self.children, default=None, key=lambda n: n.visit_count)
    

class Example:
    """Data holder for one step of a fully played MCTS game."""
    
    def __init__(self, game_id, board, move_probs, turn, result):
        """
        Args:
            game_id: arbitrary string identifying the game that this example belongs to.
            board: the board (Board.b) as an (N, 11, 10) ndarray.
            move_probs: (2, 11, 10) ndarray of move likelihoods.
            turn: 0 or 1, indicating which player's turn it was.
            result: -1, 0, 1, indicating the player that won (1 = player 0 won).
        """
        self.game_id = game_id
        self.board = board
        self.move_probs = move_probs
        self.turn = turn
        self.result = result
    
    @classmethod
    def save_all(self, path, examples):
        """Saves the examples in a HDF5 file at the given path."""
        with h5py.File(path, "w") as h5:
            for i, ex in enumerate(examples):
                grp = h5.create_group(f"{i:08}")
                grp.attrs["game_id"] = ex.game_id
                grp.create_dataset("board", data=ex.board)
                grp.create_dataset("move_probs", data=ex.move_probs)
                grp.create_dataset("turn", data=np.array([ex.turn]))
                grp.create_dataset("result", data=np.array([ex.result]))
    
    @classmethod
    def load_all(cls, path):
        """This method is for testing only.
        
        Use a HexzDataset to access the examples from PyTorch."""
        examples = []
        with h5py.File(path, "r") as h5:
            for n in h5:
                ex = h5[n]
                turn = ex["turn"][0]
                result = ex["result"][0]
                examples.append(
                    Example(ex.attrs["game_id"], ex["board"][:], ex["move_probs"][:], turn, result)
                )
        return examples
        
        
class MCTS:
    """Monte Carlo tree search."""
    def __init__(self, board, game_id=None):
        self.board = board
        # In the root node it's player 1's "fake" turn. 
        # This has the desired effect that the root's children will play
        # as player 0, who makes the first move.
        self.root = Node(None, 1, None)
        self.rng = np.random.default_rng()
        if not game_id:
            game_id = uuid4().hex[:12]
        self.game_id = game_id
        
    def rollout(self, board, player):
        """Play a random game till the end, starting with board and player on the move."""
        while True:
            moves = board.next_moves(player)
            if not moves:
                # No more moves for the player. See if the other player can continue.
                player = 1 - player
                moves = board.next_moves(player)
            if not moves:
                return board.result()
            board.make_move(player, moves[self.rng.integers(0, len(moves))])
            player = 1 - player

    def backpropagate(self, node, result):
        while node:
            node.visit_count += 1
            if node.player == 0:
                node.wins += (result + 1) / 2
            else:
                node.wins += (-result + 1) / 2
            node = node.parent

    def size(self):
        q = [self.root]
        s = 0
        while q:
            n = q.pop()
            s += 1
            q.extend(n.children)
        return s
        
    def run(self):
        """Runs a single round of the MCTS loop."""
        b = Board(self.board)
        n = self.root
        # Find leaf node.
        while n.children:
            best = None
            best_uct = -1
            for c in n.children:
                c_uct = c.uct()
                if c_uct > best_uct:
                    best = c
                    best_uct = c_uct
            b.make_move(best.player, best.move)
            n = best
        # Reached a leaf node: expand
        player = 1 - n.player  # Usually it's the other player's turn.
        moves = b.next_moves(player)
        if not moves:
            # No more moves for player. Try other player.
            player = 1 - player
            moves = b.next_moves(player)
        if not moves:
            # Game is over
            self.backpropagate(n, b.result())
            return
        # Rollout time!
        for move in moves:
            n.children.append(Node(n, player, move))
        c = n.children[self.rng.integers(0, len(n.children))]
        b.make_move(c.player, c.move)
        result = self.rollout(b, 1 - c.player)
        self.backpropagate(c, result)
        
    def play_game(self, runs_per_move=500):
        """Plays one full game and returns the move likelihoods per move and the final result.
        
        Args:
            runs_per_move: number of MCTS runs to make per move.
        """
        examples = []
        result = None
        n = 0
        started = time.perf_counter()
        while n < 200:
            for i in range(runs_per_move):
                self.run()
            best_child = self.root.best_child()
            if not best_child:
                # Game over
                result = self.board.result()
                break
            examples.append(Example(self.game_id, self.board.b.copy(), 
                                    self.root.move_likelihoods(), best_child.player, None))
            # Make the move.
            self.board.make_move(best_child.player, best_child.move)
            self.root = best_child
            self.root.parent = None  # Allow GC and avoid backprop further up.
            if n < 5 or n%10 == 0:
                print(f"Iteration {n}: visit_count:{best_child.visit_count} ",
                      f"move:{best_child.move} player:{best_child.player} score:{self.board.score()}")
            n += 1
        if n == 200:
            raise ValueError(f"Iterated {n} times. Something's fishy.")
        elapsed = time.perf_counter() - started
        print(f"Done in {elapsed:.3f}s after {n} moves. Final score: {self.board.score()}.")
        # Update all examples with result.
        for ex in examples:
            ex.result = result
        return examples


In [139]:
%%timeit
b = Board()
m = MCTS(b)
for _ in range(1000):
    m.run()
# m.root.children

877 ms ± 7.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [251]:
m = MCTS(Board())
xs = m.play_game(1000)

Iteration 0: visit_count:35  move:(0, 6, 3, 1) player:0 score:(0.0, 0.0)
Iteration 10: visit_count:259  move:(1, 9, 4, 3.0) player:0 score:(6.0, 3.0)
Iteration 20: visit_count:136  move:(1, 1, 5, 1.0) player:0 score:(16.0, 12.0)
Iteration 30: visit_count:120  move:(1, 1, 7, 3.0) player:0 score:(31.0, 24.0)
Iteration 40: visit_count:82  move:(1, 2, 8, 4.0) player:0 score:(48.0, 35.0)
Iteration 50: visit_count:59  move:(1, 0, 1, 5.0) player:0 score:(63.0, 55.0)
Iteration 60: visit_count:111  move:(1, 1, 4, 1.0) player:0 score:(78.0, 69.0)
Iteration 70: visit_count:202  move:(1, 8, 0, 4.0) player:0 score:(93.0, 84.0)
Done in 42.563s after 80 moves. Final score: (120.0, 84.0).


In [256]:
Example.save_all("funny.h5", xs)

In [261]:
Example.load_all("funny.h5")[0].game_id

'c4728ea5-a46c-4ddd-ae27-445db3bd5db2'

In [269]:
class HexzDataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.h5 = h5py.File(path, "r")
        self.length = len(self.h5)
        self.keys = list(self.h5.keys())
        
    def __getitem__(self, k):
        data = self.h5[self.keys[k]]
        # turn = data["turn"][0]
        board = torch.from_numpy(data["board"][:])
        result = torch.from_numpy(data["result"][:])
        # TODO: how can we implement the "multi-head" network here? We have two labels.
        # move_probs = torch.from_numpy(data["move_probs"][:])
        return (board, result)

    def __len__(self):
        return self.length
        

In [267]:
ds = HexzDataset("funny.h5")

In [268]:
ds[0]

NameError: name 'h5' is not defined