In [2]:
import numpy as np

import math

import torch

import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0) # set the seed as 0

from tqdm.notebook import trange

import random

## Monte Carlo Tree Search (MCTS) Implementation Documentation

### Overview

The MCTS algorithm is a decision-making algorithm that uses Monte Carlo simulations to build a tree of possible moves in a game, searching for the most promising actions to take. This implementation consists of two classes: `Node` representing a node in the MCTS tree, and `MCTS` implementing the MCTS algorithm.

### Node Class

### Key Decisions

#### 1. Node Expansion

Nodes are expanded based on the policy predicted by the neural network model. Each child represents a possible action. This decision allows the MCTS algorithm to explore potential moves and build a comprehensive tree of possibilities.

#### 2. Selection Strategy

Nodes are selected based on the Upper Confidence Bound (UCB) score, striking a balance between exploration and exploitation. This strategy ensures that promising nodes are prioritized for further exploration while still considering less-explored options.

#### 3. Simulation

Simulations involve random rollouts from a node to estimate the value of a state. This decision introduces an element of randomness, allowing the algorithm to assess the potential outcomes of different actions.

#### 4. Backpropagation

The value is backpropagated up the tree to update visit counts and value estimates. Backpropagation ensures that the information gathered during simulations influences the overall understanding of the state values throughout the tree.

#### 5. Terminal States Handling

Terminal states are identified during the simulation phase, and values are flipped to account for the perspective of the opponent. This handling of terminal states ensures accurate value assessments, considering the game outcome from the player's perspective.


### MCTS with the neural network

In [3]:
class Node:
    def __init__(self, game, args, state, player, parent=None, action_taken=None, prior=0): # =None beacuse of the root node
        self.game = game # The game object
        self.args = args  # The arguments of the game
        self.state = state # The state of the game at this node
        self.parent = parent # The parent node of this node
        self.action_taken = action_taken # The action that led to this node
        self.player = player # The player who made the action
        self.prior = prior # The probability of the action taken

        self.children = [] # The children of this node

        self.visit_count = 0 # The number of times this node has been visited
        self.value_sum = 0 # The sum of the values of the children of this node
        
        # print(f"Initialized Node: Action={action_taken}, Player={player}, Parent={parent}")

    # Check if the node is fully expanded
    def is_fully_expanded(self):
        return len(self.children) > 0

    # Selection phase
    def select(self):
        # Select the child with the highest ucb score and return it 
        best_child = None 
        best_ucb = -np.inf # -inf because we want to maximize the ucb score

        # Iterate over the children of this node
        for child in self.children:
            ucb = self.get_ucb(child) # Get the ucb score of the child
            if ucb > best_ucb: # If the ucb score is better than the best ucb score
                best_child = child  # Set the best child to this child
                best_ucb = ucb # Set the best ucb score to this ucb score
                
        # print(f"Selected Node: Action={best_child.action_taken}, UCB={best_ucb}")
        return best_child

    # Get the ucb score of a child
    def get_ucb(self, child):
        if child.visit_count == 0: # If the child has not been visited
            q_value = 0 # Set the q value to 0
            # return float('inf')
        else:
            # 1- beacuse the next player to make a move is our opponent so we want to put him on a bad situation therefor the value is close to 0
            # +1) / 2 is to become a probability, which means the range is [0,1], before it was [-1,1]
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2 
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior #the formula of ucb
        # exploitation = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        # exploration = self.args['C'] * math.sqrt(math.log(self.visit_count) / (child.visit_count + 1))
        # return exploitation + exploration * child.prior
        
    # Expansion phase
    def expand(self, policy):
        # Iterate over the policy
        for action, prob in enumerate(policy):
            if prob > 0:
                # Create a child node with the action and prior probability
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1) # 1 because the player is 1 
                child_state = self.game.change_perspective(child_state, -1) # Change the perspective to the opponent
        
                child = Node(self.game, self.args, child_state, self.player, self, action, prob) # Create the child node
                self.children.append(child) # Add the child node to the children of this node

    # Simulation phase
    def simulate(self):
        # Get the value of the state and if it is terminal or not 
        value, winner, is_terminal = self.game.get_value_and_terminated(self.state, self.action_taken, self.player)
        value = -value # -value because we want to maximize the value

        if is_terminal:
            return value # If the state is terminal return the value

        rollout_state = self.state.copy() # Copy the state
        rollout_player = 1 # Set the player to 1 
        while True:
            valid_moves = self.game.get_valid_moves(rollout_state, self.player) # Get the valid moves
            action = np.random.choice(valid_moves) # Choose a random valid move
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player) # Get the next state
            value, winner, is_terminal = self.game.get_value_and_terminated(rollout_state, action, self.player) # Get the value of the state and if it is terminal or not
            
            if is_terminal:
                if rollout_player == -1:
                    value = -value # -value because we want to maximize the value
                return value 
            rollout_player = -rollout_player # Change the player

    # Backpropagation phase
    def backpropagate(self, value):
        self.value_sum += value # Add the value to the value sum
        self.visit_count += 1 # Increment the visit count
        
        value = -value # -value because we want to maximize the value

        if self.parent is not None: # If the node has a parent
            self.parent.backpropagate(value) # Backpropagate to the parent
        # print(f"Backpropagating Node: Action={self.action_taken}, New Visit Count={self.visit_count}, Value Sum={self.value_sum}")


# The MCTS class 
class MCTS:
    def __init__(self, game, args, player, model):
        self.game = game
        self.args = args
        self.player = player
        self.model = model # The model that predicts the policy and value
        print("Initialized MCTS")


    @torch.no_grad() 

    # Search phase
    def search(self, state):
        root = Node(self.game, self.args, state, self.player) # Create the root node
        
        for search in range(self.args['num_searches']): # Iterate over the number of searches
            node = root # Set the node to the root node

            # Selection phase
            while node.is_fully_expanded():
                node = node.select() # Select the best child node
                
            # Check if the node is terminal and backpropagate immediately if it is
            value, winner, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken, self.player)
            value = -value # -value because we want to maximize the value
    
            # Check if the node is terminate and backpropagate immediately if not we expand and simulate
            if not is_terminal: 
                # Expansion phase
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state)).unsqueeze(0) # Unsqueeze to add a batch dimension
                )
                # Expand the node with the policy and simulate
                policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy() # Softmax the policy and convert it to a numpy array
                valid_moves = self.game.get_mask(node.state, node.player) # Get the valid moves
               
                policy *= valid_moves # Mask the policy
                policy /= np.sum(policy) # Normalize the policy
               
                value = value.item() # Get the value as a python number 
                node.expand(policy) # Expand the node
                
            # Backpropagation phase
            node.backpropagate(value) # Backpropagate the value
        
        # Get the action probabilities of the root node
        action_probs = np.zeros(self.game.action_size + 1) # +1 because of the pass action
        # print("Final Root Children States and Visit Counts:")
        for child in root.children: # Iterate over the children of the root node
            # print(f"Root Child: Action={child.action_taken}, Visit Count={child.visit_count}")
            action_probs[child.action_taken] = child.visit_count # Set the action probability to the visit count of the child
        action_probs /= np.sum(action_probs) # Normalize the action probabilities
        return action_probs

### MCTS without the neural network

In [4]:
# WITHOUT THE NEURAL NETWORK

# class Node:
#     def __init__(self, game, args, state, player, parent=None, action_taken=None): # =None beacuse of the root node
#         self.game = game
#         self.args = args
#         self.state = state
#         self.parent = parent
#         self.action_taken = action_taken
#         self.player = player

#         self.children = []
#         self.expandable_moves = game.get_valid_moves(state, self.player)  #list

#         self.visit_count = 0
#         self.value_sum = 0

#     #for the expansion
#     def is_fully_expanded(self):
#         return np.sum(self.expandable_moves) == 0 and len(self.children) > 0

#     #for the selection
#     def select(self):
#         #look of all of your children and for each child we calculate the ucb score and choose the one with the best score
#         best_child = None
#         best_ucb = -np.inf

#         for child in self.children:
#             ucb = self.get_ucb(child)
#             if ucb > best_ucb:
#                 best_child = child
#                 best_ucb = ucb
                
#         return best_child

#     # calculate the ucb score of a node
#     def get_ucb(self, child):
#         # 1- beacuse the next player to make a move is our opponent so we want to put him on a bad situation therefor the value is close to 0
#         # +1) / 2 is to become a probability, which means the range is [0,1], before it was [-1,1]
#         q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2 
#         return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count) #the formula of ucb

#     #expansion
#     def expand(self):
#         #== 1 because 1 means the move is legal
#         #choose randommly an indice of a move to expand
#         action = np.random.choice(self.expandable_moves)
#         # print(self.expandable_moves)
#         #make this move not expandable anymore
#         self.expandable_moves.remove(action)
#         # print(action,":" ,self.expandable_moves)
#         child_state = self.state.copy()
#         child_state = self.game.get_next_state(child_state, action, 1) #we never change the player, we flip the state arround 
#         child_state = self.game.change_perspective(child_state, -1)

#         child = Node(self.game, self.args, child_state, self.player, None, action)
#         #append the node
#         self.children.append(child)
#         return child

#     #simulation
#     def simulate(self):
#         #verify if it is terminal
#         value, winner, is_terminal = self.game.get_value_and_terminated(self.state, self.action_taken, self.player)
#         #flip arround
#         value = -value

#         if is_terminal:
#             return value

#         rollout_state = self.state.copy()
#         rollout_player = 1
#         while True:
#             valid_moves = self.game.get_valid_moves(rollout_state, self.player)
#             # if action is not None and action in valid_moves:
#             #     valid_moves.remove(action)
#             # print(valid_moves)    
#             action = np.random.choice(valid_moves)
#             rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
#             # print(rollout_state, action)
#             value, winner, is_terminal = self.game.get_value_and_terminated(rollout_state, action, self.player)
            
#             if is_terminal:
#                 if rollout_player == -1:
#                     value = -value
#                 return value
#             #flip the player
#             rollout_player = -rollout_player

#     #backpropagation
#     def backpropagate(self, value):
#         self.value_sum += value
#         self.visit_count += 1

#         value = -value
#         if self.parent is not None:
#             print(self.parent)
#             self.parent.backpropagate(value)


      
# class MCTS:
#     def __init__(self, game, args, player):
#         self.game = game
#         self.args = args
#         self.player = player

#     def search(self, state):
#         #define the root node
#         root = Node(self.game, self.args, state, self.player)
        
#         #iterations
#         for search in range(self.args['num_searches']):
#             node = root
#             #selection phase
#             while node.is_fully_expanded():
#                 node = node.select()
                
#             #check if the node selected is a terminal one 
#             value, winner, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken, self.player)
#             value = -value 
    
#             # check if the node is terminate and backpropagate immediately if not we expand and simulate
#             if not is_terminal: 
#                 #expansion phase
#                 node = node.expand()
#                 #simulations phase
#                 value = node.simulate()
                
#             #backpropagation phase
#             node.backpropagate(value)
        
#         # return the distibution of visit_counts
#         action_probs = np.zeros(self.game.action_size + 1)
#         for child in root.children:
#             action_probs[child.action_taken] = child.visit_count
#         action_probs /= np.sum(action_probs)
#         return action_probs[0]