In [1]:
import numpy as np

In [2]:
rng = np.random.default_rng()
reshaped_state = rng.integers(0, 3, size=(4, 3, 3))
reshaped_state

array([[[0, 0, 1],
        [2, 0, 0],
        [2, 1, 1]],

       [[0, 2, 1],
        [1, 1, 2],
        [1, 0, 1]],

       [[0, 1, 1],
        [2, 2, 2],
        [0, 2, 2]],

       [[0, 0, 2],
        [2, 0, 2],
        [1, 2, 0]]])

In [3]:
reshaped_state[:, 0]

array([[0, 0, 1],
       [0, 2, 1],
       [0, 1, 1],
       [0, 0, 2]])

In [4]:
rows_equal = (reshaped_state[:, :, 0] != 0) & (np.diff(reshaped_state, axis=2) == 0).all(axis=2)
rows_equal

array([[False, False, False],
       [False, False, False],
       [False,  True, False],
       [False, False, False]])

In [5]:
cols_equal = (reshaped_state[:, 0, :] != 0) & (np.diff(reshaped_state, axis=1) == 0).all(axis=1)
cols_equal

array([[False, False, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]])

DIAGONALS

In [6]:
reshaped_state[:, 0, 0]

array([0, 0, 0, 0])

In [7]:
reshaped_state[:, np.eye(3, dtype=bool)]

array([[0, 0, 1],
       [0, 1, 1],
       [0, 2, 2],
       [0, 0, 0]])

In [8]:
diagonal_1 = (reshaped_state[:, 0, 0] != 0) & (reshaped_state[:, 0, 0] == reshaped_state[:, 1, 1]) & (reshaped_state[:, 0, 0] == reshaped_state[:, 2, 2])
diagonal_1

array([False, False, False, False])

OTHER DIAGONAL

In [9]:
diagonal_2 = (reshaped_state[:, 0, 2] != 0) & (reshaped_state[:, 0, 2] == reshaped_state[:, 1, 1]) & (reshaped_state[:, 0, 2] == reshaped_state[:, 2, 0])
diagonal_2

array([False,  True, False, False])

In [10]:
rows_equal = np.all(reshaped_state == reshaped_state[:, :, [0]], axis=2) & (
    reshaped_state[:, :, 0] != 0
)
row_winners = reshaped_state[:, :, 0] * rows_equal
cols_equal = np.all(reshaped_state == reshaped_state[:, [0], :], axis=1) & (
    reshaped_state[:, 0, :] != 0
)
col_winners = reshaped_state[:, 0, :] * cols_equal
diagonal_1 = (
    (reshaped_state[:, 0, 0] != 0)
    & (reshaped_state[:, 0, 0] == reshaped_state[:, 1, 1])
    & (reshaped_state[:, 0, 0] == reshaped_state[:, 2, 2])
)
diagonal_1_winners = reshaped_state[:, 0, 0] * diagonal_1
diagonal_2 = (
    (reshaped_state[:, 0, 2] != 0)
    & (reshaped_state[:, 0, 2] == reshaped_state[:, 1, 1])
    & (reshaped_state[:, 0, 2] == reshaped_state[:, 2, 0])
)
diagonal_2_winners = reshaped_state[:, 0, 2] * diagonal_2

In [11]:
np.maximum(
    row_winners.max(axis=1),
    np.maximum(
        col_winners.max(axis=1),
        np.maximum(diagonal_1_winners, diagonal_2_winners),
    ),
)

array([0, 1, 2, 0])

### JAX speedup

In [None]:
import jax
import jax.numpy as jnp

def check_win(state) -> int:
    reshaped_state = state.reshape(3, 3)  # Reshape the state into a 2D (3x3) array
    # Check for row-wise winners
    rows_equal = jnp.all(reshaped_state == reshaped_state[:, [0]], axis=1) & (
        reshaped_state[:, 0] != 0
    )
    row_winners = reshaped_state[:, 0] * rows_equal
    
    # Check for column-wise winners
    cols_equal = jnp.all(reshaped_state == reshaped_state[[0], :], axis=0) & (
        reshaped_state[0, :] != 0
    )
    col_winners = reshaped_state[0, :] * cols_equal
    
    # Check for diagonal winners
    diagonal_1 = (
        (reshaped_state[0, 0] != 0)
        & (reshaped_state[0, 0] == reshaped_state[1, 1])
        & (reshaped_state[0, 0] == reshaped_state[2, 2])
    )
    diagonal_1_winners = reshaped_state[0, 0] * diagonal_1
    
    diagonal_2 = (
        (reshaped_state[0, 2] != 0)
        & (reshaped_state[0, 2] == reshaped_state[1, 1])
        & (reshaped_state[0, 2] == reshaped_state[2, 0])
    )
    diagonal_2_winners = reshaped_state[0, 2] * diagonal_2

    # Return the maximum winner (1 for player 1, 2 for player 2, 0 for no winner)
    return jnp.maximum(
        row_winners.max(),
        jnp.maximum(
            col_winners.max(),
            jnp.maximum(diagonal_1_winners, diagonal_2_winners),
        ),
    )

# JIT-compiled version
check_win_jit = jax.jit(check_win)

In [None]:
def check_non_jit(rng, n=10):
    for _ in range(n):
        rng, _ = jax.random.split(rng)
        state = jax.random.randint(rng, (9,), 0, 3)
        check_win(state)

def check_jit(rng, n=10):
    for _ in range(n):
        rng, _ = jax.random.split(rng)
        state = jax.random.randint(rng, (9,), 0, 3)
        check_win_jit(state)

In [22]:
%timeit check_non_jit(jax.random.PRNGKey(np.random.randint(2**31)))

2.66 s ± 29.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
%timeit check_jit(jax.random.PRNGKey(np.random.randint(2**31)))

97.4 ms ± 670 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
