<a href="https://colab.research.google.com/github/mcnica89/MATH4060/blob/main/FinalProject_Cant_Stop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import itertools
import jax.numpy as jnp
from jax import random as jrandom
from jax import nn as jnn
from jax import jit
import random
import time
import sys
import jax

In [2]:
'''Implementation of the board game "CANT STOP"

How the game is represented in Python:

-----Game parameters (that do not change while the game is being played)

N_PLAYERS
   A postive integer
  This is the number of players playing
  In the classic game rules, N_Players = 4

N_COL_TO_WIN
   A positive integer
   This is the number of columns you need to claim to win the game
   In the classic game rules, N_Col_To_Win = 3

N_MAX_RUNNERS
   A positive integer
   This is the maximum number of runners you can have
   In the classic game rules, N_Max_Runners = 3

PLAYER_COL_STATE_INIT
   An vector of shape (11,) of non-negative integers
   This is the number of squares in each game column
   In the classic game rules, this is [3,5,7,9,11,13,11,9,7,5,3]

NOTE on column labelling:
   In the game, the columns are labeled 2-12 (corresponding to dice rolls)
   In Python, the columns locations are indexed 0-10
   This means that to translate from column in Python to columns in the game,
   one must often add or subtract 2 from the column indices. 

---Variables: (that represent what is going on in the game as it is played)

active_player_index
   An index from the range [0,N_PLAYERS-1] indicating whose turn it current is

player_col_state
   An array of shape (N_players,11) of integers
   Each row is the number of squares remaining for that player in each col
   NOTE: 
     This is the number of squares REMAINING, these start at PLAYER_COL_STATE_INIT
     and count DOWN to zero as the game progresses. When the get to zero, the player has claimed the column
   WARNING:
     We will not prohibit these from being negative even though it doesn't mean anything in the game
     (This can happen if the player goes past the number needed to claim the column)

illegal_col
   A vector of shape (11,) of boolean
   Contains the information on which columns are still in play
   (columns that have been claimed by a player are not legal to play in anymore)

runner_col_state
  A vector of shape (11,) of non-negative integers
  Indicates the current state of how far the runners have advanced in each column
  A zero indicates that there is no runner in that column at all
  NOTES:
   1. count_nonzero(runner_col) should not exceed N_Max_Runners for legal runner states
   2. Since player_col_state counts DOWN to 0, runner_col is SUBTRACTED from player_col_state when the player chooses to stop rolling

dice_rolls
  A vector of shape (4,) of integers [1,6] indicating the outcome of the 4 dice rolls

runner_col_choices
  A vector of shape (N_choices, 11) of non-negative integers
  Indicates the available CHOICES the player has of where the runners could be
  This corresponds to legal choices for choosing pairings of the dice
  NOTES:
    1. By the rules of the game, N_choices can be at most 6
    2. If N_choices = 0, then this indicates that there are no legal moves and the player has busted

roll_again
  A boolean on whether or not the player wants to rolls again'''



# Helper functions

In [3]:
def randomDice(random_key):
  '''Creates a random roll of 4 dice'''
  #Input: 
  #  random_key = a jax random key (used to generate dice rolls)
  #Output: An array of size (4,) with 4 random 6 sided dice rolls
  return jrandom.randint(random_key, [4], 1, 7, dtype=jnp.dtype('u1')) #Note: the range [1,6] is [1,7) in JAX

In [4]:
@jit
def are_runners_legal(runner_col_states, illegal_col, N_MAX_RUNNERS=3):
  '''Checks if a batch of runner states are legal or not'''
  #Input:
  #  runner_col_states = an int vector of size (N,11) of runner positions
  #  illegal_col = a boolean vector of size (11,) with which columns are illegal
  #Output:
  #  a boolean vector of size (N,) with which of the runner_col_states are legal

  #Number of runners is legal iff there are <=N_MAX_RUNERS runners active:
  are_number_of_runners_legal = (jnp.count_nonzero(runner_col_states,axis=1) <= N_MAX_RUNNERS)

  #Check if all the runners are in legal columns
  #  In each column, either illegal_column must be 0 OR runners must be 0
  illegal_col_or_runner_is_0 = jnp.logical_or(runner_col_states == 0, illegal_col == False)
  #  This must happen in every single column
  are_runners_in_legal_col = jnp.all(illegal_col_or_runner_is_0,axis=1) 
  
  return jnp.logical_and(are_number_of_runners_legal,are_runners_in_legal_col)

def are_runners_legal_tests():
  '''Tests for the are_runners_legal function'''
  
  illegal_col = jnp.array([0,1,0,1,0,1,0,1,0,1,0],dtype=jnp.dtype('b'))
  runner_col_states = jnp.array([[1,1,0,0,0,0,0,0,0,0,0],[1,1,1,1,0,0,0,0,0,0,0],[1,0,1,0,1,0,1,0,0,0,0],[1,0,1,0,1,0,0,0,0,0,0]],dtype=jnp.dtype('u1'))
  N_MAX_RUNNERS = 3

  #In this test:
  #0: runner in illegal column -> not legal
  #1: runner in illegal column AND to many runners -> not legal
  #2: runner in all legal columns but too many runners -> not legal
  #3: runner in all legal columns and exactly 3 runners -> legal

  assert jnp.all(are_runners_legal(runner_col_states,illegal_col,N_MAX_RUNNERS) == jnp.array([False,False,False,True])) 
  return True

assert are_runners_legal_tests()



In [5]:
@jit
def generate_all_choices_and_legality(dice_rolls,illegal_col,runner_col_state, N_MAX_RUNNERS=3): 
  '''Computes out ALL the possible moves based on the dice and 
     whether or not they are legal based on the current state and dice'''
  #Input: 
  #  dice_rolls = a vector of size (4,) with 4 dice rolls
  #  illegal_col = a boolean vector of size (11,) with which columns are illegal
  #  runner_col_state = an int vector of size (11,) of the the current runners
  #Output: A tuple with two components
  #  An array of size (9, 11) where each row is an option for where the runners would be
  #  An boolean vector of size (9,) where each entry tells you if the move is legal or not
  #  NOTE:
  #     The size is 9 rows because there are 9 possible dice pairings in Can't Stop:
  #     1. 3 ways to pair the dice and use BOTH pairings (i.e. {{1,2},{3,4}}, {{1,3},{2,4}}, {{1,4},{2,3}} )
  #     2. 6 ways to use a single pair of dice  (i.e. {1,2},{1,3},{1,4},{2,3},{2,4},{3,4} )
  #     (In the code the 6 single pairings are further subdivided into 3 pairings that involve the first dice and 3 pairings that don't)
  #       (i.e. pairings with_1: {1,2},{1,3},{1,4} and pairings without_1: {2,3},{2,4},{3,4}
  #     (This is done to make it convenient to match up with the 3 ways to use BOTH pairings)
  #     Note that you can only choose to use the single pair of dice in the situation 
  #     that using both would be illegal. This means that at most 6 of the 9 moves can every be True

  #Find all the sums of the possible ways to pair the dice that includes the first dice
  #  This matrix is a hardcoded way to get the 3 pairings of 4 dice that involve first entry 
  #  (Its a matrix multiplication b/c we are doing linear combinations of the four dice!)
  pairing_matrix_with_1 = jnp.array([[1,1,0,0],[1,0,1,0],[1,0,0,1]],dtype=jnp.dtype('u1'))
  #All the other pairings (ones that don't involve the first dice)
  #(Its important that these are synchronized to be the complement of pairing_matrix_with_1) 
  pairing_matrix_without_1 = 1 - pairing_matrix_with_1 

  #Possible sums you get by pairings that involve the first dice
  dice_sums_with_1 = jnp.matmul(pairing_matrix_with_1,dice_rolls) 
  dice_sums_without_1 = jnp.matmul(pairing_matrix_without_1,dice_rolls)

  #Convert from the dice numbers to the column notation 
  #   (We subtract 2 here since dice roll 2 is column 0 in our notation)
  dice_sums_with_1_cols = jnn.one_hot(dice_sums_with_1 - 2, 11,dtype=jnp.dtype('u1')) 
  dice_sums_without_1_cols = jnn.one_hot(dice_sums_without_1 - 2, 11,dtype=jnp.dtype('u1'))

  #Calculate all the 9 possible moves of playing both pairs (i.e. double) and with any single pair
  # (We will work out which are legal moves afterwards!)
  double_runner_choices = runner_col_state + dice_sums_with_1_cols + dice_sums_without_1_cols
  single_runner_choices_with_1 = runner_col_state + dice_sums_with_1_cols 
  single_runner_choices_without_1 = runner_col_state + dice_sums_without_1_cols

  #Compute if the choices with both pairing played (i.e. double) are legal
  are_double_runners_legal = are_runners_legal(double_runner_choices,illegal_col, N_MAX_RUNNERS)
  are_double_runners_illegal = jnp.logical_not(are_double_runners_legal)

  #The moves with a single pair are only legal if the corresponding move with both pairs is illegal 
  #  (i.e. its legal to play only one pair iff after you play it, playing the next move is not legal)
  #  This means we can compute if they are legal on their own first and then
  #  logical_and it with the double runners

  #  first check if they would be ok on their own.
  are_single_runners_with_1_ok = are_runners_legal(single_runner_choices_with_1,illegal_col, N_MAX_RUNNERS)
  are_single_runners_without_1_ok = are_runners_legal(single_runner_choices_without_1,illegal_col, N_MAX_RUNNERS)

  #  then we logical and it with the double runners to only legalize these moves if playing both was illegal
  are_single_runners_with_1_legal = jnp.logical_and(are_double_runners_illegal,are_single_runners_with_1_ok)
  are_single_runners_without_1_legal = jnp.logical_and(are_double_runners_illegal,are_single_runners_without_1_ok)

  #Combine everything together to be outputed
  all_runner_choices = jnp.row_stack([double_runner_choices,single_runner_choices_with_1,single_runner_choices_without_1])
  all_runner_choices_legal = jnp.concatenate([are_double_runners_legal, are_single_runners_with_1_legal, are_single_runners_without_1_legal])

  return all_runner_choices, all_runner_choices_legal

def generate_legal_choices(dice_rolls,illegal_col,runner_col_state, N_MAX_RUNNERS=3):
  #Purpose:
  #  Returns only the legal choices that can be played given the current state 
  #  and dice rolled
  #Input: 
  #  dice_rolls = a vector of size (4,) with 4 dice rolls
  #  illegal_col = a boolean vector of size (11,) with which columns are illegal
  #  runner_col_state = an int vector of size (11,) of the the current runners
  #Output: 
  #  An array of shape (N_choices, 11) with possible places the runners could be
  #  Note that N_choices could be 0 and is definetly less than or equal to 6
  choices, legal = generate_all_choices_and_legality(dice_rolls,illegal_col,runner_col_state, N_MAX_RUNNERS)
  return jnp.unique(choices[jnp.where(legal)],axis=0)

def generate_choices_tests():
  #Test the generate choices functions

  dice_rolls = jnp.array([1,2,3,4],dtype=jnp.dtype('u1'))
  runner_col_state = jnp.array([0,0,1,0,0,0,0,0,0,1,1],dtype=jnp.dtype('u1'))
  illegal_col = jnp.array([0,0,0,0,0,0,0,0,0,0,0],dtype=jnp.dtype('u1'))
  N_MAX_RUNNERS = 3

  expected_answer = jnp.array([[0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 1]], dtype=jnp.dtype('u1'))
  assert jnp.all(generate_legal_choices(dice_rolls,illegal_col,runner_col_state, N_MAX_RUNNERS) == expected_answer)
  
  dice_rolls = jnp.array([1,2,3,4],dtype=jnp.dtype('u1'))
  runner_col_state = jnp.array([0,0,0,0,0,0,0,0,0,0,0],dtype=jnp.dtype('u1'))
  illegal_col = jnp.array([0,1,0,0,0,0,0,0,0,0,0],dtype=jnp.dtype('u1'))
  expected_answer = jnp.array([[0,0,0,0,0,1,0,0,0,0,0],[0,0,0,2,0,0,0,0,0,0,0],[0,0,1,0,1,0,0,0,0,0,0]], dtype=jnp.dtype('u1'))
  N_MAX_RUNNERS = 3

  assert jnp.all(generate_legal_choices(dice_rolls,illegal_col,runner_col_state, N_MAX_RUNNERS) == expected_answer)
  
  return True

assert generate_choices_tests()

In [6]:
@jit
def calculate_player_N_col_claimed(player_col_state):
  ''' Calculate player "scores" (i.e. number of columns claimed) from the board state'''
  #Input:
  #  player_col_state = An int array of size (N_players, 11) showing how many entries REMAINING until column is claimed for each player
  #Output: 
  #  An int vector of size (N_players,) showing how many columns each player has claimed. (In normal rules, first to 3 columns wins) 
  return  jnp.count_nonzero(player_col_state <= 0, axis=1)

@jit
def calculate_illegal_col(player_col_state):
  '''Calculate which columns are legal from the board state (i.e. the unclaimed columns)''' 
  #Input:
  #  player_col_state = An int array of size (N_players, 11) showing how many entries REMAINING until column is claimed for each player
  #Output: 
  #  An boolean vector of size (11,) showing which columns are legal 
  return jnp.any(player_col_state <= 0, axis=0)

def calculate_col_tests():
  '''Test the calculate columns helper functions'''
  players_col_state = jnp.array([[2,3,4,5,6,7,8,9,10,11,12],[2,0,4,5,6,7,8,9,10,11,12],[2,3,0,0,0,7,8,9,10,11,12]])

  players_N_col_claimed = calculate_player_N_col_claimed(players_col_state)
  assert jnp.all(players_N_col_claimed == jnp.array([0,1,3]))

  illegal_col = calculate_illegal_col(players_col_state)
  assert jnp.all(illegal_col == jnp.array([0,1,1,1,1,0,0,0,0,0,0]))

  return True

assert calculate_col_tests()

In [7]:
@jit
def update_player_col_state(active_player_index, player_col_state, runner_col_state):
  '''Move a players peices forward by the amount on the runners 
    (This is called when a player bank's their runners and ends their turn by choice)''' 
  #Input:
  #  active_player_index = index of whose turn it is
  #  player_col_state = int array of size (N_player, 11) with squares remaining in each column
  #  runner_col_state = int vector of size (11,) with runner locations
  #Output:
  #  An updated version of player_col_state where the positions have been moved up by the runners.

  #All we have to do is a subtraction, but in situations where we would overshoot target, we replace it with a 0 instead of negative number.
  updated_active_player_col_state = jnp.where(runner_col_state < player_col_state[active_player_index], player_col_state[active_player_index] - runner_col_state, 0)
  return player_col_state.at[active_player_index].set(updated_active_player_col_state)

def update_player_col_test():
  '''Test the update_player_col_state function'''
  player_col_state = jnp.array([[2,3,4,5,6,7,8,9,10,11,12],[2,0,4,5,6,7,8,9,10,11,12],[2,3,0,0,0,7,8,9,10,11,12]])
  runner_col_state = jnp.array([1,1,1,1,1,1,1,1,1,1,12])
  active_player_index = 1
  expected_answer = jnp.array([[2,3,4,5,6,7,8,9,10,11,12],[1,0,3,4,5,6,7,8,9,10,0],[2,3,0,0,0,7,8,9,10,11,12]])
  
  assert jnp.all(update_player_col_state(active_player_index, player_col_state, runner_col_state)==expected_answer)
  return True

assert update_player_col_test()

In [8]:
@jit
def prob_to_miss_targets(targets):
  '''Compute the probability to miss a list of target cols'''
  #Input:
  #  targets = a boolean array of shape (11,) with which are targets
  #Output:
  #  A real number with the probability to miss all the targets from targets when rolling 4 dice and pairing them

  #Create an array of shape (4,6,6,6,6) that contains all possible dice rolls
  #  i.e. the entry [:,a,b,c,d] = [a,b,c,d] is 4 dice rolls and a,b,c,d all run from 0 to 5
  four_dice_indices = jnp.indices((6,6,6,6)) 
 
  #Create an array with all 6 ways to choose 2 out of 4 dice
  #  Pairing 0 = choose dice 1 and dice 2
  #  Pairing 1 = choose dice 1 and dice 3 
  #  ... 
  #  Pairing 6 = choose dice 3 and dice 4
  pairing = jnp.array([[1,1,0,0],[1,0,1,0],[1,0,0,1],[0,1,1,0],[0,1,0,1],[0,0,1,1]])

  #Create an array of shape (6,6,6,6,6) which gives the value of the pairing 
  #  i.e. the (p,a,b,c,d) entry is the value of pairing p when the dice come up a,b,c,d
  four_dice_pairings = jnp.einsum("iabcd,ji->jabcd",four_dice_indices,pairing)
  
  #The same array, but of shape (6,6,6,6,11) now where it has been converted to a one hot encoding
  #  i.e. (p,a,b,c,d,:) is an array of shape (11,) with the one hot encoding of the pairing
  four_dice_pairings_one_hot = jnn.one_hot(four_dice_pairings,11)

  #A boolean array of shape (6,6,6,6,6)
  #  The (p,a,b,c,d) entry is k when the p-th pairing of a,b,c,d is hits the target list k times
  hit_target = jnp.einsum("pabcdk,k->pabcd",four_dice_pairings_one_hot,targets)

  #This is an array of shape (6,6,6,6) which is True when at least one of the pairings is in the target
  any_hit_target = jnp.any(hit_target > 0, axis=0)

  #Count the number of times we get a hit!
  number_of_dice_rolls_that_hit_target = jnp.count_nonzero(any_hit_target)

  #Convert count into a probability
  #  1296 = 6**4 is the total number of possible dice rolls
  return (1296- number_of_dice_rolls_that_hit_target)/1296

def prob_to_miss_targets_tests():
  '''Test the prob_to_miss_targets_function'''
  #No 2's test:
  #Probability to miss getting a 2 somewhere
  p_func = prob_to_miss_targets(jnp.array([1,0,0,0,0,0,0,0,0,0,0]))

  #On the other hand you can only miss getting a 2 if you have either:
  # exactly 4 non-1's rolled in the dice roll (Probabolity (5/6)^4)
  # exactly 3 non-1's in the dice roll (Probability 4*(5/6)^3*(1/6))
  p_expected = (5/6)**4+4*(5/6)**3*(1/6)

  assert jnp.isclose(p_func,p_expected) 

  #No 3's test:
  #Probability to miss getting a 3 somewhere
  p_func = prob_to_miss_targets(jnp.array([0,1,0,0,0,0,0,0,0,0,0]))

  #On the other hand you can only get a 3 if you have at least one 1 and at least one 2
  # so you miss getting a 3 if you have both NO 1's or NO 2's in the roll
  # By inclusion exclision, the probability of NO 1's or NO 2's is:
  # P(NO 1's) + P(NO 2's) - P(NO 1's and NO 2's):
  # = (5/6)^4 + (5/6)^4 - (4/6)^4
  p_expected = (5/6)**4 + (5/6)**4 - (4/6)**4
   
  assert jnp.isclose(p_func,p_expected) 

  #No 2's or 3's test:
  #Probability to miss getting a 3 somewhere
  p_func = prob_to_miss_targets(jnp.array([1,1,0,0,0,0,0,0,0,0,0]))

  #TO miss getting a 2 or a 3 as possible sums you need either
  # No 1's OR (exactly one 1 AND no 2's)
  # Since these are disjoint we can just add the probabilities
  p_expected = (5/6)**4 + 4*(1/6)*(4/6)**3

  
  assert jnp.isclose(p_func,p_expected) 
  

  return True

assert prob_to_miss_targets_tests()

# Main game simulator

In [9]:
def dummyAI(active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_options, N_COL_TO_WIN = 3, N_MAX_RUNNERS = 3):
  '''A placeholder game AI that makes all choices randomly'''
  #Input:
  #  active_player_index = An int with whose turn it currently is (which player the AI is playing for)
  #  player_col_state = An int array of size (N_players, 11) showing how many entries REMAINING until column is claimed for each player
  #  player_N_col_claimed = An int vector of shape (N_player,) with how many columns each player has claimed 
  #  illegal_col = A boolean vector of size (11,) showing which columns have already been played
  #  runner_col_options = An int array of size (N_choices, 11) showing the options of where the runners could be to the player
  # N_Col_To_Win, N_Max_Runners = Integers that can specify variants of the game rules 
  #Output: A tupl (choice_index, roll_again)
  #  1st entry: choice_index = An integer in [0,N_choices-1] with which runner choice is to be played
  #  2nd entry: roll_again = A boolean on whether or not to roll again 

  N_choices = jnp.shape(runner_col_options)[0]
  choice_index = random.randint(0,N_choices-1)
  roll_again = bool(random.randint(0,1))

  return choice_index, roll_again #dummy AI picks the first choice and stops rolling again

In [12]:
def simulate_game(random_key, Player_AI, Verbose = False, N_PLAYERS =4, N_COL_TO_WIN=3, N_MAX_RUNNERS=3, PLAYER_COL_STATE_INIT=[3,5,7,9,11,13,11,9,7,5,3]):
  '''Run a simulation of the game Can't Stop!'''
  #Input:
  #   randome_key = jrandom key used for dice rolls
  #   Player_AI = List with the functions for player AIs
  #   Verbose = whether or not to print out a play-by-play of the game
  #Output:
  #  An array of shape (N_players,) with a 1 at the player who won
  
  #Initialize game state
  player_col_state = jnp.tile(jnp.array(PLAYER_COL_STATE_INIT,dtype=jnp.dtype('i1')),(N_PLAYERS,1))
  player_N_col_claimed = calculate_player_N_col_claimed(player_col_state)
  illegal_col = calculate_illegal_col(player_col_state)
  
  #Choose a random player to start
  random_key, subrandom_key = jrandom.split(random_key)
  #  Note that the actual player who starts is one player later than the one chose here.
  active_player_index = int(jrandom.randint(subrandom_key,[1], 0,N_PLAYERS))

  #Main loop that goes until someone wins the game
  while jnp.all( player_N_col_claimed < N_COL_TO_WIN): 
    
    #Update whose turn it is
    active_player_index = (active_player_index + 1) % N_PLAYERS
    
    if Verbose : print("Player ",active_player_index,":") 
    if Verbose : print("--Player Column State: \n",player_col_state)

    #reset runners and set busted and roll again flags
    runner_col_state = jnp.zeros(11,dtype=jnp.dtype('u1')) 
    busted_state = False
    roll_again_state = True 

    #Loop while player is rolling on their turn
    while roll_again_state:
      random_key, subrandom_key = jrandom.split(random_key)
      dice_rolls = randomDice(subrandom_key)

      if Verbose : print("----Roll: ",dice_rolls)

      runner_col_choices = generate_legal_choices(dice_rolls, illegal_col, runner_col_state, N_MAX_RUNNERS)
      if Verbose : print("----Options:\n",runner_col_choices) 

      if len(runner_col_choices) > 0:
        #If there is at least one option, the AI makes a choice
        #Player_AI[active_player_index]
        active_player_AI = Player_AI[active_player_index]
        choice_index, roll_again_state = active_player_AI(active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_choices, N_COL_TO_WIN, N_MAX_RUNNERS)
        runner_col_state = runner_col_choices[choice_index]
        if Verbose : print("----Runners chosen:\n", runner_col_state) 
        if Verbose :print("----Roll again choice: ", roll_again_state)
      else:
        #No options to play! Busted!
        if Verbose : print("----Busted!")
        busted_state = True
        roll_again_state = False

    #End of the players turn:
    if busted_state == False:
      #In this case we stopped rolling before we busted! So we advance our position by our runners
      player_col_state = update_player_col_state(active_player_index,player_col_state,runner_col_state) 

    #Update illegal columns and player columns claimed
    player_N_col_claimed = calculate_player_N_col_claimed(player_col_state)
    illegal_col = calculate_illegal_col(player_col_state) 

  #At the end of this loop, one player has won!
 
  if Verbose : 
    print("GAME OVER!") 
    print("Final board state: ", player_col_state)
    print("Final claimed column count: ", player_N_col_claimed)
    print("Winners: ", ( player_N_col_claimed >= N_COL_TO_WIN ))

  return ( player_N_col_claimed >= N_COL_TO_WIN )

In [13]:
def oneVerboseGame(AIs):
  random_key = jrandom.PRNGKey(int(time.time()))
  #Play one verbose game
  simulate_game(random_key,AIs,True)

oneVerboseGame([dummyAI, dummyAI, dummyAI, dummyAI])

Player  1 :
--Player Column State: 
 [[ 3  5  7  9 11 13 11  9  7  5  3]
 [ 3  5  7  9 11 13 11  9  7  5  3]
 [ 3  5  7  9 11 13 11  9  7  5  3]
 [ 3  5  7  9 11 13 11  9  7  5  3]]
----Roll:  [6 6 2 2]
----Options:
 [[0 0 0 0 0 0 2 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 1]]
----Runners chosen:
 [0 0 1 0 0 0 0 0 0 0 1]
----Roll again choice:  True
----Roll:  [3 4 6 2]
----Options:
 [[0 0 1 0 0 0 0 0 1 0 1]
 [0 0 1 0 0 0 0 1 0 0 1]
 [0 0 1 0 0 0 1 0 0 0 1]
 [0 0 1 0 0 1 0 0 0 0 1]
 [0 0 1 0 1 0 0 0 0 0 1]
 [0 0 1 1 0 0 0 0 0 0 1]]
----Runners chosen:
 [0 0 1 0 0 0 0 1 0 0 1]
----Roll again choice:  False
Player  2 :
--Player Column State: 
 [[ 3  5  7  9 11 13 11  9  7  5  3]
 [ 3  5  6  9 11 13 11  8  7  5  2]
 [ 3  5  7  9 11 13 11  9  7  5  3]
 [ 3  5  7  9 11 13 11  9  7  5  3]]
----Roll:  [2 2 5 3]
----Options:
 [[0 0 0 1 0 1 0 0 0 0 0]
 [0 0 1 0 0 0 1 0 0 0 0]]
----Runners chosen:
 [0 0 1 0 0 0 1 0 0 0 0]
----Roll again choice:  True
----Roll:  [5 2 6 3]
----Options:
 [[0 0 1 0 0 0 1 0 0 

In [15]:
#Play a bunch of games and record final result
def play_N_games(N_games, AIs):
  winners = jnp.zeros(4,dtype=jnp.dtype('i4'))
  random_key = jrandom.PRNGKey(int(time.time()))
  for i in range(N_games):
    if i % (N_games/10) == 0:
      print( int(100*(i / N_games)), "% done")
    random_key, subrandom_key = jrandom.split(random_key)
    winners += simulate_game(subrandom_key,AIs)
  print("After ", N_games, " winners are ",winners)

play_N_games(100,[dummyAI,dummyAI,dummyAI,dummyAI])

0 % done
10 % done
20 % done
30 % done
40 % done
50 % done
60 % done
70 % done
80 % done
90 % done
After  100  winners are  [27 19 31 23]
