In [1]:
import numpy as np
import math
np.__version__

'1.23.5'

In [2]:
class TicTacToe:
    def __init__(self):
        # TicTacToe is a 3x3 board game
        self.row_count = 3
        self.column_count = 3
        # The total number of actions is the number of cells on the board
        self.action_size = self.row_count * self.column_count
        
    def get_initial_state(self):
        # The initial state is an empty board
        return np.zeros((self.row_count, self.column_count))
    
    def get_next_state(self, state, action, player):
        # Given the current state, the action taken by the player, and the player who took the action,
        # this function returns the next state of the board
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state
    
    def get_valid_moves(self, state):
        # Given the current state of the board, this function returns a binary array indicating which
        # actions are valid (i.e., which cells are empty)
        return (state.reshape(-1) == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        # Given the current state of the board and the action taken by the player, this function checks
        # if the player has won the game
        row = action // self.column_count
        column = action % self.column_count
        player = state[row, column]
        
        return (
            np.sum(state[row, :]) == player * self.column_count # check row
            or np.sum(state[:, column]) == player * self.row_count # check column
            or np.sum(np.diag(state)) == player * self.row_count # check diagonal
            or np.sum(np.diag(np.flip(state, axis=0))) == player * self.row_count # check anti-diagonal
        )
    
    def get_value_and_terminated(self, state, action):
        # Given the current state of the board and the action taken by the player, this function returns
        # the value of the game (1 if the player has won, 0 if the game is a draw, and -1 if the player has lost)
        # and a boolean indicating whether the game has terminated
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        # Given the current player, this function returns the opponent player
        return -player
    
    def get_opponent_value(self, value):
        # Given the value of the game, this function returns the value from the perspective of the opponent
        return -value
    
    def change_perspective(self, state, player):
        # Given the current state of the board and the player, this function returns the state of the board
        # from the perspective of the opponent
        return state * player

In [None]:
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None) -> None:
        """
        A class representing a node in the Monte Carlo Tree Search (MCTS) algorithm.

        Parameters:
        - game: an instance of the game being played
        - args: a dictionary of arguments for the MCTS algorithm
        - state: the state of the game at this node
        - parent: the parent node of this node
        - action_taken: the action taken to reach this node from the parent node
        """
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        
        self.expandable_moves = game.get_valid_moves(state)
        self.children = []
        
        self.visit_count = 0
        self.value_sum = 0
        
    def is_fully_expanded(self):
        """
        Returns True if all possible actions from this node have been expanded into child nodes.
        """
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0
        
    def select_child(self):
        """
        Selects the child node with the highest Upper Confidence Bound (UCB) value.
        """
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        
        return best_child
    
    def get_ucb(self, child):
        """
        Calculates the UCB value for a given child node.
        """
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2 #value_sum can be negative so we add 1 to make it positive then divide by 2 to make it between 0 and 1
        return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)
    
    def expand(self):
        """
        Expands this node by adding a new child node with a randomly selected action.
        """
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0
        
        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, 1)
        child_state = self.game.change_perspective(child_state, player=-1)
        
        child = Node(self.game, self.args, child_state, self, action)
        self.children.append(child)
        
        return child
    
    def simulate(self):
        """
        Simulates a game from this node by randomly selecting actions until the game ends.
        Returns the value of the game from the perspective of the player who made the move that led to this node.
        """
        value, is_terminal = self.game.get_value_and_terminated(self.state, self.action_taken)
        value = self.game.get_opponent_value(value)
        
        if is_terminal:
            return value
        
        rollout_state = self.state.copy()
        rollout_player = 1
        while True:
            valid_moves = self.game.get_valid_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves==1)[0])
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
            value, is_terminal = self.game.get_value_and_terminated(rollout_state, action)
            if is_terminal:
                if rollout_player == -1:
                    value = self.game.get_opponent_value(value)
                return value
            rollout_player = self.game.get_opponent(rollout_player)
            
    def backpropagate(self, value):
        """
        Updates the visit count and value sum of this node and all its ancestors.
        """
        self.value_sum += value
        self.visit_count += 1
        
        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)
        
class MCTS:
    def __init__(self, game, args):
        """
        A class representing the Monte Carlo Tree Search (MCTS) algorithm.

        Parameters:
        - game: an instance of the game being played
        - args: a dictionary of arguments for the MCTS algorithm
        """
        self.game = game
        self.args = args
        
    def search(self, state):
        """
        Runs the MCTS algorithm from the given state and returns the best action to take.
        """
        root = Node(self.game, self.args, state)
        
        for search in range(self.args['num_searches']):
            node = root
            
            while node.is_fully_expanded():
                node = node.select_child()
                
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)
            
            if not is_terminal:
                node = node.expand()
                value = node.simulate()
                
            node.backpropagate(value)


In [3]:
tictactoe = TicTacToe()
player = 1

args = {
    
}

state = tictactoe.get_initial_state()

while True:
    print(state)
    valid_moves = tictactoe.get_valid_moves(state)
    print("valid_moves:", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])
    action = int(input(f"{player} action: "))
    
    if valid_moves[action] == 0:
        print("Invalid move")
        continue
    
    state = tictactoe.get_next_state(state, action, player)
    value, is_terminal = tictactoe.get_value_and_terminated(state, action)
    if is_terminal:
        print(state)
        if value == 1:
            print(player, " won")
        else:
            print("Draw")
        break

    player = tictactoe.get_opponent(player)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid_moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
[[0. 1. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid_moves: [0, 2, 3, 4, 5, 6, 7, 8]
Invalid move
[[0. 1. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid_moves: [0, 2, 3, 4, 5, 6, 7, 8]
[[ 0.  1. -1.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]]
valid_moves: [0, 3, 4, 5, 6, 7, 8]
[[ 0.  1. -1.]
 [ 0.  1.  0.]
 [ 0.  0.  0.]]
valid_moves: [0, 3, 5, 6, 7, 8]
[[ 0.  1. -1.]
 [ 0.  1. -1.]
 [ 0.  0.  0.]]
valid_moves: [0, 3, 6, 7, 8]
[[ 0.  1. -1.]
 [ 0.  1. -1.]
 [ 0.  1.  0.]]
1  won
