In [1]:
from classic_madn import *
import jax
import jax.numpy as jnp
from visualize_madn import *

## Swap Logic

In [2]:
# returns a boolean array indicating valid swap positions
def valid_swap(env):
    current_player = env.current_player
    Num_players = env.num_players
    current_pins = env.pins[current_player]
    board = env.board
    target = env.target[current_player]
    goal = env.goal[current_player]

    swap_mat = jnp.tile(board[:-Num_players*4], (4,1))
    
    return jnp.where(
        ~jnp.isin(swap_mat, jnp.array([-1, current_player])),
        True,
        False) & (current_pins != -1)[:, None]

#test valid swap
# empty for initial state
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
# empty if opponent has all pins in base
env.pins = jnp.array([[8, 9, 10, 11],
                      [-1, -1, -1, -1]])
env.board = set_pins_on_board(env.board, env.pins)
print("Ausgabe 1:\n", valid_swap(env)) 
# empty if opponent has only pins in goal
env.pins = jnp.array([[8, 9, 10, 11],
                      [-1, 26, 25, 24]])
env.board = set_pins_on_board(env.board, env.pins)
print("Ausgabe 2:\n", valid_swap(env)) 
# only positions with opponent pins
env.pins = jnp.array([[0, -1, 2, -1],
                      [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print("Ausgabe 3:\n", valid_swap(env)) 


Ausgabe 1:
 [[False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]]
Ausgabe 2:
 [[False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False 

In [3]:
def swap_action(env, pin_idx, swap_pos):
    current_player = env.current_player
    invalid_action = ~valid_swap(env)[pin_idx, swap_pos]
    
    swapped_player = env.board[swap_pos]
    pin_pos = env.pins[current_player, pin_idx]
    board = env.board.at[swap_pos].set(current_player)
    board = board.at[pin_pos].set(swapped_player)
    pins = env.pins.at[current_player, pin_idx].set(swap_pos)
    pins = pins.at[swapped_player, jnp.where(pins[swapped_player] == swap_pos)].set(pin_pos)

    return jax.lax.cond(
        invalid_action,
        lambda: env.board,
        lambda: board
    )

#test swap action
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[0, -1, 2, -1],
                      [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))
#invalid swap
new_board = swap_action(env, pin_idx=0, swap_pos=1)
print("Invalid Swap Attempt:")
print(matrix_to_string(board_to_matrix(env)))
#valid swap
new_board = swap_action(env, pin_idx=0, swap_pos=10)
env.board = new_board
print("Valid Swap:")
print(matrix_to_string(board_to_matrix(env)))

                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m♠[0m  □  [94m♠[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Invalid Swap Attempt:
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m♠[0m  □  [94m♠[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Valid Swap:
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [91m♥[0m  □  [94m♠[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [94m♠[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   



## -4 Action

In [4]:
#Check if -4 modulo works, then normal step should be fine
assert (0-4) % 52 == 48  # should be 48

## Hot 7 action (hit all players on the path and distribute actions on all pins)

In [5]:
all_pin_distributions(7, 4)

Array([[0, 0, 0, 7],
       [0, 0, 1, 6],
       [0, 0, 2, 5],
       [0, 0, 3, 4],
       [0, 0, 4, 3],
       [0, 0, 5, 2],
       [0, 0, 6, 1],
       [0, 0, 7, 0],
       [0, 1, 0, 6],
       [0, 1, 1, 5],
       [0, 1, 2, 4],
       [0, 1, 3, 3],
       [0, 1, 4, 2],
       [0, 1, 5, 1],
       [0, 1, 6, 0],
       [0, 2, 0, 5],
       [0, 2, 1, 4],
       [0, 2, 2, 3],
       [0, 2, 3, 2],
       [0, 2, 4, 1],
       [0, 2, 5, 0],
       [0, 3, 0, 4],
       [0, 3, 1, 3],
       [0, 3, 2, 2],
       [0, 3, 3, 1],
       [0, 3, 4, 0],
       [0, 4, 0, 3],
       [0, 4, 1, 2],
       [0, 4, 2, 1],
       [0, 4, 3, 0],
       [0, 5, 0, 2],
       [0, 5, 1, 1],
       [0, 5, 2, 0],
       [0, 6, 0, 1],
       [0, 6, 1, 0],
       [0, 7, 0, 0],
       [1, 0, 0, 6],
       [1, 0, 1, 5],
       [1, 0, 2, 4],
       [1, 0, 3, 3],
       [1, 0, 4, 2],
       [1, 0, 5, 1],
       [1, 0, 6, 0],
       [1, 1, 0, 5],
       [1, 1, 1, 4],
       [1, 1, 2, 3],
       [1, 1, 3, 2],
       [1, 1,

In [6]:
def get_all_paths_compact(start, end, N):
    """
    Berechnet alle Pfadpositionen zwischen start und end Positionen.
    Bei gegebenem N wird modulo N gerechnet für Rundbretter.
    
    Args:
        start: Array der Startpositionen
        end: Array der Endpositionen  
        N: Board-Größe für Modulo-Rechnung (optional)
    
    Returns:
        Array aller Pfadpositionen zwischen start und end (exklusive start, inklusive end)
    """
    valid_mask = end != start

    use_modulo = (start < N) & (end < N)

    # Modulo-Logik für Rundbrett
    distance = (end - start) % N
    distance = jnp.where(distance == 0, N, distance)  # Vollrunde = N Schritte
    distance = jnp.where(valid_mask, distance, 0)  # Keine Bewegung wenn start == end
    
    max_len = jnp.max(distance)
    
    # Erstelle alle möglichen Pfade für alle start/end Paare
    i, j = jnp.meshgrid(jnp.arange(len(start)), jnp.arange(max_len), indexing='ij')
    path_values_normal = start[i] + j + 1
    path_values_modulo = (start[i] + j + 1) % N
    path_values = jnp.where(use_modulo[i], path_values_modulo, path_values_normal)
    
    valid_positions = valid_mask[i] & (j < distance[i])
    
    return path_values, valid_positions

In [7]:
def calc_paths(start, end, goal, target, N):
    '''
    Berechnet alle Pfadpositionen für Pins, die sich von start zu end bewegen.
    
    Args:
        start: (4,) Array der Startpositionen
        end: (4,) Array der Endpositionen  
        goal: (4,) Array der Goal-Positionen
        target: int Target-Position (Eingang zum Goal-Bereich)
        N: int Board-Größe
    '''
    A = jnp.isin(start, goal)  # start in goal
    B = jnp.isin(end, goal)    # end in goal  

    # Berechne alle Pfade für same area (both in goal or both not in goal)
    same_area_condition = A == B
    same_area_paths, same_area_mask = get_all_paths_compact(start, end, N)
    same_area_valid = same_area_condition[:, None] & same_area_mask
    
    # Berechne Pfade für different area (traverse to goal)
    diff_area_condition = A != B
    
    # Pfade bis zum Target für Pins die ins Goal wechseln
    target_array = jnp.full_like(end, target)
    diff_area_paths_to_target, diff_area_to_target_mask = get_all_paths_compact(start, target_array, N)
    diff_area_to_target_valid = diff_area_condition[:, None] & diff_area_to_target_mask
    
    # Kombiniere alle gültigen Pfadpositionen
    all_same_area = same_area_paths[same_area_valid]
    all_diff_area_to_target = diff_area_paths_to_target[diff_area_to_target_valid]
    
    # Goal-Positionen für Übergänge ins Goal
    transition_to_goal = diff_area_condition
    if jnp.any(transition_to_goal):
        goal_start = goal[0]
        goal_end = jnp.max(jnp.where(transition_to_goal, end, goal[0]))
        goal_range = jnp.arange(goal_start, goal_end + 1, dtype=jnp.int8)
        all_path_positions = jnp.concatenate([all_same_area, all_diff_area_to_target, goal_range])
    else:
        all_path_positions = jnp.concatenate([all_same_area, all_diff_area_to_target])
    
    return jnp.unique(all_path_positions)

In [8]:
def prototype(start, end, goal, target, N):
    '''
    start: (4,) Array der Startpositionen
    end: (4,) Array der Endpositionen
    goal: (4,) Array der Goal-Positionen
    target: int Target-Position (Eingang zum Goal-Bereich)
    N: int Board-Größe
    '''
    x = jnp.array([start, end])
    A = jnp.isin(start, goal)  # start in goal
    B = jnp.isin(end, goal)    # end in goal
    
    paths = []
    for i in range(4):
        if A[i] == B[i]:
            p, _ = get_all_paths_compact(jnp.array([start[i]]), jnp.array([end[i]]), N)
            paths.append(p[0])
        else:
            p, _ = get_all_paths_compact(jnp.array([start[i]]), jnp.array([target]), N)
            goal_range = jnp.arange(goal[0], end[i] + 1, dtype=jnp.int8)
            paths.append(jnp.concatenate([p[0], goal_range]))

    other_paths_0 = jnp.concatenate([jnp.concatenate(paths[1:3]), paths[3]])
    other_paths_1 = jnp.concatenate([paths[0], jnp.concatenate(paths[2:])])
    other_paths_2 = jnp.concatenate([jnp.concatenate(paths[:2]), paths[3]])
    other_paths_3 = jnp.concatenate(paths[:3])
    
    a = jnp.all(jnp.isin(jnp.array([start[0], end[0]]), other_paths_0))
    b = jnp.all(jnp.isin(jnp.array([start[1], end[1]]), other_paths_1))
    c = jnp.all(jnp.isin(jnp.array([start[2], end[2]]), other_paths_2))
    d = jnp.all(jnp.isin(jnp.array([start[3], end[3]]), other_paths_3))
    return jnp.array([a,b,c,d])



In [9]:
start = jnp.array([16, 17, 20, 21])
end = jnp.array([20, 17, 20, 23])
goal = jnp.array([20, 21, 22, 23])
calc_paths(start, end, goal, target=19, N=20)

Array([17, 18, 19, 20, 22, 23], dtype=int32)

In [10]:
start = jnp.array([16, 17, 0, 20])
end = jnp.array([22, 17, 2, 20])
x = jnp.array([[16, 17, 0, 20],
                [22, 17, 2, 20]])

prototype(start, end, goal=jnp.array([20,21,22,23]), target=19, N=20)
# should be False, True, False, True

Array([False,  True, False,  True], dtype=bool)

In [11]:
start = jnp.array([8, 3, 18, 3])
end = jnp.array([25, 6, 5, 9])
goal = jnp.array([24, 25, 26, 27])
paths = calc_paths(start, end, goal, 9, 20)
paths

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 19, 24, 25], dtype=int32)

In [12]:
jnp.isin(jnp.array([[12, 2, 1, 4],
               [14,2,3,7]]), paths)

Array([[False,  True,  True,  True],
       [False,  True,  True,  True]], dtype=bool)

In [13]:
def val_action_7(env:classic_MADN, seven_dist) -> chex.Array:
    '''
    Returns a mask of shape (4, ) indicating which actions are valid for each pin of the current player
    '''
    #return valid_action for each pin of the current player
    current_player = env.current_player
    board = env.board
    target = env.target[current_player]
    goal = env.goal[current_player]

    # calculate possible actions
    current_positions = env.pins[current_player]
    moved_positions = current_positions + seven_dist
    fitted_position = moved_positions % env.board_size
    x = moved_positions - target


    # Überlaufen der Zielposition verhindern falls kein Rundbrett
    result = jax.lax.cond(
        env.rules['enable_circular_board'],
        lambda: jnp.ones_like(current_positions, dtype=bool),
        lambda: ~((current_positions <= target) & (moved_positions > (target + 4)))
    )
    result = jnp.where(
        (4 >= x) & (x > 0) & (current_positions <= env.target[current_player]),
        (env.rules["enable_circular_board"] | result),#(env.rules["enable_circular_board"] & result) | (board[goal[x-1]] != current_player), # if goal is possible, check if goal position is free
        result
    )
    # filter actions for pins in goal area
    result = jnp.where(
        jnp.isin(current_positions, goal),
        (moved_positions <= goal[-1]),# & (board[moved_positions%env.total_board_size] != current_player),
        result
    )

    # alle Aktionen müssenrechenrisch möglich sein und es dürfen keine zwei Pins auf die gleiche Position ziehen
    board_mover = jnp.where(current_positions == -1, moved_positions==-1, True)# prüfe dass kein pin im startbereich bewegt werden würde 

    return jnp.all(result & board_mover) #& (jnp.unique(moved_positions).size == current_positions.size) & (jnp.sum(seven_dist)<=7)# cannot move pins in start area

def step_7(env: classic_MADN, seven_dist:Action) -> classic_MADN:
    current_player = env.current_player
    # check if the action is valid
    invalid_action = ~jnp.all(val_action_7(env, seven_dist))
    print("Invalid Action:", invalid_action)
    current_pins = env.pins
    current_positions = current_pins[current_player]
    moved_positions = current_positions + seven_dist
    fitted_positions = moved_positions % env.board_size
    x = moved_positions - env.target[current_player]
    
    
    new_positions = jnp.where(
        current_positions == -1,
        current_positions,
        jnp.where(
            jnp.isin(current_positions, env.goal[current_player]),
            moved_positions,
            jnp.where(
                (4 >= x) & (x > 0) & (current_positions <= env.target[current_player]),
                env.goal[current_player, x-1], # move to goal position
                fitted_positions
            )
        ))
        
    # update pins
    # Liste von abgelaufenen Feldern. Jede Figur die in diesen Feldern ist wird zurück geschickt
    # bei den figuren des aktuellen Spielers muss die alte und neue position abgedeckt werden
    # Zielbereiche müssen extra behandelt werden
    pins = current_pins.at[current_player].set(jnp.where(invalid_action, current_pins[current_player], new_positions))
    hit_paths = calc_paths(current_positions, new_positions, env.goal[current_player], env.target[current_player], env.board_size)
    hit_pins = jnp.isin(env.pins, hit_paths)
    curr_pins_hit = prototype(current_positions, new_positions, env.goal[current_player], env.target[current_player], env.board_size)
    hit_pins = hit_pins.at[current_player].set(curr_pins_hit)
    # if a player is at the new position and it's not the current player, send that pin back to start area
    pins = jnp.where(
        hit_pins & ~invalid_action,
        pins.at[jnp.where(hit_pins)].set(-1),
        pins
    )
    
    board = jax.lax.cond(
        ~invalid_action,
        lambda b: set_pins_on_board(-jnp.ones_like(b, dtype=jnp.int8), pins),
        lambda b: b,
        env.board
    )

    winner = get_winner(board, env.goal)
    reward = jnp.array(jnp.where(env.done, 0, jnp.where(invalid_action, -1, winner==current_player)), dtype=jnp.int8) # reward is 0 if game is done, -1 if action is invalid, else the index of the winning player (1-4) or 0 if no winner yet
    # check if the game is done
    done = env.done | jnp.where(winner != -1, True, False)
    # player changes on invalid action
    current_player = jnp.where(done | (env.rules['enable_bonus_turn_on_6']), current_player, (current_player + 1) % env.num_players) # if the game is not done or the player played a 6, switch to the next player

    env = classic_MADN(
        board=board,
        num_players=env.num_players,#/
        pins=pins,#
        current_player=current_player,
        done= done,#
        reward=reward,#
        start=env.start,#
        target=env.target,#
        goal=env.goal,#
        die=env.die,#
        board_size=env.board_size,#
        total_board_size=env.total_board_size,#
        rules=env.rules,#/
    )
    return env, reward, done

In [14]:
# test val_action_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[0, -1, 2, 1],
                      [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print(val_action_7(env, jnp.array([2,1,6,3])))
print(val_action_7(env, jnp.array([6,0,0,1])))

False
True


In [15]:
#test step_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[7, 2, 1, 4],
                        [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))
env, reward, done = step_7(env, jnp.array([3,0, 1, 3]))
print(matrix_to_string(board_to_matrix(env)))
print(env.pins)

                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  [94m♠[0m  [94m♠[0m  □  [94m♠[0m  □  □  [94m♠[0m  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Invalid Action: False
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  □  [94m♠[0m  □  □  □  □  [94m♠[0m  □  □    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [94m♠[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

[[10 -1  2  7]
 [-1 -1 -1 11]]


In [16]:
#test step_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[0, -1, 2, 1],
                      [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))
env, reward, done = step_7(env, jnp.array([6,0,0,1]))
print(matrix_to_string(board_to_matrix(env)))
print(env.pins)

                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m♠[0m  [94m♠[0m  [94m♠[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Invalid Action: False
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  □  □  □  □  □  [94m♠[0m  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

[[ 6 -1 -1 -1]
 [ 8  9 10 11]]


In [17]:
#test step_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[0, -1, 2, 1],
                      [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))
env, reward, done = step_7(env, jnp.array([4,0,2,2]))
print(matrix_to_string(board_to_matrix(env)))
print(env.pins)

                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m♠[0m  [94m♠[0m  [94m♠[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Invalid Action: False
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  □  □  □  [94m♠[0m  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

[[ 4 -1 -1 -1]
 [ 8  9 10 11]]


In [18]:
#test step_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[-1, -1, 2, 1],
                      [16, 17, 0, -1]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))
env.current_player = 1
env, reward, done = step_7(env, jnp.array([1,4,0,0]))
print(matrix_to_string(board_to_matrix(env)))
print(env.pins)

                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [91m♥[0m  [94m♠[0m  [94m♠[0m  □  □  □  □  □  □  □    
                                    
    □  □  [91m♥[0m  [91m♥[0m  □  □  □  □  □  [91m□[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Invalid Action: False
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  [91m♥[0m  [94m♠[0m  □  □  □  □  □  □  □    
                                    
    □  □  [91m♥[0m  □  □  □  □  □  □  [91m□[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

[[-1 -1  2 -1]
 [17  1 -1 -1]]


In [19]:
#test step_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[16, 17, 0, 20],
                      [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))
env, reward, done = step_7(env, jnp.array([8,0,2,0]))
print(matrix_to_string(board_to_matrix(env)))
print(env.pins)

                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m♠[0m  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  [94m♠[0m  [94m♠[0m  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m♠[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Invalid Action: False
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  □  □  □  [94m♠[0m  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m♠[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

[[ 4 -1 -1 20]
 [ 8  9 10 11]]


In [20]:
#test step_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[16, 17, 0, 20],
                      [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))
env, reward, done = step_7(env, jnp.array([5,0,2,0]))
print(matrix_to_string(board_to_matrix(env)))
print(env.pins)

                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m♠[0m  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  [94m♠[0m  [94m♠[0m  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m♠[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Invalid Action: False
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  □  [94m♠[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m♠[0m  [94m□[0m  [94m□[0m                   

[[21 -1  2 -1]
 [ 8  9 10 11]]


In [21]:
#test step_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[21, 17, 0, 20],
                      [8, 9, 10, 11]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))
env, reward, done = step_7(env, jnp.array([2,4,2,1]))
print(matrix_to_string(board_to_matrix(env)))
print(env.pins)

                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m♠[0m  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  [94m♠[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m♠[0m  [94m♠[0m  [94m□[0m  [94m□[0m                   

Invalid Action: False
                   [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  □  [94m♠[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  [91m♥[0m  [91m♥[0m    
       [94m□[0m  [94m♠[0m  [94m□[0m  [94m♠[0m                   

[[23 21  2 -1]
 [ 8  9 10 11]]


In [22]:
#test step_7
env = env_reset(0, num_players=jnp.int8(2), distance=jnp.int8(10))
env.pins = jnp.array([[5,6,7,8],
                      [4, 9, 24, 25]])
env.board = set_pins_on_board(env.board, env.pins)
env.current_player = 1
print(matrix_to_string(board_to_matrix(env)))
env, reward, done = step_7(env, jnp.array([7,0,0,0]))
print(matrix_to_string(board_to_matrix(env)))
print(env.pins)

                   [91m□[0m  [91m□[0m  [91m♥[0m  [91m♥[0m       
    [94m□[0m  □  □  □  [91m♥[0m  [94m♠[0m  [94m♠[0m  [94m♠[0m  [94m♠[0m  [91m♥[0m    
                                    
    □  □  □  □  □  □  □  □  □  [91m□[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

Invalid Action: False
                   [91m□[0m  [91m□[0m  [91m♥[0m  [91m□[0m       
    [94m□[0m  □  □  □  □  □  □  □  □  □    
                                    
    □  □  □  □  □  □  □  □  □  [91m□[0m    
       [94m□[0m  [94m□[0m  [94m□[0m  [94m□[0m                   

[[-1 -1 -1 -1]
 [25 -1 -1 -1]]


In [23]:
env = env_reset(0, num_players=jnp.int8(3), distance=jnp.int8(10))
print(env.goal)
env.pins = jnp.array([[21, 17, 0, 20],
                      [8, 9, 10, 11],
                      [38, 1, 2, -1]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))

[[30 31 32 33]
 [34 35 36 37]
 [38 39 40 41]]
                      [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m♠[0m  [93m♦[0m  [93m♦[0m  □  □  □  □  □  [91m♥[0m  [91m♥[0m  [91m♥[0m    
 [94m□[0m  □  .  .  .  .  .  .  .  .  [91m♥[0m  .    
 [94m□[0m  □  .  .  .  .  .  .  .  □  .  .    
 [94m□[0m  □  .  .  .  .  .  .  □  .  .  .    
 [94m□[0m  □  .  .  .  .  .  □  .  .  .  .    
    □  .  .  .  .  □  .  .  .  .  .    
    □  .  .  .  □  .  .  .  .  .  .    
    □  .  .  [94m♠[0m  .  .  .  .  .  .  .    
    □  .  □  .  .  .  .  .  .  .  .    
    [94m♠[0m  □  [93m♦[0m  [93m□[0m  [93m□[0m  [93m□[0m  .  .  .  .  .    
    [94m♠[0m  .  .  .  .  .  .  .  .  .  .    
                                       



In [24]:
env = env_reset(0, num_players=jnp.int8(4), distance=jnp.int8(16))
print(env.goal)
env.pins = jnp.array([[21, 17, 42, 20],
                      [8, 45, 46, 47],
                      [1,2,50,51],
                      [5,54,53,3]])
env.board = set_pins_on_board(env.board, env.pins)
print(matrix_to_string(board_to_matrix(env)))

[[64 65 66 67]
 [68 69 70 71]
 [72 73 74 75]
 [76 77 78 79]]
                                        [91m□[0m  [91m□[0m  [91m□[0m  [91m□[0m       
    [94m□[0m  [93m♦[0m  [93m♦[0m  [92m♣[0m  □  [92m♣[0m  □  □  [91m♥[0m  □  □  □  □  □  □  □  [91m□[0m    
 [94m□[0m  □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  [94m♠[0m    
 [94m□[0m  □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  □    
 [94m□[0m  □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  □    
 [94m□[0m  □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  [94m♠[0m    
    □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  [94m♠[0m    
    □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  □    
    □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  □    
    □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  □    
    □  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  □    
    [92m♣[0m  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  □    
    [92m♣[0m  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  □    
    □