In [None]:
import jax
import jax.numpy as jnp
from jax import lax
import time
import numpy as np
import numba
import functools
from ttt.cython.c_tictactoe import check_winner

In [None]:
# Win conditions defined as indices on the board
def check_win1(board: jnp.ndarray) -> int:
    win_conditions = jnp.array(
        [
            (0, 1, 2),
            (3, 4, 5),
            (6, 7, 8),
            (0, 3, 6),
            (1, 4, 8),
            (2, 5, 8),
            (0, 4, 8),
            (2, 4, 6),
        ],
        dtype=jnp.int32,
    )

    def check_line(winner, line):
        line_win = lax.cond(
            (board[line[0]] == board[line[1]])
            & (board[line[1]] == board[line[2]])
            & (board[line[0]] != 0),
            lambda: board[line[0]].astype(jnp.int32),  # Ensuring int32 output
            lambda: jnp.array(0, dtype=jnp.int32),  # Ensuring int32 output
        )
        return jnp.maximum(winner, line_win), None

    # Use `jnp.array(0)` as the initial carry value, which represents "no winner"
    winner, _ = lax.scan(check_line, jnp.array(0), win_conditions)
    return winner  # Returns 1 if player wins, 2 if opponent wins, 0 otherwise

In [None]:
def check_win2(board: jnp.ndarray) -> int:
    reshaped_state = board.reshape(3, 3)
    # 0 for tie, 1 for player 1, 2 for player 2
    # Check for row-wise winners
    rows_equal = jnp.all(reshaped_state == reshaped_state[:, [0]], axis=1) & (
        reshaped_state[:, 0] != 0
    )
    row_winners = reshaped_state[:, 0] * rows_equal

    # Check for column-wise winners
    cols_equal = jnp.all(reshaped_state == reshaped_state[[0], :], axis=0) & (
        reshaped_state[0, :] != 0
    )
    col_winners = reshaped_state[0, :] * cols_equal

    # Check for diagonal winners
    diagonal_1 = (
        (reshaped_state[0, 0] != 0)
        & (reshaped_state[0, 0] == reshaped_state[1, 1])
        & (reshaped_state[0, 0] == reshaped_state[2, 2])
    )
    diagonal_1_winners = reshaped_state[0, 0] * diagonal_1

    diagonal_2 = (
        (reshaped_state[0, 2] != 0)
        & (reshaped_state[0, 2] == reshaped_state[1, 1])
        & (reshaped_state[0, 2] == reshaped_state[2, 0])
    )
    diagonal_2_winners = reshaped_state[0, 2] * diagonal_2

    # Return the maximum winner (1 for player 1, 2 for player 2, 0 for no winner)
    return jnp.maximum(
        row_winners.max(),
        jnp.maximum(
            col_winners.max(),
            jnp.maximum(diagonal_1_winners, diagonal_2_winners),
        ),
    )

In [None]:
def check_win2_np(board: np.ndarray) -> int:
    reshaped_state = board.reshape(3, 3)
    # 0 for tie, 1 for player 1, 2 for player 2
    # Check for row-wise winners
    rows_equal = np.all(reshaped_state == reshaped_state[:, [0]], axis=1) & (
        reshaped_state[:, 0] != 0
    )
    row_winners = reshaped_state[:, 0] * rows_equal

    # Check for column-wise winners
    cols_equal = np.all(reshaped_state == reshaped_state[[0], :], axis=0) & (
        reshaped_state[0, :] != 0
    )
    col_winners = reshaped_state[0, :] * cols_equal

    # Check for diagonal winners
    diagonal_1 = (
        (reshaped_state[0, 0] != 0)
        & (reshaped_state[0, 0] == reshaped_state[1, 1])
        & (reshaped_state[0, 0] == reshaped_state[2, 2])
    )
    diagonal_1_winners = reshaped_state[0, 0] * diagonal_1

    diagonal_2 = (
        (reshaped_state[0, 2] != 0)
        & (reshaped_state[0, 2] == reshaped_state[1, 1])
        & (reshaped_state[0, 2] == reshaped_state[2, 0])
    )
    diagonal_2_winners = reshaped_state[0, 2] * diagonal_2

    # Return the maximum winner (1 for player 1, 2 for player 2, 0 for no winner)
    return max(
        row_winners.max(),
        max(
            col_winners.max(),
            max(diagonal_1_winners, diagonal_2_winners),
        ),
    )

In [None]:
def check_win3(board: jnp.ndarray) -> int:
    # 0 for tie, 1 for player 1, 2 for player 2
    reshaped_state = board.reshape(3, 3)
    # rows
    rows_equal = jnp.all(reshaped_state == reshaped_state[:, [0]], axis=1) & (
        reshaped_state[:, 0] != 0
    )
    if jnp.any(rows_equal):
        return reshaped_state[rows_equal][0, 0]
    # columns
    cols_equal = jnp.all(reshaped_state == reshaped_state[[0], :], axis=0) & (
        reshaped_state[0, :] != 0
    )
    if jnp.any(cols_equal):
        return reshaped_state[0, cols_equal][0]
    # diagonals
    if reshaped_state[0, 0] == reshaped_state[1, 1] == reshaped_state[2, 2] != 0:
        return reshaped_state[0, 0]
    if reshaped_state[2, 0] == reshaped_state[1, 1] == reshaped_state[0, 2] != 0:
        return reshaped_state[2, 0]
    return 0

In [None]:
def check_win3_np(board: np.ndarray) -> int:
    # 0 for tie, 1 for player 1, 2 for player 2
    reshaped_state = board.reshape(3, 3)
    # rows
    rows_equal = np.all(reshaped_state == reshaped_state[:, [0]], axis=1) & (
        reshaped_state[:, 0] != 0
    )
    if np.any(rows_equal):
        return reshaped_state[rows_equal][0, 0]
    # columns
    cols_equal = np.all(reshaped_state == reshaped_state[[0], :], axis=0) & (
        reshaped_state[0, :] != 0
    )
    if np.any(cols_equal):
        return reshaped_state[0, cols_equal][0]
    # diagonals
    if reshaped_state[0, 0] == reshaped_state[1, 1] == reshaped_state[2, 2] != 0:
        return reshaped_state[0, 0]
    if reshaped_state[2, 0] == reshaped_state[1, 1] == reshaped_state[0, 2] != 0:
        return reshaped_state[2, 0]
    return 0

In [None]:
@numba.njit
def check_win3_numba(board: np.ndarray) -> int:
    # Reshape the board
    reshaped_state = board.reshape(3, 3)
    
    # Check rows
    for row in range(3):
        if reshaped_state[row, 0] != 0 and reshaped_state[row, 0] == reshaped_state[row, 1] == reshaped_state[row, 2]:
            return reshaped_state[row, 0]
    
    # Check columns
    for col in range(3):
        if reshaped_state[0, col] != 0 and reshaped_state[0, col] == reshaped_state[1, col] == reshaped_state[2, col]:
            return reshaped_state[0, col]
    
    # Check diagonals
    if reshaped_state[0, 0] != 0 and reshaped_state[0, 0] == reshaped_state[1, 1] == reshaped_state[2, 2]:
        return reshaped_state[0, 0]
    if reshaped_state[2, 0] != 0 and reshaped_state[2, 0] == reshaped_state[1, 1] == reshaped_state[0, 2]:
        return reshaped_state[2, 0]
    
    # No winner
    return 0

In [None]:
def check_win4(board: jnp.ndarray) -> int:
    win_conditions = jnp.array(
        [
            [0, 1, 2],  # Rows
            [3, 4, 5],
            [6, 7, 8],
            [0, 3, 6],  # Columns
            [1, 4, 7],
            [2, 5, 8],
            [0, 4, 8],  # Diagonals
            [2, 4, 6],
        ],
        dtype=jnp.int32,
    )
    # Extract the values at the win conditions
    lines = board[win_conditions]  # Shape: (8, 3)

    # Check if all elements in a line are the same and not zero
    lines_equal = (lines[:, 0] == lines[:, 1]) & (lines[:, 1] == lines[:, 2]) & (lines[:, 0] != 0)

    # Get the winner for each line
    line_winners = lines[:, 0] * lines_equal  # Winner is lines[:, 0] if line is equal, else zero

    # Return the maximum winner (1 or 2 if there's a winner, 0 otherwise)
    return jnp.max(line_winners)

In [None]:
def check_win5(board: jnp.ndarray) -> int:
    win_conditions = jnp.array(
        [
            [0, 1, 2],  # Rows
            [3, 4, 5],
            [6, 7, 8],
            [0, 3, 6],  # Columns
            [1, 4, 7],
            [2, 5, 8],
            [0, 4, 8],  # Diagonals
            [2, 4, 6],
        ],
        dtype=jnp.int32,
    )

    # Gather the board values at the win condition indices
    lines = board[win_conditions]  # Shape: (8, 3)

    # Check if all elements in a line are equal and not zero
    lines_equal = (lines == lines[:, [0]]) & (lines[:, 0:1] != 0)
    winners = lines[:, 0] * jnp.all(lines_equal, axis=1)

    # Return the maximum winner (1 or 2 if there's a winner, 0 otherwise)
    return jnp.max(winners)

In [None]:
def benchmark(func):
    key, rng = jax.random.split(jax.random.PRNGKey(0), 2)
    for _ in range(5):
        key, rng = jax.random.split(key, 2)
        board= jax.random.randint(rng, (9,), 0, 3)
        func(board)

def benchmark_np(func):
    for _ in range(5):
        board = np.random.randint(0, 3, 9)
        func(board)

@functools.partial(jax.jit, static_argnums=(0,))
def benchmark_jax(func):
    key, rng = jax.random.split(jax.random.PRNGKey(0), 2)
    for _ in range(5):
        key, rng = jax.random.split(key, 2)
        board= jax.random.randint(rng, (9,), 0, 3)
        jax.block_until_ready(func(board))

In [None]:
jit_check_win1 = jax.jit(check_win1)
jit_check_win2 = jax.jit(check_win2)
jit_check_win3 = jax.jit(check_win3)
jit_check_win4 = jax.jit(check_win4)
jit_check_win5 = jax.jit(check_win5)
jit_check_win1(jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))
jit_check_win2(jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))
# jit_check_win3(jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))
jit_check_win4(jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))
jit_check_win5(jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))

In [None]:
%timeit benchmark(check_win1)
%timeit benchmark(jit_check_win1)
%timeit benchmark(check_win2)
%timeit benchmark(jit_check_win2)
%timeit benchmark(check_win3)
%timeit benchmark(check_win4)
%timeit benchmark(jit_check_win4)
%timeit benchmark(check_win5)
%timeit benchmark(jit_check_win5)

In [None]:
%timeit benchmark_jax(check_win1)
%timeit benchmark_jax(jit_check_win1)
%timeit benchmark_jax(check_win2)
%timeit benchmark_jax(jit_check_win2)
%timeit benchmark_jax(check_win4)
%timeit benchmark_jax(jit_check_win4)
%timeit benchmark_jax(check_win5)
%timeit benchmark_jax(jit_check_win5)

In [None]:
%timeit benchmark_np(check_win2_np)
%timeit benchmark_np(check_win3_np)

In [None]:
check_win3_numba(np.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))

In [None]:
%timeit benchmark_np(check_win3_numba)

In [None]:
%timeit benchmark_np(check_winner)

In [None]:
jax.make_jaxpr(check_win1)(np.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))

In [None]:
jax.make_jaxpr(check_win2)(np.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))

In [None]:
jax.make_jaxpr(check_win4)(np.array([1, 1, 1, 0, 0, 0, 0, 0, 0]))

In [None]:
def benchmark_empty():
    key, rng = jax.random.split(jax.random.PRNGKey(0), 2)
    for _ in range(5):
        key, rng = jax.random.split(key, 2)
        board=jax.random.randint(rng, (9,), 0, 3)

def benchmark_np_empty():
    for _ in range(5):
        board = np.random.randint(0, 3, 9)

@functools.partial(jax.jit, static_argnums=(0,))
def benchmark_jax_emtpy(func):
    key, rng = jax.random.split(jax.random.PRNGKey(0), 2)
    for _ in range(5):
        key, rng = jax.random.split(key, 2)
        board= jax.random.randint(rng, (9,), 0, 3)

@numba.njit
def benchmark_numba_empty():
    for _ in range(5):
        board = np.random.randint(0, 3, 9)

In [None]:
%timeit benchmark_empty()
%timeit benchmark_np_empty()
%timeit benchmark_jax_emtpy(benchmark_empty)
%timeit benchmark_numba_empty()

In [None]:
# jax.profiler.start_server(port=6007)
# with jax.profiler.trace('jax_trace', create_perfetto_link=True):
#     key, rng = jax.random.split(jax.random.PRNGKey(0), 2)
#     for _ in range(1000):
#         key, rng = jax.random.split(key, 2)
#         board= jax.random.randint(rng, (9,), 0, 3)
#         jit_check_win4(board)