In [1]:
import jax
import jax.numpy as jnp
from jax import lax
import time

In [2]:
# Win conditions defined as indices on the board
def check_win1(board: jnp.ndarray) -> int:
    win_conditions = jnp.array(
        [
            (0, 1, 2),
            (3, 4, 5),
            (6, 7, 8),
            (0, 3, 6),
            (1, 4, 8),
            (2, 5, 8),
            (0, 4, 8),
            (2, 4, 6),
        ],
        dtype=jnp.int32,
    )

    def check_line(winner, line):
        line_win = lax.cond(
            (board[line[0]] == board[line[1]])
            & (board[line[1]] == board[line[2]])
            & (board[line[0]] != 0),
            lambda: board[line[0]].astype(jnp.int32),  # Ensuring int32 output
            lambda: jnp.array(0, dtype=jnp.int32),  # Ensuring int32 output
        )
        return jnp.maximum(winner, line_win), None

    # Use `jnp.array(0)` as the initial carry value, which represents "no winner"
    winner, _ = lax.scan(check_line, jnp.array(0), win_conditions)
    return winner  # Returns 1 if player wins, 2 if opponent wins, 0 otherwise

In [3]:
def check_win2(board: jnp.ndarray) -> int:
    reshaped_state = board.reshape(3, 3)
    # 0 for tie, 1 for player 1, 2 for player 2
    # Check for row-wise winners
    rows_equal = jnp.all(reshaped_state == reshaped_state[:, [0]], axis=1) & (
        reshaped_state[:, 0] != 0
    )
    row_winners = reshaped_state[:, 0] * rows_equal

    # Check for column-wise winners
    cols_equal = jnp.all(reshaped_state == reshaped_state[[0], :], axis=0) & (
        reshaped_state[0, :] != 0
    )
    col_winners = reshaped_state[0, :] * cols_equal

    # Check for diagonal winners
    diagonal_1 = (
        (reshaped_state[0, 0] != 0)
        & (reshaped_state[0, 0] == reshaped_state[1, 1])
        & (reshaped_state[0, 0] == reshaped_state[2, 2])
    )
    diagonal_1_winners = reshaped_state[0, 0] * diagonal_1

    diagonal_2 = (
        (reshaped_state[0, 2] != 0)
        & (reshaped_state[0, 2] == reshaped_state[1, 1])
        & (reshaped_state[0, 2] == reshaped_state[2, 0])
    )
    diagonal_2_winners = reshaped_state[0, 2] * diagonal_2

    # Return the maximum winner (1 for player 1, 2 for player 2, 0 for no winner)
    return jnp.maximum(
        row_winners.max(),
        jnp.maximum(
            col_winners.max(),
            jnp.maximum(diagonal_1_winners, diagonal_2_winners),
        ),
    )

In [4]:
def benchmark(func):
    key, rng = jax.random.split(jax.random.PRNGKey(0), 2)
    for _ in range(5):
        key, rng = jax.random.split(key, 2)
        board= jax.random.randint(rng, (9,), 0, 3)
        func(board)

In [7]:
jit_check_win1 = jax.jit(check_win1)
jit_check_win2 = jax.jit(check_win2)
jit_check_win1(jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))
jit_check_win2(jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))

Array(1, dtype=int32)

In [8]:
%timeit benchmark(check_win1)
%timeit benchmark(jit_check_win1)
%timeit benchmark(check_win2)
%timeit benchmark(jit_check_win2)

227 ms ± 9.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
767 μs ± 2.74 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
14.8 ms ± 537 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
766 μs ± 3.62 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
