In [None]:
import python_bindings

In [None]:
python_bindings.greet("a")

In [None]:
a = 0x1000
b = 0x1010
bin(a & b)

In [None]:
import numpy as np

In [None]:
rng = np.random.default_rng()
reshaped_state = rng.integers(0, 3, size=(4, 3, 3))
reshaped_state

### JAX speedup

In [None]:
import jax
import jax.numpy as jnp

def check_win(state) -> int:
    reshaped_state = state.reshape(3, 3)  # Reshape the state into a 2D (3x3) array
    # Check for row-wise winners
    rows_equal = jnp.all(reshaped_state == reshaped_state[:, [0]], axis=1) & (
        reshaped_state[:, 0] != 0
    )
    row_winners = reshaped_state[:, 0] * rows_equal
    
    # Check for column-wise winners
    cols_equal = jnp.all(reshaped_state == reshaped_state[[0], :], axis=0) & (
        reshaped_state[0, :] != 0
    )
    col_winners = reshaped_state[0, :] * cols_equal
    
    # Check for diagonal winners
    diagonal_1 = (
        (reshaped_state[0, 0] != 0)
        & (reshaped_state[0, 0] == reshaped_state[1, 1])
        & (reshaped_state[0, 0] == reshaped_state[2, 2])
    )
    diagonal_1_winners = reshaped_state[0, 0] * diagonal_1
    
    diagonal_2 = (
        (reshaped_state[0, 2] != 0)
        & (reshaped_state[0, 2] == reshaped_state[1, 1])
        & (reshaped_state[0, 2] == reshaped_state[2, 0])
    )
    diagonal_2_winners = reshaped_state[0, 2] * diagonal_2

    # Return the maximum winner (1 for player 1, 2 for player 2, 0 for no winner)
    return jnp.maximum(
        row_winners.max(),
        jnp.maximum(
            col_winners.max(),
            jnp.maximum(diagonal_1_winners, diagonal_2_winners),
        ),
    )

# JIT-compiled version
check_win_jit = jax.jit(check_win)

In [None]:
def check_non_jit(rng, n=10):
    for _ in range(n):
        rng, _ = jax.random.split(rng)
        state = jax.random.randint(rng, (9,), 0, 3)
        check_win(state)

def check_jit(rng, n=10):
    for _ in range(n):
        rng, _ = jax.random.split(rng)
        state = jax.random.randint(rng, (9,), 0, 3)
        check_win_jit(state)

In [None]:
%timeit check_non_jit(jax.random.PRNGKey(np.random.randint(2**31)))

In [None]:
%timeit check_jit(jax.random.PRNGKey(np.random.randint(2**31)))