In [1]:
import chex
import jax
import jax.numpy as jnp
import mctx

In [2]:
Board = chex.Array
Start = chex.Array
Target = chex.Array
Goal = chex.Array
Action = chex.Array
Player = chex.Array
Reward = chex.Array
Done = chex.Array
Action_set = chex.Array
Pins = chex.Array
Num_players = chex.Array
Size = chex.Array

@chex.dataclass
class deterministic_MADN:
    board: Board  # shape (64,), values in {0, 1, 2, 3, 4} for empty, player 1, player 2, player 3, player 4
    num_players: Num_players
    current_player: Player  # scalar, 1, 2, 3, or 4
    pins : Pins  # shape (num_players,4), positions of the players' pins
    reward: Reward  # scalar, reward for the current player
    done: Done  # scalar, whether the game is over
    action_set: Action_set  # available actions, 1-6 each 3x until empty, then refilled
    start: Start  # shape (num_players,), starting indices of the players
    target: Target  # shape (num_players,), positions before the goals of the players
    goal: Goal  # shape (num_players,4), goal positions of the players
    board_size: Size  # scalar, size of the board (num_players * distance)
    total_board_size: Size  # scalar, size of the board + goal areas (num_players * distance + num_players * 4)
    rules = {
        'enable_initial_free_pin':False,
        'enable_circular_board':True,
        'enable_friendly_fire':False,
        'enable_start_on_6':True,
    }

In [3]:
def get_winner(board, goal_area) -> Player:
    '''
    returns the index of the winning player or 0 if tie or not Done
    '''
    goals = board[goal_area]
    player_goals = jnp.all(goals >= 0, axis=1)
    return jnp.where(jnp.any(player_goals), jnp.argmax(player_goals), -1)

def env_reset(_, num_players=jnp.int8(4), distance=jnp.int8(10)) -> deterministic_MADN:
    board_size = num_players * distance
    total_board_size = board_size + num_players * 4 # add goal areas
    num_pins = 4
    return deterministic_MADN(
        board = - jnp.ones(total_board_size, dtype=jnp.int8), # board is filled with -1 (empty) or 0-3 (player index)
        num_players = jnp.array(num_players, dtype=jnp.int8), # number of players
        pins = - jnp.ones((num_players,num_pins), dtype=jnp.int8), # index of each players' pins, -1 means in start area
        current_player=jnp.array(0, dtype=jnp.int8), # index of current player, 0-3
        done = jnp.bool_(False), # whether the game is over
        reward=jnp.array(0, dtype=jnp.int8), # reward for the current player
        action_set= num_pins * jnp.ones((num_players, 6), dtype=jnp.int8), # each player starts with 4 actions 1-6
        start = jnp.array(jnp.arange(num_players)*distance, dtype=jnp.int8), # starting positions of each player
        target = jnp.array((jnp.arange(num_players)*distance - 1)%board_size, dtype=jnp.int8),
        goal = jnp.reshape(jnp.arange(board_size, board_size + num_players*4, dtype=jnp.int8), (num_players, 4)),
        board_size=jnp.array(board_size, dtype=jnp.int8),
        total_board_size=jnp.array(total_board_size, dtype=jnp.int8),
    )

@jax.jit
def set_pins_on_board(board, pins):
    num_players, num_pins = pins.shape

    def body(idx, board):
        player = idx // num_pins
        pin = idx % num_pins
        pos = pins[player, pin]
        board = jax.lax.cond(
            pos != -1,
            lambda b: b.at[pos].set(player),
            lambda b: b,
            board
        )
        return board

    board = jax.lax.fori_loop(0, num_players * num_pins, body, board)
    return board

def refill_action_set(env:deterministic_MADN) -> chex.Array:
    '''
    Refills the action set for the current player if all actions are used up.
    '''
    return env.action_set.at[env.current_player].set(env.pins.shape[1] * jnp.ones(6, dtype=jnp.int8))

def no_step(env:deterministic_MADN) -> deterministic_MADN:
    """
    No-op step function for the environment.
    """
    act_set = refill_action_set(env)
    env = deterministic_MADN(
        board=env.board,
        num_players=env.num_players,
        pins=env.pins,
        current_player=(env.current_player + 1) % env.num_players,
        done=env.done,
        reward=env.reward,
        action_set= act_set,
        start=env.start,
        target=env.target,
        goal=env.goal,
        board_size=env.board_size,
        total_board_size=env.total_board_size,
    )
    return env, jnp.array(0, dtype=jnp.int8), env.done

@jax.jit
def valid_action(env:deterministic_MADN) -> chex.Array:
    '''
    Returns a mask of shape (4, 6) indicating which actions are valid for each pin of the current player
    '''
    #return valid_action for each pin of the current player
    current_player = env.current_player
    current_pins = env.pins[current_player]
    board = env.board
    target = env.target[current_player]
    goal = env.goal[current_player]
    action_set = env.action_set[current_player]
    valid_actions = jnp.where(action_set>0, True, False) # available actions for each pin

    # calculate possible actions
    current_positions = current_pins[:, None]
    moved_positions = current_pins[:, None] + jnp.arange(1, 7)
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - target

    # filter out invalid moves blocked by own pins
    result = (board[fitted_positions] != current_player) # check move to any board position

    result = jnp.where(
        (4 >= x) & (x > 0) & (current_positions <= env.target[current_player]),
        result | (board[goal[x-1]] != current_player), # if goal is possible, check if goal position is free
        result
    )
    # filter actions for pins in goal area
    result = jnp.where(
        jnp.isin(current_pins, goal)[:, None],
        (moved_positions <= goal[-1]) & (board[moved_positions%env.total_board_size] != current_player),
        result
    )

    # filter actions for pins in start area
    result = jnp.where(
        (current_pins == -1)[:, None],
        jnp.isin(jnp.arange(1, 7), jnp.array([1, 6])) & (env.board[env.start[current_player]] != env.current_player),
        result
    )
    return result & valid_actions # filter possible actions with available actions


#### Testing whether many small calls or own functions for each step variant is more efficient

In [4]:
@jax.jit
def env_step(env: deterministic_MADN, action: Action) -> deterministic_MADN:
    pin = action[0].astype(jnp.int8)
    move = action[1].astype(jnp.int8) # action is in {1, 2, 3, 4, 5, 6}
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

In [5]:
@jax.jit
def pin_0_move(env, steps):
    pin = 0
    move = steps # action is in {1, 2, 3, 4, 5, 6}
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

@jax.jit
def pin_1_move(env, steps):
    pin = 1
    move = steps
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

@jax.jit
def pin_2_move(env, steps):
    pin = 2
    move = steps
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

@jax.jit
def pin_3_move(env, steps):
    pin = 3
    move = steps
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

# @jax.jit
def step(env:deterministic_MADN, action_idx:Action) -> deterministic_MADN:
    pin = (action_idx // 6).astype(jnp.int8)
    move = (action_idx % 6 + 1).astype(jnp.int8)
    return jax.lax.switch(
        pin,
        [
            pin_0_move,
            pin_1_move,
            pin_2_move,
            pin_3_move,
        ],
        env,
        move
    )

In [6]:
def pin_0_move2(env, steps):
    pin = 0
    move = steps # action is in {1, 2, 3, 4, 5, 6}
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

def pin_1_move2(env, steps):
    pin = 1
    move = steps
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

def pin_2_move2(env, steps):
    pin = 2
    move = steps
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

def pin_3_move2(env, steps):
    pin = 3
    move = steps
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~valid_action(env)[pin, move-1]

    current_positions = env.pins[current_player, pin]
    moved_positions = current_positions + move
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]

    new_position = jnp.where(
        current_positions == -1,
        env.start[current_player], # move from start area to starting position
        jnp.where(jnp.isin(current_positions, env.goal[current_player]),
                    moved_positions,
                    jnp.where(
                        (4 >= x) & (x > 0) & (env.board[env.goal[current_player, x-1]] != current_player) & (current_positions <= env.target[current_player]),
                        env.goal[current_player, x-1], # move to goal position
                        fitted_positions
                    )
        )
    )
    
    # update pins
    # pin at new position
    pin_at_pos = env.board[new_position]
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = env.pins.at[current_player, pin].set(jnp.where(invalid_action, env.pins[current_player, pin], new_position))
    pins = jax.lax.cond(
        (pin_at_pos != -1) & (pin_at_pos != current_player) & ~invalid_action, # if a player was at the new position and it's not the current player and the action is valid
        lambda p: p.at[pin_at_pos].set(jnp.where(p[pin_at_pos] == new_position, -1, p[pin_at_pos])), # send the pin of that player back to start area
        lambda p: p,
        pins
    )
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    # update action set, change one instance of the played action to 0 (no action)
    curr_state = env.action_set.at[current_player, move-1].get()
    action_set = env.action_set.at[current_player, move-1].set(jnp.where(invalid_action | (curr_state == 0), curr_state, curr_state-1))
    action_set = jax.lax.cond(
        jnp.all(action_set[current_player] == 0), # if all actions are 0, refill the action set
        lambda a: refill_action_set(env),
        lambda a: a,
        action_set
    )
    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (move == 6), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = deterministic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        action_set=action_set,#/
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
    )
    return env, reward, done

@jax.jit
def step2(env:deterministic_MADN, action_idx:Action) -> deterministic_MADN:
    pin = (action_idx // 6).astype(jnp.int8)
    move = (action_idx % 6 + 1).astype(jnp.int8)
    return jax.lax.switch(
        pin,
        [
            pin_0_move2,
            pin_1_move2,
            pin_2_move2,
            pin_3_move2,
        ],
        env,
        move
    )

In [7]:
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(12))

In [8]:
%timeit -n 10000 env2, reward, done = env_step(env, jnp.array([1, 6], dtype=jnp.int8))

154 μs ± 23.5 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [10]:
%timeit -n 1000 env3, reward, done = step(env, jnp.array(11, dtype=jnp.int8))

1.6 ms ± 56.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
%timeit -n 10000 env4, reward, done = step2(env, jnp.array(11, dtype=jnp.int8))

169 μs ± 37 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
