<a href="https://colab.research.google.com/github/mcnica89/Markov-Chains-RL-W25/blob/main/Pig_Tac_Toe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
import jax
from einops import rearrange, reduce, repeat

import numpy as np
import time

#Global constants for size of board
N_ROWS = 4
N_COLS = 4
N_POSSIBLE = N_ROWS*N_COLS #maximum possible number of moves at any stage
P_WIN_ROLLAGAIN = 2.0/3.0 #probability to suceed when you roll again
MAX_GAME_LENGTH = 4*N_ROWS*N_COLS #maximum length a game can go to avoid an infinite game

# Rules of Pig-Tac-Toe



## Object of the game
Players place X's and O's on a 4 by 4 grid and try to get 3 of their symbols in a line (either horizontally, vertically, or diagonally). Unlike Tic-Tac-Toe, the game does *not* end when the first 3-in-a-line happens. Instead, the game is played until the board is filled, and any 3-in-a-line is worth 1 victory point. The player with the most victory points at the end of the game wins. A single X or O may contribute to scoring multiple victory points. (For example, having 4 X's in a row effectively counts as 2 victory points). If both players have the same number of victory points at the end of the game, they rejoice in their shared victory.

## How to place Xs and Os
On a turn, a player does the following with their symbol: (Rules described for the X player below, the O player is identical with the symbol changed)

**Step 1.** Place "potential" X: Choose an empty square on the board and place a "potential" X there.

**Step 2.** Decide whether to play it safe and "BANK" or to take a risk and "ROLL":
  - If they choose to "BANK", all the potential X's from this turn become permanent X's on the board. Then the turn is over and it becomes the other players turn.
  - If they choose to "ROLL", they roll a d6.
    - On a dice outcome of 3,4,5,6 they succeed their roll and go back to Step 1 and can place another potential X, and then decide whether to ROLL/BANK again.
    - On a dice outcome of 1 or 2, they fail their roll and they lose all their potential Xs from this turn, which are removed from the board. It then becomes the other player's turn.

## Special Starting Rule
The first player determined randomly (e.g., by a coin flip or die roll). To offset the first player advantage, the player who plays second begins the game with one square chosen at random already filled in.   

## Game End Condition
The game ends when all 16 squares are filled, at which point victory points are counted. In the computer implentation, to avoid games going on infinetly long, there is also a maximum game length limit imposed (normally 64 rounds max). If this maximum limit is hit, the game instantly ends and the winner is the player with most victory points at this time.

## Other Project Details

- Project is open ended: make the best possible AI you can by combining course ideas however you want to. Show your work including a recorded mini presentation on what you did.
- Will have a benchmark of some simple AIs to play against on Gradescope.
- For fun only final tournament amongst class members.
- Can work by yourself, or with a partner. Expectations higher if working with partner.
- See CourseLink for details

# Helper Functions for 3 in a line games


In [2]:
def create_k_in_a_lines(k=3):
    '''Creates a (n_k_in_a_lines, N_ROWS, N_COLS) array with all the possible k in a lines. Slow! Try to only run it once...'''
    #Input: k is the length of the k-in-a-rows to generate
    #Output: Boolean array of shape (n_k_in_a_lines, N_ROWS, N_COLS) array with all possible k_in_a_lines

    k_in_a_line_list = []  # List to store winning line_arrays

    # Horizontal lines
    for r in range(N_ROWS):
        for c in range(N_COLS - k + 1):
            line_array = np.zeros((N_ROWS, N_COLS), dtype=bool)
            line_array[r, c:c+k] = True
            k_in_a_line_list.append(line_array)

    # Vertical lines
    for c in range(N_COLS):
        for r in range(N_ROWS - k + 1):
            line_array = np.zeros((N_ROWS, N_COLS), dtype=bool)
            line_array[r:r+k, c] = True
            k_in_a_line_list.append(line_array)

    # Diagonal (bottom-left to top-right)
    for r in range(N_ROWS - k + 1):
        for c in range(N_COLS - k + 1):
            line_array = np.zeros((N_ROWS, N_COLS), dtype=bool)
            for i in range(k):
                line_array[r + i, c + i] = True
            k_in_a_line_list.append(line_array)

    # Diagonal (top-left to bottom-right)
    for r in range(k - 1, N_ROWS):
        for c in range(N_COLS - k + 1):
            line_array = np.zeros((N_ROWS, N_COLS), dtype=bool)
            for i in range(k):
                line_array[r - i, c + i] = True
            k_in_a_line_list.append(line_array)

    # Stack to Final Shape: (n_k_in_a_lines, N_ROWS, N_COLS)
    stacked_k_in_a_lines = np.stack(k_in_a_line_list)
    return stacked_k_in_a_lines

In [3]:
THREE_IN_A_ROW_MASK = jnp.logical_not(create_k_in_a_lines(k=3))
#Shape (n_3_in_a_rows, N_ROWS, N_COLS) For a 3x3 board n_3_in_a_rows is 8

@jax.jit
def which_three_in_a_line(board_bool):
  '''Inputs a boolean board of shape (...,N_ROWS,N_COLS) and outputs a boolean array of shape (...,n_3_in_a_rows) with which of the possible 3-in-a-lines are there'''
  #Input: board_bool shape (..., N_ROWS, N_COLS) representing a subset of a board
  #Output: shape (..., n_3_in_a_lines) of bool with which 3-in-a-lines are there or not

  #Shape (..., 1, N_ROWS, N_COLS) so it can be broadcasted with the create_k_lines_mask
  broadcastable_board = rearrange(board_bool, '... row col -> ... 1 row col')  #Shape (..., 1, N_ROWS, N_COLS) so it can be broadcast

  line_check_board = jnp.logical_or(THREE_IN_A_ROW_MASK, broadcastable_board )
  #Shape: ( ... , n_3_in_a_lines, N_ROW, N_COLS) by broadcasting
  #..By using a logical_or the board will be ALL trues iff there is a 3 in line


  #reduce out the row and column dimensions by doing an 'all' command
  which_lines = reduce(line_check_board , '... n_line r c -> ... n_line', 'all')

  return which_lines

#Examples:
#board = jnp.array([[1,1,1,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]])
#print(which_three_in_a_line(board))

In [4]:
#POSSIBLE_MOVES_ROWCOL is a static array of shape (N_ROWS, N_COLS, N_ROWS, N_COLS) with all possible moves from an *empty* board. The moves are filled in with a single 1, and all empty squaresa re 0.
#The (:,:,row_ix,col_ix) array is what you get when you play in row_ix, col_ix
POSSIBLE_MOVES_ROWCOL = jnp.stack([
    jnp.stack([
        # Set a move (1) at position (row_ix, col_ix)
        jnp.zeros((N_ROWS, N_COLS), dtype=int).at[row_ix, col_ix].set(1)
        for col_ix in range(N_COLS)
    ])
    for row_ix in range(N_ROWS)
])

@jax.jit
def possible_moves(board_int, fill_val):
  '''Inputs a boolean board of shape (...,N_ROWS,N_COLS) and returns a boolean array of shape
   (...,N_POSSIBLE,N_ROWS,N_COLS) with ALL the possible game states after the move
   as well as a an array (...,N_POSSIBLE) of bool with whether or not each move is legal.'''
   #Input: board_int is an int array of shape (...,N_ROWS,N_COLS) which represents the board. Assumes that any "0"s in the board_int is an empty square that can be played in. "fill_val" is the value that is filled in in the empty squares. So some of the 0s get replaced by "fill_vals" in the possible moves.
   #Output: Two arrays: moves, legal
   # moves is shape (...,N_POSSIBLE, N_ROWS, N_COLS) of int with all the possible moves one can makle from board_int. If a move is illegal, it will just return the original board_int.
   # legal is shape (...,N_POSSIBLE) of boolean with whether or not each possible move is a legal move or not. In pig-tac-toe/tic-tac-toe a move is legal if and only the square is empty (i.e. ==0)

  #First we do things with 2 axes, N_ROWS,N_COLS, and then we flatten it down to a single axis of N_POSSIBLE = N_ROWS*N_COLS

  legal_moves_ROWCOL = (board_int == 0) #Shape (..., N_ROWS, N_COLS)
  #Create a shape (..., N_ROWS, N_COLS, 1, 1) so it can be broadcast correctly
  broadcastable_legal = rearrange(legal_moves_ROWCOL, '... row col -> ... row col 1 1')

  #add axes to the board so it can be broadcast
  broadcastable_board = rearrange(board_int, '... row col -> ... row col 1 1')

  #Play the legal moves by adding the move to it. Otherwise return the original board.
  broadcastable_fill_val = rearrange(fill_val, '... -> ... 1 1 1 1')

  #Create the moves played in the legal moves, using the POSSIBLE_MOVES_ROWCOL array.
  move_played_ROWCOL = jnp.where(broadcastable_legal,
                                 broadcastable_board + broadcastable_fill_val*POSSIBLE_MOVES_ROWCOL, broadcastable_board)

  #flatten it the down to the promised axis.
  move_played_flat = rearrange(move_played_ROWCOL, '... row_game col_game row_move col_move -> ... (row_move col_move) row_game col_game')

  legal_moves_flat = rearrange(legal_moves_ROWCOL, '... row_move col_move  -> ... (row_move col_move)')

  return move_played_flat, legal_moves_flat


#Examples:
# board = jnp.array([[1,1,1,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]])
# fill_val = 2
# moves, legal = possible_moves(board, fill_val)
# for ix in range(N_POSSIBLE):
#     print("----")
#     print(legal[ix])
#     print(moves[ix,:,:])

@jax.jit
def P1_P2_swap(board):
  #Swaps the 1s <-> 2s, and -1s <-> -2s  so that the AI can always assume that its Player 1
  swap_neg = jnp.where(board == -1, -2, jnp.where(board == -2, -1, board))
  swap_pos = jnp.where(swap_neg == 1, 2, jnp.where(swap_neg == 2, 1, swap_neg))
  return swap_pos


#Examples:
# board = jnp.array([[1,1,1,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]])
# print(P1_P2_swap(board))

In [5]:
@jax.jit
def select_moves_by_ix(moves, ixs):
  '''Given the choices in ixs, pull the correct move out of the moves array'''
  #INPUT: moves is shape (...,N_POSSIBLE, N_ROWS,N_COLS) with the possible moves
  #       ix is shape (...) with the ixs that were chosen
  #OUTPUT: (...,N_ROWS,N_COLS) with the chosen move. Similar to moves[ixs] but works even for higher dimension sizes.

  #add dimensions so it can be broadcast against the moves
  broadcastable_ixs = rearrange(ixs, '... -> ... 1 1 1')

  # Use take_along_axis to select the corresponding moves along the N_POSSIBLE axis (-3), which is the "N_POSSIBLE" axis.
  selected_moves = jnp.take_along_axis(moves, broadcastable_ixs, axis=-3)

  # since moves has shape (..., N_POSSIBLE, N_ROW, N_COL) so after selection we get (..., 1, N_ROW, N_COL)

  # Remove the singleton dimension
  selected_moves = rearrange(selected_moves, '... 1 row col -> ... row col')
  return selected_moves

#Examples:
# key = jax.random.PRNGKey(0)
# fill_val = 2
# board = jnp.array([[1,1,1,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]])
# moves, legal = possible_moves(board, fill_val)
# ix = jnp.where(legal==True)[0][0] #Play the first available legal move!
# print(select_moves_by_ix(moves,ix))

# The Random Move AI

In [6]:
def random_move_AI(key, moves, legal, board):
  '''Simplest possible AI that returns a random legal move and chooses to roll again or not purely at random'''
  #Input: key is a JAX random key object
  #  moves is shape (N_POSSIBLE, N_ROW, N_COL) with the possible moves
  #  legal is shape (N_POSSIBLE) with whether or not each move is possible
  #  board is shape (N_ROW, N_COL) with the current board (i.e. the situation *before* this turn)

  #Output: move_ix, roll_again
  #move_ix is an ix for a *legal* move
  #roll_again is whether or not to roll again

  #Note: The AI always plays as if its Player 1. The board state is swapped in the implentation before it is sent to the AI to make sure this is the case!

  #Random keys used for randomness
  key, subkey_1, subkey_2 = jax.random.split(key, 3)

  #Choose a random move, but only amongst the legal moves
  move_ix = jax.random.choice(subkey_1, np.arange(N_POSSIBLE,dtype=int), p=legal)
  #"p=legal" means that the probability of choosing an illegal move is 0.

  #Choose to roll again or not.
  #roll_again = 1 means to roll again, and roll_again = 0 mean to not!
  roll_again = jax.random.choice(subkey_2, np.arange(2,dtype=int))

  return move_ix, roll_again

#Examples:
# key = jax.random.PRNGKey(0)
# fill_val = 2
# board = jnp.array([[1,1,1,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]])
# moves, legal = possible_moves(board, fill_val)
# ix, roll_again = random_move_AI(key,moves,legal)
# print(select_moves_by_ix(moves,ix))

# Pig-Tac-Toe Game Simulation (One Game at a Time)


In [7]:
#Function to print out a board state as text
def print_board(board):
    symbols = {-2:"o", -1:"x", 0:".",1:'X', 2:'O'} #what each board entry represents
    # 1 = Player 1 actual piece (X)
    # -1 = Player 1 potential piece (x)
    # 2 = Player 2 actual piece (O)
    # -2 = Player 2 potential piece (o)

    for row in range(N_ROWS):
        print(" | ".join(symbols[int(board[row,col])] for col in range(N_COLS)))
        print("-"*(N_COLS*4-3))
    print()


In [8]:
#########
# PIG TAC TOE Game loop
# Simulates ONE game a time
#########

num_episodes = 1 #number of games to simulate
VERBOSE = True #whether or not to print game details as you go.

AIs = [random_move_AI, random_move_AI] #AIs for Player 1 and Player 2
key = jax.random.PRNGKey(int(time.time())) #random key used for JAX

for episode in range(num_episodes):

  print(f"\n====Episode {episode+1}====") if VERBOSE else None

  ##############
  #Initialize things for the game
  ##############
  key, subkey = jax.random.split(key)
  current_player = ( jax.random.uniform(subkey)<0.5 ) + 1 #choose random starting player
  board = jnp.zeros((N_ROWS, N_COLS), dtype=int) #board starts out empty
  game_over = False
  key, subkey_ai, subkey_roll = jax.random.split(key,3) #random keys to give to the ai and to give to the rolls
  ai_key = jax.random.split(subkey_ai, MAX_GAME_LENGTH)
  roll_key = jax.random.split(subkey_roll, MAX_GAME_LENGTH)

  #Special Starting Rule: non-starting player gets a random starting piece!
  opposite_player = 3 - current_player
  moves, legal  = possible_moves(board, fill_val = opposite_player)
  key, subkey = jax.random.split(key)
  random_move_ix, _ = random_move_AI(subkey, moves, legal, board)
  board = select_moves_by_ix(moves, random_move_ix) #board now contains a random move for the opposite player
  print("Starting Board: ") if VERBOSE else None
  print_board(board) if VERBOSE else None



  ############
  # Main Loop for the Truns
  ############
  for turn in range(MAX_GAME_LENGTH):
    if game_over == False:
      #######
      #STEP 1: Play a potential piece
      #######
      #
      # The potential piece is added as a negative value!
      # So -1 represents a potential player 1 piece, and -2 is a potential player 2 piece.

      moves, legal = possible_moves(board, fill_val=-current_player) #Get all possible moves and their legality
      if current_player == 1:
        move_ix, roll_again = AIs[0](ai_key[turn], moves, legal, board)
      elif current_player == 2:
        #If its player 2's turn we swap P1 <-> P2 so that the AI can think from P1's point of view
        move_ix, roll_again = AIs[1](ai_key[turn], P1_P2_swap(moves), legal, P1_P2_swap(board))
      #AI chooses a move_ix from the available moves

      board = select_moves_by_ix(moves, move_ix) #Update the board

      print(f"Turn #{turn}") if VERBOSE else None
      print_board(board) if VERBOSE else None

      #########
      #STEP 2: Choose to roll again or stop
      #########
      if roll_again: #if the AI chose to roll again...
        print("  Chose to roll again!") if VERBOSE else None

        #Did it suceed?
        roll_outcome = (jax.random.uniform(roll_key[turn]) < P_WIN_ROLLAGAIN)
        if roll_outcome == 0:
          print("  ..and failed!") if VERBOSE else None
          board = jnp.where(board < 0, 0, board) #delete potential pieces
          current_player = 3 - current_player #change player
        else:
          print("  ..and suceeded!") if VERBOSE else None
          pass #if you win the roll just keep going!

      elif roll_again == 0:
        print("  Chose to stop!") if VERBOSE else None
        board = jnp.abs(board) #potential pieces become real!
        current_player = 3 - current_player #change player

      #check if game is over by counting the number of empty squares
      empty_squares = reduce(board == 0, '... row col -> ...', 'sum')
      game_over = (empty_squares == 0)


  ####
  # After game clean up
  ####
  board = jnp.abs(board) #make any potential pieces real, this is needed if the game ends with the player choosing to roll again even if the board is full or if game ends at the max-length limit
  print(f"Game End") if VERBOSE else None
  print_board(board) if VERBOSE else None

  #count up victory points, +1 for Player 1 and -1 for Player 2
  #multuply by 1 to conver the bool to int
  which_victory_points = 1*which_three_in_a_line(board == 1) - 1*which_three_in_a_line(board == 2)
  net_victory_points = reduce(which_victory_points, '... n_line -> ...', 'sum')

  print(f" Final victory point diff: {net_victory_points}") if VERBOSE else None


====Episode 1====
Starting Board: 
. | . | . | .
-------------
. | . | . | .
-------------
. | . | . | .
-------------
. | O | . | .
-------------

Turn #0
. | . | . | .
-------------
. | . | . | .
-------------
. | . | . | .
-------------
. | O | x | .
-------------

  Chose to roll again!
  ..and suceeded!
Turn #1
. | . | . | .
-------------
. | x | . | .
-------------
. | . | . | .
-------------
. | O | x | .
-------------

  Chose to roll again!
  ..and suceeded!
Turn #2
. | . | . | .
-------------
. | x | . | .
-------------
. | . | . | .
-------------
. | O | x | x
-------------

  Chose to roll again!
  ..and suceeded!
Turn #3
. | . | . | .
-------------
. | x | . | .
-------------
. | . | . | .
-------------
x | O | x | x
-------------

  Chose to stop!
Turn #4
. | . | . | .
-------------
. | X | . | .
-------------
. | o | . | .
-------------
X | O | X | X
-------------

  Chose to stop!
Turn #5
. | . | . | .
-------------
. | X | . | .
-------------
. | O | . | x
-----------

# Game Simulation (Multiple Games at once using batch axes/parallelization/vectorization)

In [9]:
#Al the helper functions were setup so they work with a batch_axis
#So you can add an axis, e.g. board is (N_BATCH,N_ROWS,N_COLS)
# and the function will apply to all the games board[0,:,:], board[1,:,:] ... etc
#In these example we apply them to 2 games at once

# board_2 = jnp.array([[[1,1,1,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]],
#                      [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,0]]])

##Testing which three in a line:
# print("===Test which_three_in_a_line===")
# print(which_three_in_a_line(board_2))

##Testing possible_moves
# print("===Test possible_moves===")
# fill_val = jnp.array([2,3])
# moves, legal = possible_moves(board_2, fill_val)
# for ix in range(N_POSSIBLE):
#     print("----")
#     print(legal[:,ix])
#     print(moves[:,ix,:,:])

##Testing select_moves_by_ix
# key = jax.random.PRNGKey(0)
# fill_val = jnp.array([2,3])
# moves, legal = possible_moves(board_2, fill_val)
# ixs = jnp.zeros(2,dtype=int)
# batch_dim = jnp.arange(2)
# ixs = ixs.at[batch_dim].set(jnp.array([jnp.where(legal[i] == True)[0][0] for i in batch_dim]))
# print(select_moves_by_ix(moves,ixs))

In [10]:
@jax.jit
def vmaped_random_move_AI(key, moves, legal, board):
  '''vmaped version of random_move_AI that can do multiple games at once'''
  #key is a *single* JAX random key
  #
  # all the other inputs have a N_BATCH dimension added on as the 0th axes: i.e.
  #moves is (N_BATCH,N_POSSIBLE,N_ROW,N_COL), and legal is (N_BATCH,N_POSSIBLE)
  #board is (N_BATCH, N_ROW, N_COL)
  #

  # Vectorize over the 0th axis (batch dimension)
  v_random_move_AI = jax.vmap(random_move_AI, in_axes=(0, 0, 0, 0))

  # Create one key per game
  keys = jax.random.split(key, legal.shape[:-1])

  # Apply the vmapped function to get one random move per game
  return v_random_move_AI(keys, moves, legal, board)

# #Example:
# board_2 = jnp.array([[[1,1,1,0],[1,0,0,0],[1,0,0,0],[0,0,0,0]],
#                      [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,0]]])
# key = jax.random.PRNGKey(0)
# fill_val = jnp.array([2,3])
# moves, legal = possible_moves(board_2, fill_val)
# move_ixs, roll_again = vmaped_random_move_AI(key, moves, legal)
# print("Move_ixs chosen: ", move_ixs)
# print("Roll again chosen: ",roll_again)
# print(select_moves_by_ix(moves,move_ixs))

In [17]:
#########
# PIG TAC TOE Multi Game game loop
# Simulates "N_BATCH" games at a time
# If in google colab, can select Runtime > Change runtime type > GPU or TPU for extra speedup here. (Warning: this will restart the sessions. Also you may need to pip install einops on TPU for some reason?)
#########

N_BATCH = 1_000 #Number of games simulated at once.

num_episodes = 1 #Number of times to run it. So total games is num_episodes * N_BATCH

AIs = [vmaped_random_move_AI, vmaped_random_move_AI]
key = jax.random.PRNGKey(int(time.time()))

VERBOSE = False #whether or not to print stuff out

for episode in range(num_episodes):

  print(f"\n====Episode {episode+1}====") if VERBOSE else None

  ##############
  #Initialize things for the game
  ##############

  key, subkey = jax.random.split(key)
  #choose random starting player. Note that is a vector of size N_BATCH since each game can have a different current_player
  current_player = ( jax.random.uniform(subkey,shape=(N_BATCH))<0.5 ) + 1
  #Board. The zeroth axis is the batch axis so board[batch_ix] is a single game
  board = np.zeros((N_BATCH, N_ROWS, N_COLS), dtype=int)
  #Whether or not each game is over.
  game_over = jnp.zeros(N_BATCH, dtype=bool)
  #Random keys
  key, subkey_ai, subkey_roll = jax.random.split(key,3)
  ai_key = jax.random.split(subkey_ai, MAX_GAME_LENGTH)
  roll_key = jax.random.split(subkey_roll, MAX_GAME_LENGTH)

  #Special Starting Rule: non-starting player gets a random starting piece!
  opposite_player = 3 - current_player
  moves, legal  = possible_moves(board, opposite_player)
  key, subkey = jax.random.split(key)
  random_move_ix, _ = vmaped_random_move_AI(subkey, moves, legal, board)
  board = select_moves_by_ix(moves, random_move_ix)
  print("Starting Board: ") if VERBOSE else None
  print_board(board[0]) if VERBOSE else None


  ############
  # Main Loop for the Turns
  ############
  for turn in range(MAX_GAME_LENGTH):

      #######
      #STEP 1: Play a potential piece
      #######
      #
      # The potential piece is added as a negative value!
      # So -1 represents a potential player 1 piece, and -2 is a potential player 2 piece.
      moves, legal = possible_moves(board, -current_player)
      # Easiest solution to batching the AIs is to just run both. (Otherwise have to do a bunch of rearranging)
      #
      # What would Player 1 do here?
      move_ix_1, roll_again_1 = AIs[0](ai_key[turn], moves, legal, board)
      # What would Player 2 do here? Note we swap 1s <-> 2s so the AI can assume it is P1's turn always.
      move_ix_2, roll_again_2 = AIs[1](ai_key[turn], P1_P2_swap(moves), legal, P1_P2_swap(board))

      # Set the moves based on whose turn it is currently!
      move_ix = jnp.where(current_player == 1, move_ix_1, move_ix_2)
      roll_again = jnp.where(current_player == 1, roll_again_1, roll_again_2)

      #Only update the games that are still in progress; if the game is over just do nothing.
      broadcastable_game_on = rearrange(game_over==False, '... -> ... 1 1')
      board = jnp.where(broadcastable_game_on, select_moves_by_ix(moves, move_ix), board)

      print(f"Turn #{turn}") if VERBOSE else None
      print_board(board[0]) if VERBOSE else None

      #########
      #STEP 2: Choose to roll again or stop
      #########
      # Get what the roll outcome would be in all the games.
      roll_outcome = (jax.random.uniform(roll_key[turn],shape=(N_BATCH)) < P_WIN_ROLLAGAIN)

      #Use jnp.where to do the "if" statements in a vectorized way.

      #IF roll_again == 1 & roll_outomce == 0, then FAILED the roll.
      condition = (roll_again == 1) & (roll_outcome == 0) & (game_over == False)
      current_player = jnp.where( condition,  3-current_player, current_player)
      #
      broadcastable_condition = rearrange(condition, '... -> ... 1 1')
      board = jnp.where(broadcastable_condition,  jnp.where(board < 0, 0, board), board ) #remove all the potential pieces (where its <0)

      #IF roll_again == 1 & roll_outcome == 1, then SUCEED the roll.
      #(nothing happens here)

      #IF roll_again==0, then chose to STOP.
      condition = (roll_again == 0) & (game_over == False)
      current_player = jnp.where( condition,  3-current_player, current_player)
      #
      broadcastable_condition = rearrange(condition, '... -> ... 1 1')
      board = jnp.where(broadcastable_condition,  jnp.abs(board), board )

      #check for which games are over by counting empty squares.
      empty_squares = reduce(board == 0, '... row col -> ...', 'sum')
      game_over = (empty_squares == 0)
      if jnp.all(game_over):
        break #exit early if all the games are finished..


  ####
  # After game clean up
  ####
  board = jnp.abs(board) #make any potential pieces real, this is needed if the game ends with the player choosing to roll again even if the board is full or if game ends at the max-length limit
  print(f"Game End") if VERBOSE else None
  print_board(board) if VERBOSE else None

  #count up victory points, +1 for Player 1 and -1 for Player 2. Multuply by 1 to convert bool to int.
  which_victory_points = 1*which_three_in_a_line(board == 1) - 1*which_three_in_a_line(board == 2)
  net_victory_points = reduce(which_victory_points, '... n_line -> ...', 'sum')

  #the winner is whoever has more points, so final reward is +1 if Player1 > Player 2, 0 points if tied, and -1 if Player1 < Player2. The "sign" function does this.
  game_winner_reward = jnp.sign(net_victory_points)

  print(f" Final victory point diff: {net_victory_points}") if VERBOSE else None

In [18]:
net_victory_points

Array([ -2,   1,   8,  -3,  -7,  -4,   2,  -3,  -1,   9,  -1,  -2,   0,
         2,   1,   3,  -1,  -3,  -6,  -7,  -6,   9,   0,   4,  -3, -11,
         4,  -8,   1,   5,   0,   3,  -2,  -5,  11,   2,   0,  -1,   2,
        -1,  -2,  -3,   0,  -1,  -4,  -1,  -2,  -1,  -3,   7,   0,   1,
         1, -10,  -8,   8,  -2, -12,   7,  -3,   0,   8,  -5,  -4,   0,
       -12,   5,  -3,  -5,   1,   0,  -2,   6,   4,   8,  -1,   0,   2,
        -4,  -1,  -3,  -3,   2,  -2,  -6,  -2,  -1,  -7,   0,   1,   0,
        -5,   4,   1,  -7,  -1,   9,  -1,   3,   6,   7,  12,   4,  -3,
         9,   3,  -7,   2,  -9,   1,  -3,   5,  -6,  -2,  -3,   3,  -6,
        -3,   3,   7,   4,  -8,  -3,   2, -12,   8,  -1,  -5,  -2,  -5,
        -1,  -4,   5,  -3,   4,   0,  -5,   1,   2,  -1,  -4,   8,   1,
        -5,  -1,  -6,  -4,   1,   6,  -8,   0,  -4,  -6,   0,  -6,  -3,
        -3,   0,   6,  -1,   5,   0,  -3,  -3,  10,  -8,   1,  -1,   0,
         3,   3,   1,   1,   0, -11,  -3,   4,  -4,  -5,  -1,   