<a href="https://colab.research.google.com/github/mcnica89/Markov-Chains-RL-W25/blob/main/Assignment4_TicTacToe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
import jax
from einops import rearrange, reduce, repeat

import numpy as np
import time

#Global constants for size of board
N_ROWS = 3
N_COLS = 3
N_POSSIBLE = N_ROWS*N_COLS #maximum possible number of moves at any stage
MAX_GAME_LENGTH = N_ROWS*N_COLS #maximum game length

# Helper Functions

In [2]:
def create_k_in_a_lines(k=3):
    '''Creates a (n_k_in_a_lines, N_ROWS, N_COLS) array with all the possible k in a lines. Slow! Try to only run it once...'''
    #Input: k is the length of the k-in-a-rows to generate
    #Output: Boolean array of shape (n_k_in_a_lines, N_ROWS, N_COLS) array with all possible k_in_a_lines

    k_in_a_line_list = []  # List to store winning line_arrays

    # Horizontal lines
    for r in range(N_ROWS):
        for c in range(N_COLS - k + 1):
            line_array = np.zeros((N_ROWS, N_COLS), dtype=bool)
            line_array[r, c:c+k] = True
            k_in_a_line_list.append(line_array)

    # Vertical lines
    for c in range(N_COLS):
        for r in range(N_ROWS - k + 1):
            line_array = np.zeros((N_ROWS, N_COLS), dtype=bool)
            line_array[r:r+k, c] = True
            k_in_a_line_list.append(line_array)

    # Diagonal (bottom-left to top-right)
    for r in range(N_ROWS - k + 1):
        for c in range(N_COLS - k + 1):
            line_array = np.zeros((N_ROWS, N_COLS), dtype=bool)
            for i in range(k):
                line_array[r + i, c + i] = True
            k_in_a_line_list.append(line_array)

    # Diagonal (top-left to bottom-right)
    for r in range(k - 1, N_ROWS):
        for c in range(N_COLS - k + 1):
            line_array = np.zeros((N_ROWS, N_COLS), dtype=bool)
            for i in range(k):
                line_array[r - i, c + i] = True
            k_in_a_line_list.append(line_array)

    # Stack to Final Shape: (n_k_in_a_lines, N_ROWS, N_COLS)
    stacked_k_in_a_lines = np.stack(k_in_a_line_list)
    return stacked_k_in_a_lines

In [3]:
THREE_IN_A_LINE_MASK = jnp.logical_not(create_k_in_a_lines(k=3))
#Shape (n_3_IN_A_LINEs, N_ROWS, N_COLS) For a 3x3 board n_3_IN_A_LINEs is 8

@jax.jit
def which_three_in_a_line(board_bool):
  '''Inputs a boolean board of shape (...,N_ROWS,N_COLS) and outputs a boolean array of shape (...,n_3_IN_A_LINEs) with which of the possible 3-in-a-lines are there'''
  #Input: board_bool shape (..., N_ROWS, N_COLS) representing a subset of a board
  #Output: shape (..., n_3_in_a_lines) of bool with which 3-in-a-lines are there or not

  #Shape (..., 1, N_ROWS, N_COLS) so it can be broadcasted with the create_k_lines_mask
  broadcastable_board = rearrange(board_bool, '... row col -> ... 1 row col')  #Shape (..., 1, N_ROWS, N_COLS) so it can be broadcast

  line_check_board = jnp.logical_or(THREE_IN_A_LINE_MASK, broadcastable_board )
  #Shape: ( ... , n_3_in_a_lines, N_ROW, N_COLS) by broadcasting
  #..By using a logical_or the board will be ALL trues iff there is a 3 in line


  #reduce out the row and column dimensions by doing an 'all' command
  which_lines = reduce(line_check_board , '... n_line r c -> ... n_line', 'all')

  return which_lines

@jax.jit
def any_three_in_a_line(board_bool):
  '''Inputs a boolean board of shape (...,N_ROWS,N_COLS) and outputs a boolean array of shape (...) with if it has ANY 3-in-a-lines or not'''
  #Input: board_bool shape (..., N_ROWS, N_COLS) representing a subset of a board
  #Output: shape (..., n_3_in_a_lines) of bool with which 3-in-a-lines are there or not

  which = which_three_in_a_line(board_bool)
  return reduce(which, '... n_line -> ...', 'any')

#Examples:
board = jnp.array([[1,1,1],[1,0,0],[1,0,0]])
print(which_three_in_a_line(board))
print(any_three_in_a_line(board))

[ True False False  True False False False False]
True


In [4]:
#POSSIBLE_MOVES_ROWCOL is a static array of shape (N_ROWS, N_COLS, N_ROWS, N_COLS) with all possible moves from an *empty* board. The moves are filled in with a single 1, and all empty squaresa re 0.
#The (:,:,row_ix,col_ix) array is what you get when you play in row_ix, col_ix
POSSIBLE_MOVES_ROWCOL = jnp.stack([
    jnp.stack([
        # Set a move (1) at position (row_ix, col_ix)
        jnp.zeros((N_ROWS, N_COLS), dtype=int).at[row_ix, col_ix].set(1)
        for col_ix in range(N_COLS)
    ])
    for row_ix in range(N_ROWS)
])

@jax.jit
def possible_moves(board_int, fill_val):
  '''Inputs a boolean board of shape (...,N_ROWS,N_COLS) and returns a boolean array of shape
   (...,N_POSSIBLE,N_ROWS,N_COLS) with ALL the possible game states after the move
   as well as a an array (...,N_POSSIBLE) of bool with whether or not each move is legal.'''
   #Input: board_int is an int array of shape (...,N_ROWS,N_COLS) which represents the board. Assumes that any "0"s in the board_int is an empty square that can be played in. "fill_val" is the value that is filled in in the empty squares. So some of the 0s get replaced by "fill_vals" in the possible moves.
   #Output: Two arrays: moves, legal
   # moves is shape (...,N_POSSIBLE, N_ROWS, N_COLS) of int with all the possible moves one can makle from board_int. If a move is illegal, it will just return the original board_int.
   # legal is shape (...,N_POSSIBLE) of boolean with whether or not each possible move is a legal move or not. In pig-tac-toe/tic-tac-toe a move is legal if and only the square is empty (i.e. ==0)

  #First we do things with 2 axes, N_ROWS,N_COLS, and then we flatten it down to a single axis of N_POSSIBLE = N_ROWS*N_COLS

  legal_moves_ROWCOL = (board_int == 0) #Shape (..., N_ROWS, N_COLS)
  #Create a shape (..., N_ROWS, N_COLS, 1, 1) so it can be broadcast correctly
  broadcastable_legal = rearrange(legal_moves_ROWCOL, '... row col -> ... row col 1 1')

  #add axes to the board so it can be broadcast
  broadcastable_board = rearrange(board_int, '... row col -> ... row col 1 1')

  #Play the legal moves by adding the move to it. Otherwise return the original board.
  broadcastable_fill_val = rearrange(fill_val, '... -> ... 1 1 1 1')

  #Create the moves played in the legal moves, using the POSSIBLE_MOVES_ROWCOL array.
  move_played_ROWCOL = jnp.where(broadcastable_legal,
                                 broadcastable_board + broadcastable_fill_val*POSSIBLE_MOVES_ROWCOL, broadcastable_board)

  #flatten it the down to the promised axis.
  move_played_flat = rearrange(move_played_ROWCOL, '... row_game col_game row_move col_move -> ... (row_move col_move) row_game col_game')

  legal_moves_flat = rearrange(legal_moves_ROWCOL, '... row_move col_move  -> ... (row_move col_move)')

  return move_played_flat, legal_moves_flat


#Examples:
board = jnp.array([[1,1,1],[1,0,0],[1,0,0]])
fill_val = 2
moves, legal = possible_moves(board, fill_val)
for ix in range(N_POSSIBLE):
    print("----")
    print(legal[ix])
    print(moves[ix,:,:])

----
False
[[1 1 1]
 [1 0 0]
 [1 0 0]]
----
False
[[1 1 1]
 [1 0 0]
 [1 0 0]]
----
False
[[1 1 1]
 [1 0 0]
 [1 0 0]]
----
False
[[1 1 1]
 [1 0 0]
 [1 0 0]]
----
True
[[1 1 1]
 [1 2 0]
 [1 0 0]]
----
True
[[1 1 1]
 [1 0 2]
 [1 0 0]]
----
False
[[1 1 1]
 [1 0 0]
 [1 0 0]]
----
True
[[1 1 1]
 [1 0 0]
 [1 2 0]]
----
True
[[1 1 1]
 [1 0 0]
 [1 0 2]]


In [5]:
@jax.jit
def select_moves_by_ix(moves, ixs):
  '''Given the choices in ixs, pull the correct move out of the moves array'''
  #INPUT: moves is shape (...,N_POSSIBLE, N_ROWS,N_COLS) with the possible moves
  #       ix is shape (...) with the ixs that were chosen
  #OUTPUT: (...,N_ROWS,N_COLS) with the chosen move. Similar to moves[ixs] but works even for higher dimension sizes.

  #add dimensions so it can be broadcast against the moves
  broadcastable_ixs = rearrange(ixs, '... -> ... 1 1 1')

  # Use take_along_axis to select the corresponding moves along the N_POSSIBLE axis (-3), which is the "N_POSSIBLE" axis.
  selected_moves = jnp.take_along_axis(moves, broadcastable_ixs, axis=-3)

  # since moves has shape (..., N_POSSIBLE, N_ROW, N_COL) so after selection we get (..., 1, N_ROW, N_COL)

  # Remove the singleton dimension
  selected_moves = rearrange(selected_moves, '... 1 row col -> ... row col')
  return selected_moves

#Examples:
key = jax.random.PRNGKey(0)
fill_val = 2
board = jnp.array([[1,1,1],[1,0,0],[1,0,0]])
moves, legal = possible_moves(board, fill_val)
ix = jnp.where(legal==True)[0][0] #Play the first available legal move!
print(select_moves_by_ix(moves,ix))

[[1 1 1]
 [1 2 0]
 [1 0 0]]


In [6]:
@jax.jit
def P1_P2_swap(board):
  #Swaps the 1s <-> 2s so that the AI can always assume that its Player 1
  return jnp.where(board == 1, 2, jnp.where(board == 2, 1, board))


#Example:
board = jnp.array([[1,1,1],[2,0,0],[2,0,0]])
print(P1_P2_swap(board))

[[2 2 2]
 [1 0 0]
 [1 0 0]]


# The Random Move AI

In [7]:
def random_move_AI(key, moves, legal):
  '''Simplest possible AI that returns a random legal move and chooses to roll again or not purely at random'''
  #Input: key is a JAX random key object
  #  moves is shape (N_POSSIBLE, N_ROW, N_COL) with the possible moves
  #  legal is shape (N_POSSIBLE) with whether or not each move is legal
  #
  #Output: move_ix, roll_again
  #move_ix is an ix for a *legal* move
  #roll_again is whether or not to roll again

  #Random keys used for randomness
  key, subkey = jax.random.split(key)

  #Choose a random move, but only amongst the legal moves
  move_ix = jax.random.choice(subkey, np.arange(N_POSSIBLE,dtype=int), p=legal)

  return move_ix

#Examples:
key = jax.random.PRNGKey(0)
fill_val = 2
board = jnp.array([[1,1,1],[1,0,0],[1,0,0]])
moves, legal = possible_moves(board, fill_val)
ix = random_move_AI(key,moves,legal)
print(select_moves_by_ix(moves,ix))

[[1 1 1]
 [1 0 0]
 [1 0 2]]


# 1-Turn-Greedy AI

In [8]:
def greedy_AI(key, moves, legal):
  '''AI that does two things: 1. Takes any 3 in a line available, 2. If on the enemies next turn'''
  #Input: key is a JAX random key object
  #  moves is shape (N_POSSIBLE, N_ROW, N_COL) with the possible moves
  #  legal is shape (N_POSSIBLE) with whether or not each move is legal
  #
  #Output: move_ix, roll_again
  #move_ix is an ix for a *legal* move
  #roll_again is whether or not to roll again

  me_player = 1 #Its setup so the AI always plays from the point of view of player 1.
  enemy_player = 2

  #will be filled in!
  possible_move_choices = jnp.zeros(N_POSSIBLE, dtype=bool) #Shape (N_POSSIBLE)


  which_moves_are_3_in_a_lines = any_three_in_a_line(moves==me_player) #Shape(N_POSSIBLE)
  if jnp.any(which_moves_are_3_in_a_lines):
    #Choose a random move from amongst the moves that let us win!
    possible_move_choices = which_moves_are_3_in_a_lines

  else: #we don't have a win, so lets see if the opponent would have one!
    #By taking the minimum, we get what the state was that led to this move list.
    previous_state = reduce(moves, 'move_ix row col -> row col','min')
    enemy_moves, enemy_legal = possible_moves(previous_state, fill_val=enemy_player) #possible moves if the enemy was going to move instead of us...
    possible_move_choices = any_three_in_a_line(enemy_moves==enemy_player)

  possible_move_choices = jnp.logical_and( possible_move_choices, legal ) #make sure to only choose from a legal move...

  if not jnp.any(possible_move_choices):
    #if there aren't any choice, then just pick a random legal move
    possible_move_choices = legal


  key, subkey = jax.random.split(key)
  move_ix = jax.random.choice(subkey, np.arange(N_POSSIBLE,dtype=int), p=possible_move_choices)
  #this will choose purely randomly if they are all False.

  return move_ix

# Your AI Here

In [9]:
params = None #make some parameters to use on your AI!

def value_function(params, board):
  #use the params to create a value function here
  value = None
  return value

def my_AI(key, moves, legal):
  '''Simplest possible AI that returns a random legal move and chooses to roll again or not purely at random'''
  #Input: key is a JAX random key object
  #  moves is shape (N_POSSIBLE, N_ROW, N_COL) with the possible moves
  #  legal is shape (N_POSSIBLE) with whether or not each move is legal
  #
  #Output: move_ix, roll_again
  #move_ix is an ix for a *legal* move
  #roll_again is whether or not to roll again

  #Random keys used for randomness
  key, subkey = jax.random.split(key)

  #Choose a move!
  #Replace the current code which chooses randomly with code that choose according to your value function (HINT: use jnp.argmax)

  move_ix = jax.random.choice(subkey, np.arange(N_POSSIBLE,dtype=int), p=legal)

  return move_ix

#Examples:
key = jax.random.PRNGKey(0)
fill_val = 2
board = jnp.array([[1,1,1],[1,0,0],[1,0,0]])
moves, legal = possible_moves(board, fill_val)
ix = my_AI(key,moves,legal)
print(select_moves_by_ix(moves,ix))

[[1 1 1]
 [1 0 0]
 [1 0 2]]


# Tic-Tac-Toe Game Simulation (One Game at a Time)


In [10]:
#Function to print out a board state as text
def print_board(board):
    symbols = {0:".",1:'X', 2:'O'} #what each board entry represents
    # 1 = Player 1 actual piece (X)
    # 2 = Player 2 actual piece (O)

    for row in range(N_ROWS):
        print(" | ".join(symbols[int(board[row,col])] for col in range(N_COLS)))
        print("-"*(N_COLS*4-3))
    print()

In [21]:
#########
# TIC TAC TOE Game loop
# Simulates ONE game a time
#########

num_episodes = 1 #number of games to simulate
VERBOSE = True #whether or not to print game details as you go.

AIs = [random_move_AI, greedy_AI] #AIs for Player 1 and Player 2
key = jax.random.PRNGKey(int(time.time())) #random key used for JAX

for episode in range(num_episodes):

  print(f"\n====Episode {episode+1}====") if VERBOSE else None

  ##############
  #Initialize things for the game
  ##############
  key, subkey = jax.random.split(key)
  current_player = ( jax.random.uniform(subkey)<0.5 ) + 1 #choose random starting player
  board = jnp.zeros((N_ROWS, N_COLS), dtype=int) #board starts out empty
  game_over = False
  key, subkey_ai = jax.random.split(key,2) #random keys to give to the ai
  ai_key = jax.random.split(subkey_ai, MAX_GAME_LENGTH)

  ############
  # Main Loop for the Truns
  ############
  for turn in range(MAX_GAME_LENGTH):
    if game_over == False:


      #Reminder: on the board, 0=Empty, 1=Player 1, 2=Player 2
      # so we get the legal moves by filling in with current_player.
      moves, legal = possible_moves(board, fill_val=current_player) #Get all possible moves and their legality

      #AI chooses a move_ix from the available moves

      if current_player == 1:
        move_ix = AIs[0](ai_key[turn], moves, legal)
      elif current_player == 2:
        #Swap the 1s <-> 2s so the AI can think as if its Player 1
        move_ix = AIs[1](ai_key[turn], P1_P2_swap(moves), legal)


      board = select_moves_by_ix(moves, move_ix) #Update the board

      print(f"Turn #{turn}") if VERBOSE else None
      print_board(board) if VERBOSE else None

      #check if currnt player has a 3 in a line by doing an "any" on which three in a lines the have
      game_over = any_three_in_a_line(board == current_player)  #game ends if they have it!

      current_player = 3-current_player #switch players here

  ####
  # After game clean up
  ####
  print(f"Game End") if VERBOSE else None
  print_board(board) if VERBOSE else None


  #Tic Tac Toe can also be thought of in terms of victory points, where there is at most 1 victory point in play!
  which_victory_points = 1.0*which_three_in_a_line(board == 1) - 1.0*which_three_in_a_line(board == 2)
  net_victory_points = reduce(which_victory_points, '... n_line -> ...', 'sum')

  print(f" Final victory point diff: {net_victory_points}") if VERBOSE else None


====Episode 1====
Turn #0
. | . | .
---------
. | . | O
---------
. | . | .
---------

Turn #1
. | . | X
---------
. | . | O
---------
. | . | .
---------

Turn #2
. | . | X
---------
. | . | O
---------
. | O | .
---------

Turn #3
. | X | X
---------
. | . | O
---------
. | O | .
---------

Turn #4
O | X | X
---------
. | . | O
---------
. | O | .
---------

Turn #5
O | X | X
---------
. | . | O
---------
X | O | .
---------

Turn #6
O | X | X
---------
. | O | O
---------
X | O | .
---------

Turn #7
O | X | X
---------
. | O | O
---------
X | O | X
---------

Turn #8
O | X | X
---------
O | O | O
---------
X | O | X
---------

Game End
O | X | X
---------
O | O | O
---------
X | O | X
---------

 Final victory point diff: -1.0
