In [62]:
import chex
import jax.numpy as jnp
from jax.scipy.signal import correlate2d
import jax
from enum import Enum
from typing import List, Tuple
from jax import lax
from dataclasses import field, dataclass

In [2]:
GOMOKU_BOARD_SIZE = 20


class Player(Enum):
    PLAYER_1 = 1
    PLAYER_2 = 2

    # A methord to get the other
    @classmethod
    def other(cls, player: int):
        return lax.cond(
            player == cls.PLAYER_1.value,
            lambda _: cls.PLAYER_2.value,
            lambda _: cls.PLAYER_1.value,
            None,
        )


@chex.dataclass(frozen=True)
class Board:
    """
    This holds the current state of the board which is a 2D array of size GOMOKU_BOARD_SIZE x GOMOKU_BOARD_SIZE.
    Moves of the players are represented by 1 and 2 respectively.
    """

    # 0 for empty, 1 for player 1, 2 for player 2.
    state: chex.ArrayDevice = field(
        default_factory=lambda: jnp.zeros(
            (GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE), dtype=jnp.int32
        )
    )
    # Choosing not to retain enums as they are not stackable with jax.
    player: int = Player.PLAYER_1.value  # The player who is making the next move.
    # An equals method for testing.

    @classmethod
    def make(cls, state: chex.ArrayDevice, player: Player) -> "Board":
        """
        A helper function to create a board object.
        """
        return cls(state=state, player=player.value)

    @classmethod
    def expand(cls, board: "Board", num=1) -> "Board":
        """
        A helper function to create a board object with an added dimension
        Helps in creating inputs for vmapped functions
        """
        return cls(
            state=jnp.stack([board.state] * num, axis=0),
            player=jnp.stack([board.player] * num, axis=0),
        )

    @classmethod
    def make_canonical(cls, state: chex.ArrayDevice) -> "Board":
        """
        A helper function to create a board object.
        """
        return cls(state=state, player=Player.PLAYER_1.value)

    def __eq__(self, other):
        return jnp.all(self.state == other.state) and self.player == other.player

    def canonical(self):
        """
        This function returns a board such that it's always the player 1's turn.
        Useful when running self-play style search algorithms
        """

        def flip():
            # flip the player values
            player_1_spots = self.state == Player.PLAYER_1.value
            player_2_spots = self.state == Player.PLAYER_2.value
            state = jnp.zeros(
                (GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE), dtype=self.state.dtype
            )
            state = jnp.where(player_1_spots, Player.PLAYER_2.value, state)
            state = jnp.where(player_2_spots, Player.PLAYER_1.value, state)
            return Board.make(state=state, player=Player.PLAYER_1)

        return lax.cond(
            self.player == Player.PLAYER_1.value, lambda: self, lambda: flip()
        )

    def invalid_actions(self) -> chex.ArrayDevice:
        """
        This function returns an array of boolean values indicating which actions are invalid
        All moves which are invalid are marked as True
        """
        invalid_actions = jnp.where(self.state != 0, True, False)
        # Flatten the last two dimensions
        return invalid_actions.reshape(*invalid_actions.shape[:-2], -1)


@chex.dataclass(frozen=True)
class GameOutcome:
    """
    This holds the outcome of a game.
    """

    class WinType(Enum):
        ROW = 0
        COL = 1
        DIAG = 2
        ALT_DIAG = 3
        DRAW = 4

    winner: int = -1  # The winner of the game.
    row_loc: int = -1  # Row location of the winning move.
    col_loc: int = -1  # Column location of the winning move.
    win_type: int = -1  # Type of the winning move.


@chex.dataclass(frozen=True)
class StepOutput:
    """
    This holds the output of the step function.
    """

    board: Board
    reward: int
    done: bool
    # An equals method for testing.

    def __eq__(self, other):
        return (
            self.board == other.board
            and self.reward == other.reward
            and self.done == other.done
        )


class GomokuEnv:
    @classmethod
    def action_to_index(cls, action: int) -> (int, int):
        """
        This function takes an action and returns the corresponding 2D index.
        param action: An integer between 0 and 399.
        """
        row = action // GOMOKU_BOARD_SIZE
        col = action % GOMOKU_BOARD_SIZE
        return (row, col)

    @classmethod
    def kernels(cls, kernel_size=5) -> jnp.ndarray:
        """
        This function returns the kernels for checking if there is a winner.
        """
        row_kernel = jnp.zeros((kernel_size, kernel_size))
        # set the first row to 1.
        row_kernel = row_kernel.at[0, :].set(1)
        col_kernel = row_kernel.T
        diag_kernel = jnp.eye(kernel_size)
        alt_diag_kernel = jnp.fliplr(diag_kernel)
        # All kernels stacked together.
        kernels = jnp.stack([row_kernel, col_kernel, diag_kernel, alt_diag_kernel])
        return kernels

    @classmethod
    def _expanded_board(cls, board: Board) -> jnp.ndarray:
        """
        This function takes a board and returns an expanded version of it.
        The expanded board has one channel for each player.
        param board: A board object whose state needs to be expanded.
        """
        return jnp.stack(
            [board.state == Player.PLAYER_1.value, board.state == Player.PLAYER_2.value]
        ).astype(jnp.int32)

    @classmethod
    def _adjust_location(cls, location, kernel_size):
        """
        The win locations will indicate how many strides row-wise and column wise the kernel
        took to reach the winning move. We need to subtract kernel_size-1 from row and column for all
        the outputs except for the alternate diagonal kernel whose offset is just kernel_size-1 for the
        row index and 0 for the column index.
        """
        kernel_index = location[0]
        row = location[1]
        col = location[2]

        # Apply conditional logic
        col_adjuster = lambda: jax.lax.cond(
            kernel_index % 4 == 3,
            lambda _: kernel_size
            - 1,  # Revert the column index for the alternate diagonal kernel
            lambda _: 0,
            operand=None,
        )

        return jax.lax.cond(
            kernel_index == -1,
            lambda _: jnp.array(
                [-1, -1, -1]
            ),  # If there is no winner, return -1 for all values
            lambda _: jnp.array(
                [
                    kernel_index,
                    row - (kernel_size - 1),
                    col - (kernel_size - 1) + col_adjuster(),
                ]
            ),
            operand=None,
        )

    @classmethod
    def _win_locations(cls, board: Board) -> jnp.ndarray:
        """
        This function takes a board and finds the locations of the winning move.
        param board: A board object whose state needs to be expanded.
        returns: A 2D array of size Nx3 where N is the number of winning moves as computed by the location where the kernels were operated on.
        The three columns are the (kernel index per player, row index and column index).
        0-3 kernel index corresponds to player 1 and 4-8 kernel index corresponds to player 2.
        It returns an array of size (GOMOKU_BOARD_SIZE*GOMOKU_BOARD_SIZE, 3)
        """
        kernels = cls.kernels()
        board_expanded = cls._expanded_board(board)
        # valid mode of convolution is a mode that does not extend the board.
        kernel_outputs = jnp.array(
            [
                correlate2d(board_of_player, kernel, mode="full")
                for board_of_player in board_expanded
                for kernel in kernels
            ]
        )
        # Check if there is a winner using argwhere for the value is equal to the number of pieces in a row.
        locations = jnp.argwhere(
            kernel_outputs == kernels.shape[1],
            size=GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE,
            fill_value=-1,
        )

        # Vectorize the adjust_location function
        adjust_locations_vectorized = jax.vmap(cls._adjust_location, in_axes=(0, None))

        # Use the vectorized function on all locations
        adjusted_locations = adjust_locations_vectorized(locations, kernels.shape[1])

        return adjusted_locations

    @classmethod
    def assert_valid(cls, board: Board) -> bool:
        """
        A helper function to check if a board is valid. This checks the following:
        1. The board is a 2D array of size 20x20.
        2. The board has only 0, Player.PLAYER_1.value and Player.PLAYER_2.value
        3. It does not have multiple winners.
        4. Each player played the same number of moves or one more move.
        Raise a unique error if any of the above conditions are not satisfied.
        """
        # Check the shape of the board.
        if board.state.shape != (GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE):
            raise ValueError(
                f"The board state is not of shape ({GOMOKU_BOARD_SIZE}, {GOMOKU_BOARD_SIZE})."
            )
        # Check that the board has only 0, 1 and 2.
        if not jnp.all(
            jnp.isin(
                board.state,
                jnp.array(
                    [0, Player.PLAYER_1.value, Player.PLAYER_2.value], dtype=jnp.int32
                ),
            )
        ):
            raise ValueError(
                f"The board state has values other than 0, {Player.PLAYER_1.value} and {Player.PLAYER_2.value}."
            )
        # Check that the board does not have multiple winners.
        if jnp.count_nonzero(cls._win_locations(board)[:, 0] != -1) > 1:
            raise ValueError("The board has multiple winners.")
        # Check that the number of moves played by each player is the same or one more.
        if jnp.sum(board.state == Player.PLAYER_1.value) - jnp.sum(
            board.state == Player.PLAYER_2.value
        ) not in [0, 1]:
            raise ValueError(
                f"The number of moves played by each player is not the same or one more."
                f"Player 1: {jnp.sum(board.state == Player.PLAYER_1.value)},"
                f"Player 2: {jnp.sum(board.state == Player.PLAYER_2.value)}"
            )
        return True

    @classmethod
    def outcome(cls, board: Board) -> GameOutcome:
        """
        Determine the outcome of the game given the current board state.
        """

        def has_winner(locations):
            win_type_code = locations[0, 0]
            player = jax.lax.cond(
                win_type_code < 4,
                lambda _: Player.PLAYER_1.value,
                lambda _: Player.PLAYER_2.value,
                None,
            )
            win_type = win_type_code % 4  # 0-3 for player 1 and 4-7 for player 2
            row_loc = locations[0, 1]
            col_loc = locations[0, 2]
            return GameOutcome(
                winner=player, row_loc=row_loc, col_loc=col_loc, win_type=win_type
            )

        def no_winner(_):
            # If the board is full, it's a draw.
            draw = jnp.all(board.state != 0)
            return GameOutcome(
                winner=-1,
                row_loc=-1,
                col_loc=-1,
                win_type=jax.lax.cond(
                    draw, lambda _: GameOutcome.WinType.DRAW.value, lambda _: -1, None
                ),
            )

        locations = cls._win_locations(board)

        # Use lax.cond to determine if there is a winner or not.
        # If there is a winner, then at least one of the locations will be non -1.
        return lax.cond(
            (locations.all() != -1) & (jnp.count_nonzero(locations[:, 0] != -1) == 1),
            lambda locations: has_winner(locations),
            lambda locations: no_winner(locations),
            locations,
        )

    @classmethod
    def step(cls, board: Board, action: int, player: int) -> StepOutput:
        """
        This function takes a board state and an action and returns the new board state.
        This function also checks if the game is over, and returns the reward.
        param board: A valid board object.
        param action: An integer between 0 and 399.
        param player: The player who is making the move.
        """

        def game_over(_):
            return StepOutput(board=board, reward=-1, done=True)

        def valid_action(_):
            row, col = cls.action_to_index(action)
            return lax.cond(
                (board.state[row, col] != 0) | (player != board.player),
                game_over,
                update_state,
                None,
            )

        def update_state(_):
            row, col = cls.action_to_index(action)
            new_state = Board(
                state=board.state.at[row, col].set(player), player=Player.other(player)
            )
            outcome = cls.outcome(new_state)
            reward = lax.cond(
                outcome.winner == Player.PLAYER_1.value,
                lambda _: 1,
                lambda _: lax.cond(
                    outcome.winner == Player.PLAYER_2.value,
                    lambda _: -1,
                    lambda _: 0,
                    None,
                ),
                None,
            )
            return StepOutput(
                board=new_state, reward=reward, done=outcome.win_type != -1
            )

        # Check if the action is out of range or the game is already over.
        return lax.cond(
            (action < 0)
            | (action >= GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE)
            | (cls.outcome(board).win_type != -1),
            game_over,
            valid_action,
            None,
        )

In [3]:
# A class to produce gomoku scenarios easily for unit testing.


@chex.dataclass(frozen=True)
class GomokuStroke:
    class StrokeType(Enum):
        ROW = 0
        COL = 1
        DIAG = 2
        ALT_DIAG = 3

    """Describes a stroke to paint on the board."""
    player: int
    row: int
    col: int
    stroke_type: StrokeType
    length: int


class GomokuScenario:
    """
    This class is used to produce scenarios for unit testing. Give it a list of strokes and it will produce a board.
    """

    @classmethod
    def make_scenario(cls, strokes: List[GomokuStroke], end_player: Player):
        """
        This function takes a list of strokes and produces a board.
        The player field in the returned board is set to the end_player.
        """
        state = jnp.zeros((GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE), dtype=jnp.int32)
        for stroke in strokes:
            kernels = GomokuEnv.kernels(kernel_size=stroke.length)
            kernel = kernels[stroke.stroke_type.value]
            starting_column = (
                stroke.col
                if stroke.stroke_type != GomokuStroke.StrokeType.ALT_DIAG
                else stroke.col - stroke.length + 1
            )
            state = state.at[
                stroke.row : stroke.row + stroke.length,
                starting_column : starting_column + stroke.length,
            ].set(
                jnp.where(
                    kernel == 1,
                    stroke.player.value,
                    state[
                        stroke.row : stroke.row + stroke.length,
                        starting_column : starting_column + stroke.length,
                    ],
                )
            )
        return Board.make(state=state, player=end_player)

# Unit Tests


In [4]:
import unittest
import numpy as np

In [5]:
# Example usage and unit test
class TestGomokuScenarioMaker(unittest.TestCase):
    def test_scenario_with_all_stroke_types(self):
        strokes = [
            GomokuStroke(
                player=Player.PLAYER_1,
                row=0,
                col=0,
                stroke_type=GomokuStroke.StrokeType.ROW,
                length=5,
            ),
            GomokuStroke(
                player=Player.PLAYER_2,
                row=1,
                col=0,
                stroke_type=GomokuStroke.StrokeType.COL,
                length=2,
            ),
            GomokuStroke(
                player=Player.PLAYER_1,
                row=3,
                col=0,
                stroke_type=GomokuStroke.StrokeType.DIAG,
                length=3,
            ),
            GomokuStroke(
                player=Player.PLAYER_2,
                row=4,
                col=4,
                stroke_type=GomokuStroke.StrokeType.ALT_DIAG,
                length=5,
            ),
        ]
        scenario = GomokuScenario.make_scenario(strokes, Player.PLAYER_1)
        expected_array = jnp.array(
            [
                [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )

        np.testing.assert_array_equal(scenario.state, expected_array)
        self.assertEqual(scenario.player, Player.PLAYER_1.value)

    def test_scenario_with_one_alt_diag(self):
        scenario = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=6,
                    stroke_type=GomokuStroke.StrokeType.ALT_DIAG,
                    length=5,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=6,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        )

        expected_first_12_by_12 = jnp.array(
            [
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )

        np.testing.assert_array_equal(scenario.state[:12, :12], expected_first_12_by_12)


unittest.main(argv=["-k", "TestGomokuScenarioMaker"], verbosity=2, exit=False)

test_scenario_with_all_stroke_types (__main__.TestGomokuScenarioMaker.test_scenario_with_all_stroke_types) ... ok
test_scenario_with_one_alt_diag (__main__.TestGomokuScenarioMaker.test_scenario_with_one_alt_diag) ... ok

----------------------------------------------------------------------
Ran 2 tests in 0.834s

OK


<unittest.main.TestProgram at 0x1529e7110>

In [6]:
GOMOKU_BOARD_SIZE = 20


class TestGomokuEnv(unittest.TestCase):
    # Test the __action_to_index function.
    def test_action_to_index(self):
        self.assertEqual(GomokuEnv.action_to_index(0), (0, 0))
        self.assertEqual(GomokuEnv.action_to_index(399), (19, 19))
        self.assertEqual(GomokuEnv.action_to_index(200), (10, 0))
        self.assertEqual(GomokuEnv.action_to_index(399), (19, 19))

    # Test that the kernels are correct.
    def test_kernels(self):
        kernels = GomokuEnv.kernels()
        self.assertEqual(kernels.shape, (4, 5, 5))
        self.assertTrue(
            jnp.all(
                kernels[0]
                == jnp.array(
                    [
                        [1, 1, 1, 1, 1],
                        [0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0],
                    ]
                )
            )
        )
        self.assertTrue(
            jnp.all(
                kernels[1]
                == jnp.array(
                    [
                        [1, 0, 0, 0, 0],
                        [1, 0, 0, 0, 0],
                        [1, 0, 0, 0, 0],
                        [1, 0, 0, 0, 0],
                        [1, 0, 0, 0, 0],
                    ]
                )
            )
        )
        self.assertTrue(
            jnp.all(
                kernels[2]
                == jnp.array(
                    [
                        [1, 0, 0, 0, 0],
                        [0, 1, 0, 0, 0],
                        [0, 0, 1, 0, 0],
                        [0, 0, 0, 1, 0],
                        [0, 0, 0, 0, 1],
                    ]
                )
            )
        )
        self.assertTrue(
            jnp.all(
                kernels[3]
                == jnp.array(
                    [
                        [0, 0, 0, 0, 1],
                        [0, 0, 0, 1, 0],
                        [0, 0, 1, 0, 0],
                        [0, 1, 0, 0, 0],
                        [1, 0, 0, 0, 0],
                    ]
                )
            )
        )

    # Create a few boards that are invalid and check that assert_valid raises an error.
    def test_assert_invalid(self):
        # Test that the board is of the right shape.
        with self.assertRaises(ValueError):
            GomokuEnv.assert_valid(Board(state=jnp.zeros((10, 10), dtype=jnp.int32)))
        # Test that the board has only 0, 1 and 2.
        with self.assertRaises(ValueError):
            GomokuEnv.assert_valid(
                Board(state=312 * jnp.ones((20, 20), dtype=jnp.int32))
            )
        # Test that the board has only one winner.
        with self.assertRaises(ValueError):
            GomokuEnv.assert_valid(
                GomokuScenario.make_scenario(
                    [
                        GomokuStroke(
                            player=Player.PLAYER_1,
                            row=0,
                            col=0,
                            stroke_type=GomokuStroke.StrokeType.ROW,
                            length=5,
                        ),
                        GomokuStroke(
                            player=Player.PLAYER_2,
                            row=1,
                            col=0,
                            stroke_type=GomokuStroke.StrokeType.COL,
                            length=5,
                        ),
                    ],
                    Player.PLAYER_1,
                )
            )  # This has two winners.
        # Test that the number of moves played by each player is the same or one more.
        with self.assertRaises(ValueError):
            GomokuEnv.assert_valid(
                GomokuScenario.make_scenario(
                    [
                        GomokuStroke(
                            player=Player.PLAYER_1,
                            row=0,
                            col=0,
                            stroke_type=GomokuStroke.StrokeType.ROW,
                            length=4,
                        ),
                        GomokuStroke(
                            player=Player.PLAYER_2,
                            row=1,
                            col=0,
                            stroke_type=GomokuStroke.StrokeType.COL,
                            length=3,
                        ),
                        GomokuStroke(
                            player=Player.PLAYER_1,
                            row=2,
                            col=0,
                            stroke_type=GomokuStroke.StrokeType.COL,
                            length=2,
                        ),
                    ],
                    Player.PLAYER_1,
                )
            )

    def test_expand(self):
        board = Board()
        expanded_board = Board.expand(board)
        self.assertEqual(
            expanded_board.state.shape, (1, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE)
        )
        self.assertEqual(expanded_board.player.shape, (1,))

    # Create a few boards that are valid and check that assert_valid does not raise an error.
    def test_assert_valid(self):
        GomokuEnv.assert_valid(
            GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=0,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.ROW,
                        length=5,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=1,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=4,
                    ),
                ],
                Player.PLAYER_1,
            )
        )
        GomokuEnv.assert_valid(
            GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=0,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.ROW,
                        length=4,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=1,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=4,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=6,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=1,
                    ),
                ],
                Player.PLAYER_1,
            )
        )

    # Test that we can asses a winning outcome and correctly locate it in an alternate diagonal scenario
    def test_winning_outcome_alt_diag(self):
        outcome = GomokuEnv.outcome(
            GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=0,
                        col=6,
                        stroke_type=GomokuStroke.StrokeType.ALT_DIAG,
                        length=5,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=6,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=4,
                    ),
                ],
                Player.PLAYER_1,
            )
        )
        self.assertEqual(outcome.winner, Player.PLAYER_1.value)
        self.assertEqual(outcome.win_type, GameOutcome.WinType.ALT_DIAG.value)
        self.assertEqual(outcome.row_loc, 0)
        self.assertEqual(outcome.col_loc, 6)

    def test_winning_outcome_alt_diag_bottom_right_edge(self):
        outcome = GomokuEnv.outcome(
            GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=15,
                        col=15,
                        stroke_type=GomokuStroke.StrokeType.ALT_DIAG,
                        length=5,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=6,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=4,
                    ),
                ],
                Player.PLAYER_1,
            )
        )
        self.assertEqual(outcome.winner, Player.PLAYER_1.value)
        self.assertEqual(outcome.win_type, GameOutcome.WinType.ALT_DIAG.value)
        self.assertEqual(outcome.row_loc, 15)
        self.assertEqual(outcome.col_loc, 15)

    # Test that we can asses a winning outcome and correctly locate it in a row scenario
    def test_winning_outcome_row(self):
        outcome = GomokuEnv.outcome(
            GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=10,
                        col=10,
                        stroke_type=GomokuStroke.StrokeType.ROW,
                        length=5,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=6,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=4,
                    ),
                ],
                Player.PLAYER_1,
            )
        )
        self.assertEqual(outcome.winner, Player.PLAYER_1.value)
        self.assertEqual(outcome.win_type, GameOutcome.WinType.ROW.value)
        self.assertEqual(outcome.row_loc, 10)
        self.assertEqual(outcome.col_loc, 10)

    def test_winning_outcome_row_edge(self):
        outcome = GomokuEnv.outcome(
            GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=0,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.ROW,
                        length=5,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=6,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=4,
                    ),
                ],
                Player.PLAYER_1,
            )
        )
        self.assertEqual(outcome.winner, Player.PLAYER_1.value)
        self.assertEqual(outcome.win_type, GameOutcome.WinType.ROW.value)
        self.assertEqual(outcome.row_loc, 0)
        self.assertEqual(outcome.col_loc, 0)

    def test_winning_outcome_col_edge(self):
        outcome = GomokuEnv.outcome(
            GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=0,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=5,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=6,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=4,
                    ),
                ],
                Player.PLAYER_1,
            )
        )
        self.assertEqual(outcome.winner, Player.PLAYER_2.value)
        self.assertEqual(outcome.win_type, GameOutcome.WinType.COL.value)
        self.assertEqual(outcome.row_loc, 0)
        self.assertEqual(outcome.col_loc, 0)

    # Test that we can assess a ongoing game state
    def test_ongoing_outcome(self):
        # Create a random but valid game state
        outcome = GomokuEnv.outcome(
            GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=10,
                        col=10,
                        stroke_type=GomokuStroke.StrokeType.ROW,
                        length=3,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=6,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.COL,
                        length=4,
                    ),
                ],
                Player.PLAYER_1,
            )
        )
        self.assertEqual(outcome.winner, -1)
        self.assertEqual(outcome.win_type, -1)
        self.assertEqual(outcome.row_loc, -1)
        self.assertEqual(outcome.col_loc, -1)

    # Test that the step function works as expected in multiple scenarios
    def test_multiple_steps(self):
        # Create an empty board
        board = Board()
        output = GomokuEnv.step(board, 0, Player.PLAYER_1.value)
        expected_output = StepOutput(
            board=GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=0,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.ROW,
                        length=1,
                    ),
                ],
                Player.PLAYER_2,
            ),
            reward=0,
            done=False,
        )
        self.assertEqual(output, expected_output)
        # Now try adding another move at the same location and see that the state is done and reward is -1
        abrupt_end_output = GomokuEnv.step(output.board, 0, Player.PLAYER_2.value)
        self.assertEqual(abrupt_end_output.reward, -1)
        self.assertEqual(abrupt_end_output.done, True)
        # Now let's play the game a few more steps
        output = GomokuEnv.step(output.board, 100, Player.PLAYER_2.value)
        output = GomokuEnv.step(output.board, 1, Player.PLAYER_1.value)
        output = GomokuEnv.step(output.board, 101, Player.PLAYER_2.value)
        output = GomokuEnv.step(output.board, 2, Player.PLAYER_1.value)
        output = GomokuEnv.step(output.board, 102, Player.PLAYER_2.value)
        output = GomokuEnv.step(output.board, 3, Player.PLAYER_1.value)
        output = GomokuEnv.step(output.board, 103, Player.PLAYER_2.value)
        output = GomokuEnv.step(output.board, 4, Player.PLAYER_1.value)
        expected_final_output = StepOutput(
            board=GomokuScenario.make_scenario(
                [
                    GomokuStroke(
                        player=Player.PLAYER_1,
                        row=0,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.ROW,
                        length=5,
                    ),
                    GomokuStroke(
                        player=Player.PLAYER_2,
                        row=5,
                        col=0,
                        stroke_type=GomokuStroke.StrokeType.ROW,
                        length=4,
                    ),
                ],
                Player.PLAYER_2,
            ),
            reward=1,
            done=True,
        )
        self.assertEqual(output, expected_final_output)

    # Test that the canonical function flips the board correctly
    def test_canonical(self):
        canonical_board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=5,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        )
        board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=5,
                ),
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_2,
        )
        self.assertEqual(board.canonical(), canonical_board)

        # When we have a player 1 perspective canonical board is the same is the original board.
        self.assertEqual(canonical_board.canonical(), canonical_board)

    def test_reward_fn(self):
        # Create a scenario where player 1 wins and reward is 1
        scenario = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        )
        # Take a step
        output = GomokuEnv.step(scenario, 4, Player.PLAYER_1.value)
        # Check that the reward is 1
        self.assertEqual(output.reward, 1)
        # Create a scenario where player 2 wins and reward is -1
        scenario = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_2,
        )
        # Take a step
        output = GomokuEnv.step(scenario, 4, Player.PLAYER_2.value)
        # Check that the reward is -1
        self.assertEqual(output.reward, -1)


unittest.main(argv=["-k", "TestGomokuEnv"], verbosity=2, exit=False)

test_action_to_index (__main__.TestGomokuEnv.test_action_to_index) ... ok
test_assert_invalid (__main__.TestGomokuEnv.test_assert_invalid) ... ok
test_assert_valid (__main__.TestGomokuEnv.test_assert_valid) ... ok
test_canonical (__main__.TestGomokuEnv.test_canonical) ... ok
test_expand (__main__.TestGomokuEnv.test_expand) ... ok
test_kernels (__main__.TestGomokuEnv.test_kernels) ... ok
test_multiple_steps (__main__.TestGomokuEnv.test_multiple_steps) ... ok
test_ongoing_outcome (__main__.TestGomokuEnv.test_ongoing_outcome) ... ok
test_reward_fn (__main__.TestGomokuEnv.test_reward_fn) ... ok
test_winning_outcome_alt_diag (__main__.TestGomokuEnv.test_winning_outcome_alt_diag) ... ok
test_winning_outcome_alt_diag_bottom_right_edge (__main__.TestGomokuEnv.test_winning_outcome_alt_diag_bottom_right_edge) ... ok
test_winning_outcome_col_edge (__main__.TestGomokuEnv.test_winning_outcome_col_edge) ... ok
test_winning_outcome_row (__main__.TestGomokuEnv.test_winning_outcome_row) ... ok
test_win

<unittest.main.TestProgram at 0x154a4b790>

# Experimenting with MCTS

To use the `mctx` library we need to define the following components:

### The `recurrent_fn(params, rng_key, action, embedding)-> (RecurrentFnOutput, new_embedding)`

- params: A pytree to hold various parameters that might be used in the recurrent function (maybe learned components?)
- rng_key: A psuedo random generator seed which will be used in the MCTS search
- action: A jax array specifying which action to take
- empbedding: A jax array specifying the embedding of the current state

The function should return a tuple of an object of the type `RecurrentFnOutput` and a new embedding. It has the following fields:

- reward: A jax array which represents the reward for taking the action
- discount: A jax array which represents the discount which needs to be applied to the value when added to the reward
- prior_logits: A jax array with the unnormalized probabilities of the action space given the new state we are in after taking the action
- value: A jax array with the value of this particular root state

### The `RootFnOutput`

This is the representation of the root state from which the mcts will be run. Usually the fields of this struct would be filled with
what is computed from a representation neural network. The fields are:

- prior_logits: A jax array with the un-normalized probability values of taking any action of all possible actions in the action space
- value: A jax array with the estimated value of this particular root state
- embedding: A jax array specifying the embedding of the current state

### The call to mcts to get an informed action

```
policy_output = mctx.gumbel_muzero_policy(params, rng_key, root, recurrent_fn,
                                          num_simulations=32)
```

## A naive implementation

Let's create a simple MCTS agent which does not have any prior information or learned components so:

- All the prior logits are equally distributed since we don't know anything about the prior: `naive_prior(state: chex.Array)`
  - The naive_prior would be as simple as taking all the empty-cells and giving them equal values of (1) and non-empty cells with (0)
- For the estimated value we can implement a constant value function which will return a constant value of 0.5 `constant_value() -> 0.5`
- The embedding will just be a jax array of the 2D board
- The recurrent function should return a **canonical observation**. i.e. The board should always look like it's the first player's turn. This is important because `Gomoku`
  is a game where the "environment dynamics" is the other player. Since we don't have anothe player, we play against ourself and hope that this gives enough information about
  the underlying game dynamics and that any rational player with an intent to win would also play the same way.


In [7]:
import mctx
from abc import ABC, abstractmethod

In [8]:
# Abstract Base Class for a MuZero Model
class Model(ABC):
    """
    An abstract interface for a MuZero model which has value and prior calls
    This is agnostic of the implementation
    """

    @abstractmethod
    def value(self, params: chex.ArrayTree, rng_key: int, x, done) -> chex.Array:
        """
        Given the state and whether the episode is done
        give the value of the current state. The potential
        expected reward from here.
        """
        pass

    @abstractmethod
    def prior(self, params: chex.ArrayTree, rng_key: int, x) -> chex.Array:
        """
        Given the current state, provide the prior on all possible moves
        for the mctx policy to use in order to perform a search
        """
        pass

    @abstractmethod
    def prior_and_value(self, params, rng_key, x, done):
        """
        Return a tuple of prior and value. Useful if it can be reimplemented
        in an efficient way for usecases that require both
        """
        return (self.prior(params, rng_key, x), self.value(params, rng_key, x, done))

    def step(
        self, params: chex.ArrayTree, rng_key: int, action: chex.Array, embedding
    ) -> (mctx.RecurrentFnOutput, chex.Array):
        """A naive step call which calls the gomoku step call and returns a observation and a recurrent fn output"""
        board = embedding
        # step the board
        step_output = GomokuEnv.step(board, action, board.player)
        return (
            mctx.RecurrentFnOutput(
                reward=step_output.reward,
                discount=-1 * jnp.ones_like(step_output.reward),
                prior_logits=self.prior(params, rng_key, step_output.board.state),
                value=self.value(
                    params, rng_key, step_output.board.state, step_output.done
                ),
            ),
            step_output.board,
        )


class NaiveMcts(Model):
    def prior(self, params: chex.ArrayTree, rng_key: int, board: chex.Array):
        """
        This function returns a uniform prior over all legal moves.
        """
        return jnp.where(board == 0, 100, 0).astype(jnp.float32).flatten()

    def value(
        self, params: chex.ArrayTree, rng_key: int, board: chex.Array, done: bool
    ):
        """A naive constant value function for all states"""
        return lax.cond(done, lambda _: 0, lambda _: 1, None)

    def prior_and_value(self, params, rng_key, x, done):
        return super().prior_and_value(params, rng_key, x, done)

# Roadblock

At this point I realize that all of the above code is not written with vmap or jit in mind. The code functions like step need to be vmappable and dataclasses like Board, StepOutcome, Stroke need to be mappable for this to happen. So let's make sure `GomokuEnv` is completely "vmap"pable and "jit"able.

A quote I found online that is interesting and warrants a redesign of the Board, and how GomokuEnv works:

> vmap supports pytrees when written in the form of a "struct of arrays" rather than an "array of structs". This is because JAX/XLA only has support for numeric array dtypes, so the "array of struct" form can't be operated on efficiently in JAX.

The approach I took:

- First remove any use of optionals and enums in all dataclasses. They are not stackable and castable to jax arrays. At least not that I know of


In [9]:
from jax.tree_util import tree_flatten, tree_unflatten
from jax import tree_map


def pytrees_stack(pytrees, axis=0):
    results = tree_map(lambda *values: jnp.stack(values, axis=axis), *pytrees)
    return results


def show_example(structured):
    flat, tree = tree_flatten(structured)
    unflattened = tree_unflatten(tree, flat)
    print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")


class TestVmappedGomokuEnv(unittest.TestCase):
    # Test that the vmapped step function works as expected in multiple scenarios
    def test_multiple_boards(self):
        # Use the vmapped step function to step multiple boards (straght, diag and alt diag, empty)
        # In the first board player 1 wins, in the second board player 2 wins and in the third board the game is ongoing
        board1 = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        )
        expected_board1 = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=4,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=1,
                ),
            ],
            Player.PLAYER_2,
        )
        board2 = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=3,
                ),
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=8,
                    col=8,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=2,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_2,
        )
        expected_board2 = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=3,
                ),
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=8,
                    col=8,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=2,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=5,
                ),
            ],
            Player.PLAYER_1,
        )
        board3 = GomokuScenario.make_scenario([], Player.PLAYER_1)
        expected_board3 = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=1,
                ),
            ],
            Player.PLAYER_2,
        )
        boards = pytrees_stack([board1, board2, board3])
        actions = jnp.array([4, 100, 0])
        players = jnp.array(
            [Player.PLAYER_1.value, Player.PLAYER_2.value, Player.PLAYER_1.value]
        )
        vmapped_step = jax.vmap(GomokuEnv.step, in_axes=(0, 0, 0))
        output = vmapped_step(boards, actions, players)
        # Check that the output is correct
        expected_output = pytrees_stack(
            [
                StepOutput(board=expected_board1, reward=1, done=True),
                StepOutput(board=expected_board2, reward=-1, done=True),
                StepOutput(board=expected_board3, reward=0, done=False),
            ]
        )
        np.testing.assert_array_equal(output.board.state, expected_output.board.state)


unittest.main(argv=["-k", "TestVmappedGomokuEnv"], verbosity=2, exit=False)

test_multiple_boards (__main__.TestVmappedGomokuEnv.test_multiple_boards) ... ok

----------------------------------------------------------------------
Ran 1 test in 1.558s

OK


<unittest.main.TestProgram at 0x152a09210>

# Outcome of fixing Roadblock

I had to go back and redo a lot of the `GomokuEnv` class. The following lessons learnt:

- Don't use `if`, `else` conditionals in jax if you want to compile them in. Use `cond`, `select`, `switch` etc
- Don't use enums in your dataclasses. Jax cannot stack them
- Don't use asserts in your jax code. They will not be amenable to jax transformations

JAX conditionals require predicates that have the same type signatures for both true and false conditions. During tracing JAX _WILL_ evaluate both predicates just to check if everything will work as needed and to understand the code path in both directions for compiling. So if you have conditionals which change behavior based on input shape etc that is code-smell. Don't do that.

I faced this issue with the win_locations function whose output is of varied size which then get's consumed by a vmappable step function in some conditional. This could not be compiled. Thanfully, `argwhere` has a mode where it will always return a fixed size output. So using that helped solve that issue. I hope the gradient landscape is not screwed up by this lol. I don't think it would be.


In [10]:
import unittest


def recurrent_fn_output_equals(
    output: mctx.RecurrentFnOutput, expected_output: mctx.RecurrentFnOutput
):
    return (
        np.allclose(output.value, expected_output.value)
        and np.allclose(output.prior_logits, expected_output.prior_logits)
        and np.allclose(output.discount, expected_output.discount)
        and np.allclose(output.reward, expected_output.reward)
    )


class TestMctsHelpers(unittest.TestCase):
    def setUp(self) -> None:
        self.mcts = NaiveMcts()

    # Test that when given a board it returns a prior which is effectively the flatten on that board
    def test_prior(self):
        custom_board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=5,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        )
        rng_key = jax.random.PRNGKey(0)
        prior = self.mcts.prior((), rng_key, custom_board.state)
        expected_prior = jnp.where(custom_board.state == 0, 100.0, 0.0).flatten()
        np.testing.assert_array_equal(prior, expected_prior)

    # Test that step call through the recurrent function approach
    def test_step(self):
        # Create a random scenario first
        custom_board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        ).canonical()
        # Now call the naive step function
        output, board = self.mcts.step((), 4, 4, custom_board)
        expected_board = GomokuEnv.step(custom_board, 4, Player.PLAYER_1.value).board
        expected_logits = jnp.where(expected_board.state == 0, 100.0, 0.0).flatten()
        expected_output = mctx.RecurrentFnOutput(
            value=0, prior_logits=expected_logits, discount=-1, reward=1.0
        )  # action 4 is a winning move in this scenario
        self.assertTrue(recurrent_fn_output_equals(output, expected_output))
        self.assertTrue(np.allclose(expected_board.state, board.state))

    def test_step_vmaped(self):
        # Create a random scenario
        custom_board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        ).canonical()
        # Stack the two boards such that the batched dimension is the first dimension
        stacked_board_states = jnp.stack([custom_board.state, custom_board.state])
        boards = Board(
            state=stacked_board_states,
            player=jnp.array([Player.PLAYER_1.value, Player.PLAYER_1.value]),
        )
        stacked_actions = jnp.array([4, 7])
        self.assertTrue(
            stacked_board_states.shape == (2, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE)
        )
        # Transform the naive step function to a vmaped version
        vmapped_step = jax.vmap(self.mcts.step, in_axes=(None, None, 0, 0))
        # Call the vmapped step function
        output, board = vmapped_step((), 4, stacked_actions, boards)
        expected_board1 = GomokuEnv.step(custom_board, 4, Player.PLAYER_1.value).board
        expected_board2 = GomokuEnv.step(custom_board, 7, Player.PLAYER_1.value).board
        expected_rewards = jnp.array([1.0, 0.0])
        expected_logits1 = jnp.where(expected_board1.state == 0, 100.0, 0.0)
        expected_logits2 = jnp.where(expected_board2.state == 0, 100.0, 0.0)
        expected_output = mctx.RecurrentFnOutput(
            value=jnp.array([0, 1]),
            prior_logits=jnp.stack(
                [expected_logits1.flatten(), expected_logits2.flatten()]
            ),
            discount=jnp.array([-1, -1]),
            reward=expected_rewards,
        )
        self.assertTrue(recurrent_fn_output_equals(output, expected_output))
        self.assertTrue(np.allclose(expected_board1.state, board.state[0]))
        self.assertTrue(np.allclose(expected_board2.state, board.state[1]))


# just run the tests in the class TestNaiveMctsHelpers
unittest.main(argv=["-k", "TestMctsHelpers"], verbosity=2, exit=False)

test_prior (__main__.TestMctsHelpers.test_prior) ... ok
test_step (__main__.TestMctsHelpers.test_step) ... ok
test_step_vmaped (__main__.TestMctsHelpers.test_step_vmaped) ... ok

----------------------------------------------------------------------
Ran 3 tests in 3.640s

OK


<unittest.main.TestProgram at 0x29827a350>

In [11]:
from functools import partial

In [12]:
GOMOKU_BOARD_SIZE = 5


def fn():
    # Create a gomoku scenario with 4 in a row for both players
    scenario = GomokuScenario.make_scenario(
        [
            GomokuStroke(
                player=Player.PLAYER_1,
                row=0,
                col=0,
                stroke_type=GomokuStroke.StrokeType.ROW,
                length=4,
            ),
            GomokuStroke(
                player=Player.PLAYER_2,
                row=1,
                col=0,
                stroke_type=GomokuStroke.StrokeType.COL,
                length=4,
            ),
        ],
        Player.PLAYER_1,
    ).canonical()

    prior_logits = jnp.where(scenario.state == 0, 1.0, 0.0).flatten().reshape(1, -1)
    value = jnp.array([100]).reshape(
        1,
    )
    embedding = Board(
        state=scenario.state.reshape(1, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE),
        player=jnp.array([Player.PLAYER_1.value]),
    )

    root = mctx.RootFnOutput(
        prior_logits=prior_logits, value=value, embedding=embedding
    )
    naive_mcts = NaiveMcts()
    recurrent_fn = jax.jit(jax.vmap(naive_mcts.step, in_axes=(None, None, 0, 0)))
    rng_key = jax.random.PRNGKey(4)
    rng_key_1, rng_key_2 = jax.random.split(rng_key, 2)
    policy_output = mctx.muzero_policy(
        (), rng_key_1, root, recurrent_fn, num_simulations=3000
    )
    return policy_output


policy_output = fn()



In [13]:
print(f"The recommended action is {policy_output.action}. The winning move is 4!")
assert policy_output.action == 4

The recommended action is [4]. The winning move is 4!


> The mctx repo has a very nice function to visualize a mctx simulation. Dropping it in here to help visualize!
> This needs a pygraphviz dependency.

Note: I could not easily add pygraphviz through poetry. It's not self contained apparently. So this dependency is not currently reflected in the toml file :(


In [14]:
import pygraphviz
from typing import Optional, Sequence


def convert_tree_to_graph(
    tree: mctx.Tree, action_labels: Optional[Sequence[str]] = None, batch_index: int = 0
) -> pygraphviz.AGraph:
    """Converts a search tree into a Graphviz graph.

    Args:
      tree: A `Tree` containing a batch of search data.
      action_labels: Optional labels for edges, defaults to the action index.
      batch_index: Index of the batch element to plot.

    Returns:
      A Graphviz graph representation of `tree`.
    """
    chex.assert_rank(tree.node_values, 2)
    batch_size = tree.node_values.shape[0]
    if action_labels is None:
        action_labels = range(tree.num_actions)
    elif len(action_labels) != tree.num_actions:
        raise ValueError(
            f"action_labels {action_labels} has the wrong number of actions "
            f"({len(action_labels)}). "
            f"Expecting {tree.num_actions}."
        )

    def node_to_str(node_i, reward=0, discount=1):
        return (
            f"{node_i}\n"
            f"Reward: {reward:.2f}\n"
            f"Discount: {discount:.2f}\n"
            f"Value: {tree.node_values[batch_index, node_i]:.2f}\n"
            f"Visits: {tree.node_visits[batch_index, node_i]}\n"
        )

    def edge_to_str(node_i, a_i):
        node_index = jnp.full([batch_size], node_i)
        probs = jax.nn.softmax(tree.children_prior_logits[batch_index, node_i])
        return (
            f"{action_labels[a_i]}\n"
            f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n"  # pytype: disable=unsupported-operands  # always-use-return-annotations
            f"p: {probs[a_i]:.2f}\n"
        )

    graph = pygraphviz.AGraph(directed=True)

    # Add root
    graph.add_node(0, label=node_to_str(node_i=0), color="green")
    # Add all other nodes and connect them up.
    for node_i in range(tree.num_simulations):
        for a_i in range(tree.num_actions):
            # Index of children, or -1 if not expanded
            children_i = tree.children_index[batch_index, node_i, a_i]
            if children_i >= 0:
                graph.add_node(
                    children_i,
                    label=node_to_str(
                        node_i=children_i,
                        reward=tree.children_rewards[batch_index, node_i, a_i],
                        discount=tree.children_discounts[batch_index, node_i, a_i],
                    ),
                    color="red",
                )
                graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i))

    return graph

Example code to generate pngs with mcts search tree

```python
graph = convert_tree_to_graph(policy_output.search_tree)
output_file = "tree.png"
print(f"Saving tree diagram to: {output_file}")
graph.draw(output_file, prog="dot")
```

Example code to print the mcts search

```python
tree = policy_output.search_tree
for node_i in range(tree.num_simulations):
    print(f"Node {node_i} has {tree.node_visits[0, node_i]} visits and value {tree.node_values[0, node_i]}")
    for a_i in range(tree.num_actions):
        children_i = tree.children_index[0, node_i, a_i]
        if children_i >= 0:
            print(f"Action {a_i} has {tree.children_visits[0, node_i, a_i]} visits and value {tree.children_rewards[0, node_i, a_i]} and goes to node {children_i}")
```


# Moving on to building a neural network piece

Although I don't fully understand the reward and discount engineering here, I want to move on to building the neural network piece. The mcts search seems to continue searching irrespective of a negative reward. This is fair since it does not have a notion of when a game is done etc. However this needs to be somehow represented in the reward/value functions somehow to make the mcts search more efficent.

Following this thread on why choosing a negative discount would help in a two-player self-play setting later on will help clarify further:

- https://stats.stackexchange.com/questions/304386/discount-factor-for-self-play-in-reinforcement-learning
- https://github.com/deepmind/mctx/issues/24#issuecomment-1193281828


In [15]:
from flax import linen as nn
import optax
import plotly.express as px
import pandas as pd

In [16]:
# Implement a simple MLP model to simulate the value function for a given state and prior logits for all actions (same size as the state)
class SimpleMuZeroModel(Model, nn.Module):
    @nn.compact
    def __call__(self, x, prior_logits):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        value = nn.Dense(features=1)(x)
        policy_logits = nn.Dense(features=GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE)(x)
        return value, policy_logits

    # Override the value function interface
    def value(self, params, rng_key, x, done):
        # WARNING: Fix this prior_logits thing
        value, _ = self.apply(params, x.flatten(), jnp.zeros_like(x.flatten()))
        return value.squeeze(-1)

    # Override the prior function interface
    def prior(self, params, rng_key, x):
        # WARNING: Fix this prior_logits thing
        _, policy_logits = self.apply(params, x.flatten(), jnp.zeros_like(x.flatten()))
        return policy_logits

    def prior_and_value(self, params, rng_key, x, done):
        value, policy_logits = self.apply(
            params,
            x.flatten(),
            jnp.zeros_like(x.flatten()),
        )
        return policy_logits, value.squeeze(-1)

    @classmethod
    def random_model(cls, board_size):
        model = SimpleMuZeroModel()
        # Create an interactive Module instance by binding variables and RNGs
        prior = jnp.ones(board_size * board_size)
        state = jnp.zeros(board_size * board_size)
        params = model.init(jax.random.PRNGKey(0), state, prior)
        return model, params

    def step(
        self, params: chex.ArrayTree, rng_key: int, action: chex.Array, embedding
    ) -> (mctx.RecurrentFnOutput, chex.Array):
        """A naive step call which calls the gomoku step call and returns a observation and a recurrent fn output"""
        board = embedding
        # step the board
        step_output = GomokuEnv.step(board, action, board.player)
        # Convert the board to a canonical observation
        prior_logits, value = self.prior_and_value(
            params, rng_key, step_output.board.canonical().state, step_output.done
        )
        return (
            mctx.RecurrentFnOutput(
                reward=step_output.reward,
                discount=-1 * jnp.ones_like(step_output.reward),
                prior_logits=prior_logits,
                value=value,
            ),
            step_output.board,
        )


# Define a loss function for the model a simple L2 Loss term for the value and a cross entropy loss for the policy
def loss_fn(value, policy_logits, target_value, target_policy_logits):
    value_loss = jnp.mean((value - target_value) ** 2)
    policy_loss = optax.softmax_cross_entropy(
        policy_logits, target_policy_logits
    ).mean()
    return value_loss + policy_loss

In [17]:
# Run the mcts helper unit tests for SimpleMuZeroModel


class TestSimpleMuzeroHelpers(unittest.TestCase):
    def setUp(self) -> None:
        self.mcts, self.params = SimpleMuZeroModel.random_model(GOMOKU_BOARD_SIZE)
        self.model = self.mcts.bind(self.params)
        self.rng = jax.random.PRNGKey(0)

    def test_prior(self):
        # Test that the shape of the prior is (batch_size, GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE)
        custom_board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        ).canonical()
        prior = self.model.prior(self.params, self.rng, custom_board.state)

    def test_value(self):
        # Test that the shape of the value is (batch_size, 1)
        custom_board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        ).canonical()
        value = self.model.value(self.params, self.rng, custom_board.state, False)
        self.assertTrue(value.shape == ())

    def test_step(self):
        custom_board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        ).canonical()
        # Now call the naive step function
        output, board = self.mcts.step(self.params, self.rng, 4, custom_board)
        # Assert that the shape of the board is correct in what is returned
        self.assertTrue(board.state.shape, (GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE))
        self.assertTrue(np.all(board.player == 2))

    def test_vmapped_step(self):
        # Create a random scenario
        custom_board = GomokuScenario.make_scenario(
            [
                GomokuStroke(
                    player=Player.PLAYER_1,
                    row=0,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.ROW,
                    length=4,
                ),
                GomokuStroke(
                    player=Player.PLAYER_2,
                    row=1,
                    col=0,
                    stroke_type=GomokuStroke.StrokeType.COL,
                    length=4,
                ),
            ],
            Player.PLAYER_1,
        ).canonical()
        # Stack the two boards such that the batched dimension is the first dimension
        stacked_board_states = jnp.stack([custom_board.state, custom_board.state])
        boards = Board(
            state=stacked_board_states,
            player=jnp.array([Player.PLAYER_1.value, Player.PLAYER_1.value]),
        )
        stacked_actions = jnp.array([4, 7])
        self.assertTrue(
            stacked_board_states.shape == (2, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE)
        )
        # Transform the naive step function to a vmaped version
        vmapped_step = jax.vmap(self.mcts.step, in_axes=(None, None, 0, 0))
        # Call the vmapped step function
        output, board = vmapped_step(self.params, self.rng, stacked_actions, boards)
        # Assert that the output board has the shape 2, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE
        self.assertTrue(board.state.shape == (2, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE))
        self.assertTrue(np.all(board.player == 2))
        self.assertTrue(board.player.shape, (2,))


# Run all the tests again
unittest.main(argv=["-k", "TestSimpleMuzeroHelpers"], verbosity=2, exit=False)

test_prior (__main__.TestSimpleMuzeroHelpers.test_prior) ... ok
test_step (__main__.TestSimpleMuzeroHelpers.test_step) ... ok
test_value (__main__.TestSimpleMuzeroHelpers.test_value) ... ok
test_vmapped_step (__main__.TestSimpleMuzeroHelpers.test_vmapped_step) ... ok

----------------------------------------------------------------------
Ran 4 tests in 2.917s

OK


<unittest.main.TestProgram at 0x157d33b90>

In [18]:
# Create a toy dataset for training the model and ensuring that it converges as expected.
# Data should be something trivial like a quadratic loss function and a argmax policy function
# A mapping from a state vector to a value and policy logits.


def toy_dataset(num_samples=10000):
    """A toy dataset for testing."""
    # state vectors of size 400 of values between 0,1
    states = jax.random.uniform(
        key=jax.random.PRNGKey(0),
        shape=(num_samples, GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE),
    )
    # value is a quadratic function of the state
    value = jnp.sum(states**2, axis=-1, keepdims=True)
    # policy is a argmax function of the state
    policy_output = jnp.argmax(states, axis=-1)
    # policy_logits is a one hot encoding of the policy output
    policy_logits = jax.nn.one_hot(policy_output, GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE)
    # uniform prior logits
    prior_logits = jnp.ones((num_samples, GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE)) / (
        GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE
    )
    return states, value, policy_logits, prior_logits

In [19]:
# Create a training step function
def train_step(
    model,
    optimizer,
    params,
    optimizer_state,
    states,
    prior_logits,
    policy_logits,
    value,
):
    """Train one step: compute the loss, update params and return the loss, params and updated optimizer state."""

    # Define a loss function
    def forward(params, states, prior_logits, policy_logits, value):
        value_pred, policy_logits_pred = model.apply(params, states, prior_logits)
        return loss_fn(value_pred, policy_logits_pred, value, policy_logits)

    # IMPORTANT:
    # The jax default behavior when it comes to gradients is to only return the gradient wrt to the first argument
    # In this case we want to return the gradient wrt to params which is the model parameters that we want to update
    value_and_grad_loss = jax.value_and_grad(forward)
    loss, grad = value_and_grad_loss(params, states, prior_logits, policy_logits, value)
    # Update the optimizer state
    updates, optimizer_state = optimizer.update(grad, optimizer_state)
    # Update the parameters
    params = optax.apply_updates(params, updates)
    return (loss, params, optimizer_state)

In [20]:
def toy_train_loop(model):
    """A driver function for training the model on a toy dataset."""
    # Create a toy dataset
    states, value, policy_logits, prior_logits = toy_dataset()
    # Initialize the model with a random input
    params = model.init(jax.random.PRNGKey(0), states[:1], policy_logits[:1])
    # Print the model parameters shapes with tree_map
    print(f"Model parameters shapes:\n {tree_map(lambda x: x.shape, params)}")

    # Create an optimizer
    optimizer = optax.adam(learning_rate=1e-4)
    # Create an optimizer state
    optimizer_state = optimizer.init(params)

    states, value, policy_logits, prior_logits = toy_dataset()

    # partially apply the model and optimizer to the train_step function and then jit it
    train_step_fn = jax.jit(partial(train_step, model, optimizer))

    # Train the model for 1000 steps
    losses = []
    for i in range(1000):
        loss, params, optimizer_state = train_step_fn(
            params, optimizer_state, states, prior_logits, policy_logits, value
        )
        if i % 100 == 0:
            print(f"Loss({i}/1000): {loss}")
        losses.append(loss)
    return losses, params

In [21]:
# Train and plot the losses
def fn():
    losses, params = toy_train_loop(SimpleMuZeroModel())
    assert jnp.all(losses[-1] < 3)
    fig = px.line(
        x=range(len(losses)),
        y=losses,
        title="Losses for toy problem",
        labels={"x": "Step", "y": "Loss"},
    )
    fig.show()


fn()

Model parameters shapes:
 {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}
The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}
Loss(0/1000): 75.56990051269531
Loss(100/1000): 6.0869669914245605
Loss(200/1000): 3.684264898300171
Loss(300/1000): 3.5695204734802246
Loss(400/1000): 3.4499385356903076
Loss(500/1000): 3.300150156021118
Loss(600/1000): 3.1231250762939453
Loss(700/1000): 2.9212839603424072
Loss(800/1000): 2.6845343112945557
Loss(900/1000): 2.4209797382354736


Nice to see a the model converges on a simple dataset! That means the training routine is working as expected

# What's next?

Now we need to collect real data and write the policy/value iteration loop to update the parameters. This is how it's done in AlphaZero
<br>

<div style="text-align: center">
<img src="https://github.com/kartikarcot/gomoku/blob/master/backend/notebooks/alphazero.png?raw=true" width="400">
</div>

Now we are at a stage where we can update our `NaiveMcts` to `AlphaMcts`. `AlphaMcts` will use a neural network to perform inference at every stage to compute "action probabilities" and "winning probability" which will then be piped to the `mctx` library's montecarlo routine to guide the search. **Ideally this should then learn a good prior on a simple 5x5 board. Once that is the case we can scale it.**

> Proof that the model is working on 5x5 board will be when the model will always win in some basic unit tests with **less than 100 simulations**.

1. Horizontal brink of winning:

   ```
   2 | 2 | 2 | 0 | 0
   0 | 0 | 0 | 0 | 0
   1 | 1 | 1 | 1 | 0
   0 | 0 | 0 | 0 | 0
   2 | 0 | 0 | 0 | 0
   ```

2. Vertical brink of winning:

   ```
   2 | 1 | 2 | 0 | 0
   2 | 1 | 0 | 0 | 0
   2 | 1 | 0 | 0 | 0
   0 | 1 | 0 | 0 | 0
   0 | 0 | 0 | 0 | 0
   ```

3. Diagonal (top-left to bottom-right) brink of winning:

   ```
   1 | 0 | 2 | 2 | 2
   0 | 1 | 0 | 0 | 2
   0 | 0 | 1 | 0 | 0
   0 | 0 | 0 | 1 | 0
   0 | 0 | 0 | 0 | 0
   ```

4. Diagonal (bottom-left to top-right) brink of winning:

   ```
   2 | 2 | 2 | 0 | 0
   2 | 0 | 0 | 1 | 0
   0 | 0 | 1 | 0 | 0
   0 | 1 | 0 | 0 | 0
   1 | 0 | 0 | 0 | 0
   ```

5. Edge case - Horizontal at the top edge:

   ```
   1 | 1 | 1 | 1 | 0
   0 | 0 | 0 | 0 | 0
   0 | 0 | 0 | 0 | 2
   0 | 0 | 0 | 0 | 2
   0 | 0 | 0 | 2 | 2
   ```

6. Edge case - Vertical at the left edge:
   ```
   1 | 0 | 2 | 2 | 2
   1 | 0 | 0 | 0 | 2
   1 | 0 | 0 | 0 | 0
   1 | 0 | 0 | 0 | 0
   0 | 0 | 0 | 0 | 0
   ```

These cases cover various scenarios where player 1 is one move away from winning, either horizontally, vertically, or diagonally.


In [22]:
# Shifting back to a gomoku board size of 5x5 to test out the alpha go model in a smaller setting
GOMOKU_BOARD_SIZE = 5


def fn():
    # Let's simply run another test to see if the neural network performs assists the mctx search
    # without any errors
    # Create a gomoku scenario with 4 in a row for both players
    scenario = GomokuScenario.make_scenario(
        [
            GomokuStroke(
                player=Player.PLAYER_1,
                row=0,
                col=0,
                stroke_type=GomokuStroke.StrokeType.ROW,
                length=4,
            ),
            GomokuStroke(
                player=Player.PLAYER_2,
                row=1,
                col=0,
                stroke_type=GomokuStroke.StrokeType.COL,
                length=4,
            ),
        ],
        Player.PLAYER_1,
    ).canonical()

    prior_logits = jnp.where(scenario.state == 0, 1.0, 0.0).flatten().reshape(1, -1)
    value = jnp.array([100]).reshape(
        1,
    )
    embedding = Board(
        state=scenario.state.reshape(1, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE),
        player=jnp.array([Player.PLAYER_1.value]),
    )

    root = mctx.RootFnOutput(
        prior_logits=prior_logits, value=value, embedding=embedding
    )
    muzero_mcts, muzero_params = SimpleMuZeroModel.random_model(GOMOKU_BOARD_SIZE)
    recurrent_fn = jax.vmap(muzero_mcts.step, in_axes=(None, None, 0, 0))

    rng_key = jax.random.PRNGKey(4)
    rng_key_1, rng_key_2 = jax.random.split(rng_key, 2)
    policy_output = mctx.muzero_policy(
        muzero_params, rng_key_1, root, recurrent_fn, num_simulations=3000
    )
    return policy_output


policy_output = fn()


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



In [23]:
print(f"The recommended action is {policy_output.action}. The winning move is 4!")
assert policy_output.action == 4

The recommended action is [4]. The winning move is 4!


Now that a simple muzero model is implemented wit ha succesful step function that can be vmapped and used in a muzero policy, we can start building the policy iteration loop.


In [24]:
import jax.tree_util as tree_util

# Print the structure of the policy output struct
print("Policy Output Structure:")
print(tree_util.tree_structure(policy_output))

Policy Output Structure:
PyTreeDef(CustomNode(PolicyOutput[('action', 'action_weights', 'search_tree')], [*, *, CustomNode(Tree[('action_from_parent', 'children_discounts', 'children_index', 'children_prior_logits', 'children_rewards', 'children_values', 'children_visits', 'embeddings', 'extra_data', 'node_values', 'node_visits', 'parents', 'raw_values', 'root_invalid_actions')], [*, *, *, *, *, *, *, CustomNode(Board[('player', 'state')], [*, *]), None, *, *, *, *, *])]))


# Neural network training procedure

The way they train it is by recording self-play episodes $(s_t, \pi_t, z_t)$ where $s_t$ is the state of the board, $\pi_t$ is the policy recommended by the monte-carlo search and $z_t$ is who the winner was from the perspective of $s_t$. They collect many such episodes and then train the network's policy and value heads jointly like so:

$ L = (z - v*\theta)^2 - \pi^T\log P*\theta + c||\theta||^2 $

where, $v_\theta$ is the value head and $P_\theta$ is the policy head. So its a some of a $L2$ loss term for value head and a cross-entropy loss term for the policy head in-addition to a regularization term for the network parameters $(\theta)$.


In [126]:
from typing import Callable

# Let's write a simple routine to collect a training episode: A list of (s_t, pi_t, z_t) from a self-play game
# In order to do this we can first collect (s_t,pi_t) pairs by running the self-play game with the given network parameters
# and then in the final step we can record the winner z_t. Which can then be flipped for alternative previous time steps


@dataclass
class Episode:
    boards: chex.Array
    actions: chex.Array
    policies: chex.Array
    # NOTE: It is the value of the board for the current player who is playing
    # So when you see a board saying -1 with game ending and player 2 has won what it
    # means is that in this board player 2 has played and won and now it is player 1's turn
    # where his value is -1
    values: chex.Array
    canonical: bool = False


def mctx_step(
    prior_and_value_fn: Callable,
    recurrent_fn: Callable,
    step_fn: Callable,
    params: chex.ArrayTree,
    rng_key: int,
    board: Board,
    num_simulations: int = 3000,
):
    """
    A naive step call which calls the gomoku step call and returns a canonical observation and a recurrent fn output
    NOTE: ASSUMES that all functions are vmapped. Have not figured out how to lift the vmap above this
    TODO: Figure out how to lift the vmap above this
    """
    # Create a root fn output (for which first need the board in canonical format)
    vmapped_canonical = jax.vmap(Board.canonical, in_axes=(0,))
    prior_logits, value = prior_and_value_fn(params, rng_key, vmapped_canonical(board).state, False)
    root = mctx.RootFnOutput(prior_logits=prior_logits, value=value, embedding=board)
    policy_output = mctx.muzero_policy(
        params,
        rng_key,
        root,
        recurrent_fn,
        num_simulations=num_simulations,
        invalid_actions=board.invalid_actions(),
    )
    # Take the recommended action and step the board
    output = step_fn(board, policy_output.action, board.player)
    return output, policy_output


def reward_to_value(reward):
    """Given the reward array, find the last reward and keep flipping the sign for each previous step"""
    n = len(reward)
    # Flip the sign of the reward for each previous step
    signs = jnp.power(
        -1, jnp.arange(n - 1, -1, -1)
    )  # The arguments mean start at n-1, stop at -1, step backwards by 1
    return reward[-1] * signs  # Multiply the last reward by the signs


def strip_single_episode(state, player, action, action_weights, done, reward):
    """
    For each batch (game) we need to strip out the steps that are not valid (that come past the done flag)
    We can do this by finding the index of the first done flag per game in the batch and then slicing everything accordingly
    """
    cumsum = jnp.cumsum(done, axis=0)
    # Find the first done flag
    valid_mask = cumsum <= 1
    # Now we can strip the episode
    state = state[valid_mask]
    player = player[valid_mask]
    action = action[valid_mask]
    action_weights = action_weights[valid_mask]
    done = done[valid_mask]
    value = reward_to_value(reward[valid_mask])
    return Episode(
        Board(state=state, player=player),
        action,
        action_weights,
        value,
        canonical=False,
    )


def format_episodes(step_outputs, policy_outputs) -> List[Episode]:
    """Format an episode into a more readable format"""
    # The step_outputs state has shape (num_steps, batch_size, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE)
    # The action taken has shape (num_steps, batch_size)
    # The action weights has shape (num_steps, batch_size, GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE)
    # The done flag has shape (num_steps, batch_size)
    # The reward has shape (num_steps, batch_size)
    # The players have shape (num_steps, batch_size)
    state = jnp.swapaxes(step_outputs.board.state, 0, 1)
    players = jnp.swapaxes(step_outputs.board.player, 0, 1)
    action = jnp.swapaxes(policy_outputs.action, 0, 1)
    action_weights = jnp.swapaxes(policy_outputs.action_weights, 0, 1)
    done = jnp.swapaxes(step_outputs.done, 0, 1)
    reward = jnp.swapaxes(step_outputs.reward, 0, 1)

    # Iterate over the batch dimension and strip each episode
    stripped_episodes = [
        strip_single_episode(
            state[i], players[i], action[i], action_weights[i], done[i], reward[i]
        )
        for i in range(state.shape[0])
    ]
    return stripped_episodes


def self_play_episode(
    prior_and_value_fn: Callable,
    recurrent_fn: Callable,
    step_fn: Callable,
    params: chex.ArrayTree,
    rng_key: int,
    max_steps: int = GOMOKU_BOARD_SIZE * GOMOKU_BOARD_SIZE,
    num_games: int = 1,
    num_simulations: int = 100,  # 100 is faster to run although not as accurate
) -> List[Episode]:
    """
    Given a model and parameters for that model, run a self-play episode and return the states, policies and values
    for each time step as well as the winner of the game. You can also run multiple games at once by setting num_games
    to a value greater than 1.
    Since we may want to run multiple games at once then you need to ensure all the component
    functions are vmapped for batching to work.
    NOTE: There has been difficulty in JITing this function due to the use of lax.scan. Suggest jitting the component functions
    and then calling this function with the jitted functions.
    TODO: Figure out how to lift JIT and VMAP above this
    """
    # Create an empty board scenario
    board = Board.expand(Board(), num=num_games)

    # Use jax.lax.scan to collect the outputs and policy outputs
    # The scan function takes in a function that takes in the previous state and returns the next state
    def scan(prev, _):
        # Unpack the previous state
        board, rng_key = prev
        # Split the rng key
        rng_key, rng_key_next = jax.random.split(rng_key)
        output, policy_output = mctx_step(
            prior_and_value_fn,
            recurrent_fn,
            step_fn,
            params,
            rng_key,
            board,
            num_simulations,
        )
        return (output.board, rng_key_next), (output, policy_output)

    # Scan over the board until the game is done
    _, (outputs, policy_outputs) = lax.scan(
        scan, (board, rng_key), None, length=max_steps
    )

    # Now we can collect the states, policies and values into a single episode
    return (outputs, policy_outputs)


def make_episode_canonical(episode: Episode):
    """
    Given an episode, make each step canonical by updating board state and make the value component 1
    """
    vmapped_canonical = jax.vmap(Board.canonical, in_axes=(0,))
    canonical_board = vmapped_canonical(episode.boards)
    # If all values are zeros then values need not be updated
    # Otherwise if it ends with -1 then array should be multiplied by -1
    # This is because the value function is always from the perspective of the player 1.
    # and if the player 2 had one we get -1. Since we are always player 1 we need to flip the sign for the value
    # and then propagate the sign change backwards which is effectively multiplying by -1 on all values
    values = lax.cond(
        episode.values[-1] == 0,
        lambda _: episode.values,
        # In a canonical game the last step is always a win for player 1
        # as it is the last player to make a move and we consider him to be player 1
        lambda _: jnp.power(-1, jnp.arange(len(episode.values))),
        None,
    )
    return Episode(
        canonical_board, episode.actions, episode.policies, values, canonical=True
    )

In [127]:
GOMOKU_BOARD_SIZE = 5


class TestMctxEpisodeHelpers(unittest.TestCase):
    def setUp(self) -> None:
        self.model = SimpleMuZeroModel()
        self.prior_and_value_fn = jax.jit(
            jax.vmap(self.model.prior_and_value, in_axes=(None, None, 0, None))
        )
        self.recurrent_fn = jax.jit(
            jax.vmap(self.model.step, in_axes=(None, None, 0, 0))
        )
        self.jitted_step = jax.jit(jax.vmap(GomokuEnv.step, in_axes=(0, 0, 0)))
        self.params = SimpleMuZeroModel.random_model(GOMOKU_BOARD_SIZE)[1]
        self.rng_key = jax.random.PRNGKey(256)
        self.jit_self_play_episode = jax.jit(
            self_play_episode, static_argnums=(0, 1, 2, 5, 6)
        )

    # Test that the single step function returns outputs of appropriate shapes
    def test_mctx_step(self):
        board = Board.expand(Board())
        rng_key = jax.random.PRNGKey(4)
        output, policy_output = mctx_step(
            self.prior_and_value_fn,
            self.recurrent_fn,
            self.jitted_step,
            self.params,
            rng_key,
            board,
            100,
        )
        self.assertTrue(
            output.board.state.shape == (1, GOMOKU_BOARD_SIZE, GOMOKU_BOARD_SIZE)
        )
        self.assertTrue(output.reward.shape == (1,))
        self.assertTrue(output.done.shape == (1,))
        self.assertTrue(policy_output.action.shape == (1,))

    def test_format_episodes(self):
        episode_data = self.jit_self_play_episode(
            self.prior_and_value_fn,
            self.recurrent_fn,
            self.jitted_step,
            self.params,
            self.rng_key,
            max_steps=2,
            num_games=2,
        )
        episodes = format_episodes(*episode_data)
        self.assertTrue(len(episodes) == 2)
        for episode in episodes:
            self.assertTrue(episode.boards.state.shape == (2, 5, 5))
            self.assertTrue(episode.actions.shape == (2,))
            self.assertTrue(episode.policies.shape == (2, 25))
            self.assertTrue(episode.values.shape == (2,))
            self.assertFalse(episode.canonical)

    def test_make_canonical(self):
        episode_data = self.jit_self_play_episode(
            self.prior_and_value_fn,
            self.recurrent_fn,
            self.jitted_step,
            self.params,
            self.rng_key,
            max_steps=2,
            num_games=2,
        )
        episodes = format_episodes(*episode_data)
        episode = make_episode_canonical(episodes[0])
        self.assertTrue(episode.boards.state.shape == (2, 5, 5))
        self.assertTrue(np.all(episode.boards.player == 1))
        self.assertTrue(episode.actions.shape == (2,))
        self.assertTrue(episode.policies.shape == (2, 25))
        self.assertTrue(episode.values.shape == (2,))
        self.assertTrue(np.all(episode.values == 0))
        self.assertTrue(episode.canonical)


unittest.main(argv=["-k", "TestMctxEpisodeHelpers"], verbosity=2, exit=False)

test_format_episodes (__main__.TestMctxEpisodeHelpers.test_format_episodes) ... 

ok
test_make_canonical (__main__.TestMctxEpisodeHelpers.test_make_canonical) ... ok
test_mctx_step (__main__.TestMctxEpisodeHelpers.test_mctx_step) ... ok

----------------------------------------------------------------------
Ran 3 tests in 10.898s

OK


<unittest.main.TestProgram at 0x2d9d17c10>

In [128]:
GOMOKU_BOARD_SIZE = 5


def fn():
    model, params = SimpleMuZeroModel.random_model(GOMOKU_BOARD_SIZE)
    prior_and_value_fn = jax.jit(
        jax.vmap(model.prior_and_value, in_axes=(None, None, 0, None))
    )
    recurrent_fn = jax.jit(jax.vmap(model.step, in_axes=(None, None, 0, 0)))
    jitted_step = jax.jit(jax.vmap(GomokuEnv.step, in_axes=(0, 0, 0)))
    rng_key = jax.random.PRNGKey(256)
    jit_self_play_episode = jax.jit(self_play_episode, static_argnums=(0, 1, 2, 5, 6))
    episode_data = jit_self_play_episode(
        prior_and_value_fn, recurrent_fn, jitted_step, params, rng_key
    )
    episodes = format_episodes(*episode_data)
    return episodes


# Print the structure of the episodes the actions taken who won in the end
episodes = fn()
for i, episode in enumerate(episodes):
    print(f"Episode {i}")
    print(f"Episode winner: {episode.values[0]}")
    print(f"Episode actions: {episode.actions}")
    print(f"Episode states: {episode.boards.state}")
    print(f"Episode values: {episode.values}")

# Now print canonical episodes
print(
    "-------------------------------Canonical Episodes--------------------------------"
)
canonical_episodes = [make_episode_canonical(episode) for episode in episodes]
for i, episode in enumerate(canonical_episodes):
    print(f"Episode {i}")
    print(f"Episode winner: {episode.values[0]}")
    print(f"Episode actions: {episode.actions}")
    print(f"Episode states:\n {episode.boards.state}")
    print(f"Episode values: {episode.values}")

Episode 0
Episode winner: 0
Episode actions: [ 3 12 18 19 15 23  6 14  4 10  2  9 13  5 21  1 11  0  7 16 17 22 20  8
 24]
Episode states: [[[0 0 0 1 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]

 [[0 0 0 1 0]
  [0 0 0 0 0]
  [0 0 2 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]

 [[0 0 0 1 0]
  [0 0 0 0 0]
  [0 0 2 0 0]
  [0 0 0 1 0]
  [0 0 0 0 0]]

 [[0 0 0 1 0]
  [0 0 0 0 0]
  [0 0 2 0 0]
  [0 0 0 1 2]
  [0 0 0 0 0]]

 [[0 0 0 1 0]
  [0 0 0 0 0]
  [0 0 2 0 0]
  [1 0 0 1 2]
  [0 0 0 0 0]]

 [[0 0 0 1 0]
  [0 0 0 0 0]
  [0 0 2 0 0]
  [1 0 0 1 2]
  [0 0 0 2 0]]

 [[0 0 0 1 0]
  [0 1 0 0 0]
  [0 0 2 0 0]
  [1 0 0 1 2]
  [0 0 0 2 0]]

 [[0 0 0 1 0]
  [0 1 0 0 0]
  [0 0 2 0 2]
  [1 0 0 1 2]
  [0 0 0 2 0]]

 [[0 0 0 1 1]
  [0 1 0 0 0]
  [0 0 2 0 2]
  [1 0 0 1 2]
  [0 0 0 2 0]]

 [[0 0 0 1 1]
  [0 1 0 0 0]
  [2 0 2 0 2]
  [1 0 0 1 2]
  [0 0 0 2 0]]

 [[0 0 1 1 1]
  [0 1 0 0 0]
  [2 0 2 0 2]
  [1 0 0 1 2]
  [0 0 0 2 0]]

 [[0 0 1 1 1]
  [0 1 0 0 2]
  [2 0 2 0 2]
  [1 0 0 1 2]
  [0 0 0 2 0]

Now that we have all the machinery in place to collect data it's time to write a data collection routine. A routine to save the data. Then use the data to train a model. The main training routine per-se:

Outer Loop

- Run the current model to simulate $N$ games of $M$ length
- Format the data to episodes
- Convert a list of episodes to a dataset
  Inner Loop
  - Train the model on this dataset
    - The model will run $E$ epochs on the dataset
      - After every epoch do a validation routine (play 5 games of self-play)
        and report some metrics
      - Report validation losses to W&B
    - The model will follow a learn rate decay schedule $lr$
      - [ ] Will this be a learning rate decay on the outer-lopp or the inner loop ?
    - Save the checkpointed model
    - Update the model

Interesting insight:

> The first step of the outerloop is pmappable! I think even the model training routine is easily pmappable with flax!


In [129]:
# Simulate N games of M length is just a call to the self_play_episode function
# Simulate 10 games of 10 steps each
def simulate_fn(
    model, params, rng_key, num_games=10, max_steps=10, num_simulations=100
) -> List[Episode]:
    prior_and_value_fn = jax.jit(
        jax.vmap(model.prior_and_value, in_axes=(None, None, 0, None))
    )
    recurrent_fn = jax.jit(jax.vmap(model.step, in_axes=(None, None, 0, 0)))
    jitted_step = jax.jit(jax.vmap(GomokuEnv.step, in_axes=(0, 0, 0)))
    episode_data = self_play_episode(
        prior_and_value_fn,
        recurrent_fn,
        jitted_step,
        params,
        rng_key,
        max_steps=max_steps,
        num_games=num_games,
        num_simulations=num_simulations,
    )
    episodes = format_episodes(*episode_data)
    return episodes

In [53]:
class TestSimulateFn(unittest.TestCase):
    def setUp(self) -> None:
        self.model, self.params = SimpleMuZeroModel.random_model(GOMOKU_BOARD_SIZE)

    def test_simulate_fn(self):
        episodes = simulate_fn(
            self.model,
            self.params,
            jax.random.PRNGKey(256),
            num_games=10,
            max_steps=10,
            num_simulations=100,
        )
        for i, episode in enumerate(episodes):
            self.assertTrue(episode.actions.shape == (10,))
            self.assertTrue(episode.boards.state.shape == (10, 5, 5))
            self.assertTrue(episode.policies.shape == (10, 25))
            self.assertTrue(episode.values.shape == (10,))


unittest.main(argv=["-k", "TestSimulateFn"], verbosity=2, exit=False)

test_simulate_fn (__main__.TestSimulateFn.test_simulate_fn) ... ok

----------------------------------------------------------------------
Ran 1 test in 5.769s

OK


<unittest.main.TestProgram at 0x2e4b62bd0>

In [56]:
def fn():
    model, params = SimpleMuZeroModel.random_model(GOMOKU_BOARD_SIZE)
    episodes = simulate_fn(model, params, jax.random.PRNGKey(256))
    for i, episode in enumerate(episodes):
        # print actions
        print(f"Episode {i}")
        print(f"Episode winner: {episode.actions}")
    return episodes


episodes = fn()

Episode 0
Episode winner: [13 11  2 18 23 24 12  5  7 21]
Episode 1
Episode winner: [ 5  1  3 18  7 22 16  8 11 10]
Episode 2
Episode winner: [20  0  2 13 12 21  6 16 11  5]
Episode 3
Episode winner: [16  2  9  0  5 20  8  1 22 15]
Episode 4
Episode winner: [13 10 24 20 19 22 16  1 11 18]
Episode 5
Episode winner: [20  1 18 12 24  5 19 10 13 11]
Episode 6
Episode winner: [ 5 15 11 16 23 13  2  6 12 19]
Episode 7
Episode winner: [16 12 10 11 18 23  2 14 20 21]
Episode 8
Episode winner: [ 0  5 17  7 13  6  1 22 18 11]
Episode 9
Episode winner: [ 0 24 22  8  2 16 11 19  9  5]


In [86]:
# Let's say a dataset is a shuffled list of steps from multiple episodes
# We can create a dataset from the episodes by flattening the steps and shuffling them


# A dataclass to hold the dataset
@dataclass
class Dataset:
    states: chex.Array
    actions: chex.Array
    policies: chex.Array
    priors: chex.Array
    values: chex.Array


def make_dataset(
    episodes: List[Episode],
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]:
    """
    Given a list of episodes, create a dataset by flattening the steps and shuffling them
    """
    # Assert that all episodes are canonical
    assert all(
        episode.canonical for episode in episodes
    ), "Not all episodes are canonical"

    # Flatten the episodes
    actions = jnp.concatenate([episode.actions for episode in episodes])
    policies = jnp.concatenate([episode.policies for episode in episodes])
    # TODO: Fix this to include the prior logits
    priors = jnp.ones_like(policies)
    values = jnp.concatenate([episode.values for episode in episodes])
    states = jnp.concatenate(
        [
            # Flatten the last two dimension of the boards tensor
            episode.boards.state.reshape(*episode.boards.state.shape[:-2], -1)
            for episode in episodes
        ]
    )
    # Shuffle the dataset
    perm = jax.random.permutation(jax.random.PRNGKey(0), len(actions))
    # Shuffle by using the permutation as an index
    return Dataset(
        states[perm], actions[perm], policies[perm], priors[perm], values[perm]
    )

In [87]:
GOMOKU_BOARD_SIZE = 5


class TestMakeDataset(unittest.TestCase):
    def setUp(self) -> None:
        self.model, self.params = SimpleMuZeroModel.random_model(GOMOKU_BOARD_SIZE)
        self.episodes = simulate_fn(
            self.model,
            self.params,
            jax.random.PRNGKey(256),
            num_games=10,
            max_steps=10,
            num_simulations=100,
        )
        self.episodes = [make_episode_canonical(episode) for episode in self.episodes]

    def test_make_dataset(self):
        dataset = make_dataset(self.episodes)
        self.assertTrue(dataset.states.shape == (100, 25))
        self.assertTrue(dataset.actions.shape == (100,))
        self.assertTrue(dataset.policies.shape == (100, 25))
        self.assertTrue(dataset.values.shape == (100,))
        self.assertTrue(dataset.priors.shape == (100, 25))


unittest.main(argv=["-k", "TestMakeDataset"], verbosity=2, exit=False)

test_make_dataset (__main__.TestMakeDataset.test_make_dataset) ... ok

----------------------------------------------------------------------
Ran 1 test in 5.816s

OK


<unittest.main.TestProgram at 0x298062e50>

In [205]:
# Rough implementation of the training loop
model, params = SimpleMuZeroModel.random_model(GOMOKU_BOARD_SIZE)
optimizer = optax.adam(learning_rate=1e-4)
optimizer_state = optimizer.init(params)
all_losses = []
rng_key = jax.random.PRNGKey(256)

In [96]:
# import progressbar
from tqdm import tqdm

In [203]:
def inner_train_loop(model, params, dataset, optimizer, optimizer_state):
    train_step_fn = jax.jit(partial(train_step, model, optimizer))
    losses = []
    # 1000 steps of training in tqdm
    for i in tqdm(range(1000)):
        loss, params, optimizer_state = train_step_fn(
            params,
            optimizer_state,
            dataset.states,
            dataset.priors,
            dataset.policies,
            dataset.values,
        )
        losses.append(loss)
    return losses, params

In [207]:
# Outer training loop
for i in range(50):
    rng_key, sub_key = jax.random.split(rng_key, 2)
    episodes = simulate_fn(
        model,
        params,
        sub_key,
        num_games=50,
        max_steps=25,
        num_simulations=100,
    )
    episodes = [make_episode_canonical(episode) for episode in episodes]
    dataset = make_dataset(episodes)
    losses, params = inner_train_loop(model, params, dataset, optimizer, optimizer_state)
    print(f"Finished training on dataset {i}")
    print(f"Final loss: {losses[-1]}")
    all_losses.extend(losses)

  0%|          | 1/1000 [00:00<04:28,  3.71it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 139.34it/s]


Finished training on dataset 0
Final loss: 3.074525833129883


  0%|          | 1/1000 [00:00<04:28,  3.73it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 137.34it/s]


Finished training on dataset 1
Final loss: 3.0948493480682373


  0%|          | 1/1000 [00:00<04:19,  3.85it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 141.32it/s]


Finished training on dataset 2
Final loss: 3.242175340652466


  0%|          | 1/1000 [00:00<04:34,  3.64it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 147.84it/s]


Finished training on dataset 3
Final loss: 3.2152419090270996


  0%|          | 1/1000 [00:00<04:35,  3.63it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 143.96it/s]


Finished training on dataset 4
Final loss: 3.2076423168182373


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 141.18it/s]


Finished training on dataset 5
Final loss: 3.299654722213745


  0%|          | 1/1000 [00:00<04:49,  3.46it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 127.60it/s]


Finished training on dataset 6
Final loss: 3.315676689147949


  0%|          | 1/1000 [00:00<04:38,  3.59it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:04<00:00, 200.05it/s]


Finished training on dataset 7
Final loss: 3.2626779079437256


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 138.54it/s]


Finished training on dataset 8
Final loss: 3.2883241176605225


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 148.34it/s]


Finished training on dataset 9
Final loss: 3.186957359313965


  0%|          | 1/1000 [00:00<04:31,  3.68it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 144.66it/s]


Finished training on dataset 10
Final loss: 3.246988534927368


  0%|          | 1/1000 [00:00<04:38,  3.58it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 142.47it/s]


Finished training on dataset 11
Final loss: 3.2401039600372314


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 142.46it/s]


Finished training on dataset 12
Final loss: 3.2851781845092773


  0%|          | 1/1000 [00:00<04:46,  3.49it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 137.18it/s]


Finished training on dataset 13
Final loss: 3.267923355102539


  0%|          | 1/1000 [00:00<04:26,  3.75it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 143.07it/s]


Finished training on dataset 14
Final loss: 3.291376829147339


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 149.36it/s]


Finished training on dataset 15
Final loss: 3.1994307041168213


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 137.69it/s]


Finished training on dataset 16
Final loss: 3.2609963417053223


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 132.18it/s]


Finished training on dataset 17
Final loss: 3.1836392879486084


  0%|          | 1/1000 [00:00<04:49,  3.46it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 135.34it/s]


Finished training on dataset 18
Final loss: 3.2620110511779785


  0%|          | 1/1000 [00:00<04:48,  3.46it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 138.71it/s]


Finished training on dataset 19
Final loss: 3.2220535278320312


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 141.41it/s]


Finished training on dataset 20
Final loss: 3.332808017730713


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 134.87it/s]


Finished training on dataset 21
Final loss: 3.15388560295105


  0%|          | 1/1000 [00:00<04:25,  3.76it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 137.76it/s]


Finished training on dataset 22
Final loss: 3.235524892807007


  0%|          | 1/1000 [00:00<04:49,  3.46it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 154.43it/s]


Finished training on dataset 23
Final loss: 3.270639181137085


  0%|          | 1/1000 [00:00<04:22,  3.81it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 152.29it/s]


Finished training on dataset 24
Final loss: 3.2989344596862793


  0%|          | 1/1000 [00:00<04:44,  3.51it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 147.44it/s]


Finished training on dataset 25
Final loss: 3.32116961479187


  0%|          | 1/1000 [00:00<04:36,  3.62it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 155.53it/s]


Finished training on dataset 26
Final loss: 3.178536891937256


  0%|          | 1/1000 [00:00<04:31,  3.68it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 142.91it/s]


Finished training on dataset 27
Final loss: 3.175217866897583


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 149.69it/s]


Finished training on dataset 28
Final loss: 3.2756409645080566


  0%|          | 1/1000 [00:00<04:20,  3.83it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 151.45it/s]


Finished training on dataset 29
Final loss: 3.2781529426574707


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 135.19it/s]


Finished training on dataset 30
Final loss: 3.4033195972442627


  0%|          | 1/1000 [00:00<04:41,  3.55it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 164.31it/s]


Finished training on dataset 31
Final loss: 3.3405046463012695


  0%|          | 1/1000 [00:00<04:22,  3.81it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 145.29it/s]


Finished training on dataset 32
Final loss: 3.213669776916504


  0%|          | 1/1000 [00:00<04:24,  3.78it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 134.88it/s]


Finished training on dataset 33
Final loss: 3.2822320461273193


  0%|          | 1/1000 [00:00<04:29,  3.70it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 151.34it/s]


Finished training on dataset 34
Final loss: 3.282973527908325


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 147.98it/s]


Finished training on dataset 35
Final loss: 3.198570489883423


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:04<00:00, 203.11it/s]


Finished training on dataset 36
Final loss: 3.2266080379486084


  0%|          | 1/1000 [00:00<04:46,  3.48it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 147.90it/s]


Finished training on dataset 37
Final loss: 3.318864345550537


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:07<00:00, 136.73it/s]


Finished training on dataset 38
Final loss: 3.3077762126922607


  0%|          | 1/1000 [00:00<04:06,  4.06it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 162.03it/s]


Finished training on dataset 39
Final loss: 3.350271463394165


  0%|          | 1/1000 [00:00<04:43,  3.53it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 144.49it/s]


Finished training on dataset 40
Final loss: 3.231992244720459


  0%|          | 1/1000 [00:00<04:12,  3.96it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 149.65it/s]


Finished training on dataset 41
Final loss: 3.20151948928833


  0%|          | 1/1000 [00:00<04:30,  3.69it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 152.43it/s]


Finished training on dataset 42
Final loss: 3.2964024543762207


  0%|          | 1/1000 [00:00<04:42,  3.53it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 147.14it/s]


Finished training on dataset 43
Final loss: 3.2633228302001953


  0%|          | 1/1000 [00:00<04:24,  3.78it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 149.01it/s]


Finished training on dataset 44
Final loss: 3.2956535816192627


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 151.27it/s]


Finished training on dataset 45
Final loss: 3.277076244354248


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 144.56it/s]


Finished training on dataset 46
Final loss: 3.269977569580078


  0%|          | 1/1000 [00:00<04:27,  3.73it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 147.77it/s]


Finished training on dataset 47
Final loss: 3.2659361362457275


  0%|          | 0/1000 [00:00<?, ?it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 144.18it/s]


Finished training on dataset 48
Final loss: 3.3390278816223145


  0%|          | 1/1000 [00:00<04:39,  3.58it/s]

The updates are on the following parameters {'params': {'Dense_0': {'bias': (256,), 'kernel': (25, 256)}, 'Dense_1': {'bias': (256,), 'kernel': (256, 256)}, 'Dense_2': {'bias': (1,), 'kernel': (256, 1)}, 'Dense_3': {'bias': (25,), 'kernel': (256, 25)}}}


100%|██████████| 1000/1000 [00:06<00:00, 147.42it/s]

Finished training on dataset 49
Final loss: 3.3250412940979004





In [208]:
# Plot the losses
fig = px.line(
    x=range(len(all_losses)),
    y=all_losses,
    title="Losses for toy problem",
    labels={"x": "Step", "y": "Loss"},
)
fig.show()

In [172]:
def simulate_step(model, params, board, num_simulations=100):
    prior_and_value_fn = jax.jit(
        jax.vmap(model.prior_and_value, in_axes=(None, None, 0, None))
    )
    recurrent_fn = jax.jit(jax.vmap(model.step, in_axes=(None, None, 0, 0)))
    jitted_step = jax.jit(jax.vmap(GomokuEnv.step, in_axes=(0, 0, 0)))
    output, policy_output = mctx_step(
        prior_and_value_fn,
        recurrent_fn,
        jitted_step,
        params,
        jax.random.PRNGKey(256),
        board,
        num_simulations,
    )
    return output, policy_output

In [219]:
# Create a board with 4 in a row for player 1
board = GomokuScenario.make_scenario(
    [
        GomokuStroke(
            player=Player.PLAYER_1,
            row=0,
            col=0,
            stroke_type=GomokuStroke.StrokeType.DIAG,
            length=4,
        ),
        GomokuStroke(
            player=Player.PLAYER_2,
            row=1,
            col=0,
            stroke_type=GomokuStroke.StrokeType.COL,
            length=4,
        ),
    ],
    Player.PLAYER_1,
).canonical()

In [220]:
Board.expand(board)

Board(state=Array([[[1, 0, 0, 0, 0],
        [2, 1, 0, 0, 0],
        [2, 0, 1, 0, 0],
        [2, 0, 0, 1, 0],
        [2, 0, 0, 0, 0]]], dtype=int32), player=Array([1], dtype=int32, weak_type=True))

In [223]:
output, policy_output = simulate_step(model, params, Board.expand(board), num_simulations=1000)

In [224]:
output.board.state

Array([[[1, 0, 0, 0, 0],
        [2, 1, 0, 0, 0],
        [2, 0, 1, 0, 0],
        [2, 0, 0, 1, 0],
        [2, 0, 0, 0, 1]]], dtype=int32)

In [226]:
# Run an episode of self play using simulate_fn

episodes = simulate_fn(
    model,
    params,
    jax.random.PRNGKey(256),
    num_games=1,
    max_steps=25,
    num_simulations=100,
)


In [231]:
episodes[0].boards.state

Array([[[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0],
        [0, 0, 0, 2, 1],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0],
        [1, 0, 0, 2, 1],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0],
        [1, 0, 0, 2, 1],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 2],
        [0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0],
        [1, 0, 0, 2, 1],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 1, 2],
        [0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0],
        [1, 0, 0, 2, 1],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 1, 2],
        [0, 0, 0, 2, 0]],

       [[0, 0, 0, 0, 0],
        [1, 0, 0, 2, 1],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 2],
        [0, 0, 0, 2, 0]],

       [[0, 0, 0, 0, 0],
        [1, 0, 0, 2, 1],
        [0, 0, 0, 1, 2],
        [0, 0, 0, 1, 2],
        [0,