In [4]:
import random, math
import import_ipynb

In [7]:
from game_envs import TwoPlayerGameEnv, NimEnv
from game_algorithms import minimax_pure, minimax_ab

In [None]:
# -------------------- MCTS with Tree Reuse -------------------- #
class MCTSNode:
    def __init__(self, state, parent=None, move=None, untried_moves=None):
        self.state = state
        self.parent = parent
        self.move = move
        self.children = {}
        self.untried_moves = list(untried_moves) if untried_moves else []
        self.visits = 0
        self.value_sum = 0.0

    @property
    def player_to_move(self):
        return self.state[1]

    def add_child(self, move, child_state, child_untried):
        child = MCTSNode(child_state, parent=self, move=move, untried_moves=child_untried)
        self.children[move] = child
        if move in self.untried_moves:
            self.untried_moves.remove(move)
        return child

    def best_child(self,criteron='value'):
        if criteron == 'value':
            if self.player_to_move == 1:
                return max(self.children.values(), key=lambda n: n.value_sum/n.visits)
            else:
                return min(self.children.values(), key=lambda n: n.value_sum/n.visits)
        elif criteron == 'visits': 
            return max(self.children.values(), key=lambda n: n.visits)

class MCTS:
    def __init__(self, env: TwoPlayerGameEnv, exploration_c=math.sqrt(2),criterion='visits'):
        self.env = env
        self.C = exploration_c
        self.root = None
        self.original_root = None
        self.criteron = criterion

     # NEW: pretty-print the tree from the original root
    def render_tree(self, node=None, indent="", depth=None):
        if node is None:
            node = self.original_root
        if node is None:
            print("No tree available")
            return
        avg_val = node.value_sum / node.visits if node.visits > 0 else 0
        print(f"{indent}Move={node.move}, Player={node.player_to_move}, "
              f"Visits={node.visits}, ValueSum={node.value_sum:.2f}, Avg={avg_val:.2f}")
        if depth == 0:
            print(indent + "   ...")
            return
        for move, child in node.children.items():
            self.render_tree(child, indent + "   ", None if depth is None else depth - 1)
   
    def _ucb_score(self, parent, child):
        if child.visits == 0:
            return float('inf')
        mean_value = child.value_sum / child.visits
        # Exploitation term should be from the perspective of the player who
        # is choosing among the children (the parent node's player).
        # If parent player is Player 1 (maximizer), higher mean_value is better.
        # If parent player is Player 2 (minimizer), lower mean_value is better,
        # so we flip the sign.
        parent_player = parent.player_to_move
        if parent_player == 1:
            exploitation = mean_value
        else:
            exploitation = -mean_value
        exploration = self.C * math.sqrt(math.log(max(1, parent.visits)) / child.visits)
        return exploitation + exploration


    def _select(self, node):
        while True:
            is_leaf, _ = self.env.is_terminal(node.state)
            if is_leaf or node.untried_moves:
                return node
            node = max(node.children.values(), key=lambda ch: self._ucb_score(node, ch))

    def _expand(self, node):
        if not node.untried_moves:
            return node
        move = random.choice(node.untried_moves)
        next_state = self.env.make_move(node.state, move)
        return node.add_child(move, next_state, self.env.moves(next_state))

    def _simulate(self, state):
        while True:
            is_leaf, value = self.env.is_terminal(state)
            if is_leaf:
                return value
            moves = self.env.moves(state)
            if not moves:
                return 0
            state = self.env.make_move(state, random.choice(moves))

    def _backprop(self, node, reward):
        while node is not None:
            node.visits += 1
            node.value_sum += reward
            node = node.parent

    def search(self, root_state, n_simulations=1000):
        if self.root is None or self.root.state != root_state:
            self.root = MCTSNode(root_state, untried_moves=self.env.moves(root_state))
            if self.original_root is None:   # store the very first root
                self.original_root = self.root
            

        for _ in range(n_simulations):
            node = self._select(self.root)
            node = self._expand(node)
            reward = self._simulate(node.state)
            self._backprop(node, reward)

        if not self.root.children:
            return None
        return self.root.best_child(criteron=self.criteron).move

    def update_root(self, move):
        if self.root and move in self.root.children:
            self.root = self.root.children[move]
            self.root.parent = None
        else:
            self.root = None
