# AlphaZero MCTS Interactive Explainer

This notebook provides an interactive visualization of how AlphaZero's Monte Carlo Tree Search (MCTS) works.

## Key Concepts

AlphaZero's MCTS consists of three main phases:

1. **Selection**: Navigate from root to leaf using PUCT (Predictor + UCB) scores
2. **Expansion & Evaluation**: Add new child nodes to the leaf AND evaluate the position using the neural network. These are combined because the neural network returns both policy (for expansion) and value (for evaluation) in a single forward pass.
3. **Backup**: Propagate the evaluation back up the tree, updating visit counts and values

This differs from traditional MCTS which has 4 phases (Selection, Expansion, Simulation, Backup). AlphaZero replaces the random simulation phase with a neural network that provides both policy and value predictions simultaneously.

The PUCT formula used for selection is:

$$
PUCT(a) = Q(a) + c_{puct} \cdot P(a) \cdot \frac{\sqrt{N_{parent}}}{1 + N(a)}
$$

Where:
- $Q(a)$ = average value of action $a$ (exploitation term)
- $P(a)$ = prior probability from neural network policy
- $N(a)$ = visit count of action $a$
- $N_{parent}$ = visit count of parent node
- $c_{puct}$ = exploration constant (typically 1.0)

In [1]:
# Import required libraries
import anywidget
import traitlets
import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import json

## MCTS Tree Structure

First, let's define the tree structure that represents the MCTS search tree.

In [2]:
@dataclass
class MCTSTreeNode:
    """Represents a node in the MCTS tree for visualization"""

    node_id: int
    parent_id: Optional[int]
    action: Optional[
        int
    ]  # Action taken from parent to reach this node (0-8 for tic-tac-toe)
    board_state: List[int]  # Flattened 3x3 board: 0=empty, 1=X, -1=O
    player: int  # Current player to move: 1=X, -1=O

    # MCTS statistics
    visit_count: int = 0
    value_sum: float = 0.0
    prior: float = 0.0

    # For visualization
    x: float = 0.0
    y: float = 0.0
    children: List[int] = field(default_factory=list)

    @property
    def q_value(self) -> float:
        """Average value from this node's perspective"""
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count

    def ucb_score(self, parent_visits: int, c_puct: float = 1.0) -> float:
        """Calculate PUCT score from the parent's perspective

        The parent is choosing which child to explore. Since this is a zero-sum game:
        - self.q_value is from the CHILD's perspective (the player who just moved)
        - The PARENT wants to choose the child that's worst for the opponent
        - So we negate: if child has q_value=+0.8 (good for child's player),
          then from parent's view it's -0.8 (bad for parent's player)
        """
        if parent_visits == 0:
            return float("inf")

        # Exploitation term: negate because parent and child are opponents
        q = -self.q_value

        # Exploration term: encourage exploring less-visited nodes
        u = c_puct * self.prior * np.sqrt(parent_visits) / (1 + self.visit_count)

        return q + u

    def to_dict(self) -> dict:
        """Convert to dictionary for JSON serialization"""
        return {
            "id": self.node_id,
            "parent": self.parent_id,
            "action": self.action,
            "board": self.board_state,
            "player": self.player,
            "visits": self.visit_count,
            "value": round(self.q_value, 3),
            "prior": round(self.prior, 3),
            "x": round(self.x, 2),
            "y": round(self.y, 2),
            "children": self.children,
        }

In [3]:
class MCTSTree:
    """Manages the MCTS tree for visualization"""

    def __init__(self, initial_board: List[int], initial_player: int = 1):
        self.nodes: Dict[int, MCTSTreeNode] = {}
        self.next_id = 0
        self.c_puct = 1.0

        # Create root node
        self.root_id = self._create_node(
            parent_id=None,
            action=None,
            board_state=initial_board,
            player=initial_player,
            prior=1.0,
        )

        # History of MCTS iterations for step-by-step playback
        self.iteration_history: List[dict] = []

        # Step-by-step execution state
        self.current_step_state: Optional[dict] = None
        self.current_step_phase: int = (
            0  # 0=not started, 1=selection done, 2=expansion done, 3=backup done
        )

    def _create_node(
        self,
        parent_id: Optional[int],
        action: Optional[int],
        board_state: List[int],
        player: int,
        prior: float,
    ) -> int:
        """Create a new node and return its ID"""
        node_id = self.next_id
        self.next_id += 1

        node = MCTSTreeNode(
            node_id=node_id,
            parent_id=parent_id,
            action=action,
            board_state=board_state,
            player=player,
            prior=prior,
        )

        self.nodes[node_id] = node

        # Add to parent's children
        if parent_id is not None:
            self.nodes[parent_id].children.append(node_id)

        return node_id

    def get_valid_moves(self, board_state: List[int]) -> List[int]:
        """Get list of valid moves (empty positions)"""
        return [i for i, cell in enumerate(board_state) if cell == 0]

    def make_move(self, board_state: List[int], action: int, player: int) -> List[int]:
        """Create new board state after making a move"""
        new_board = board_state.copy()
        new_board[action] = player
        return new_board

    def is_terminal(self, board_state: List[int]) -> bool:
        """Check if game is over"""
        # Check rows, columns, diagonals
        board = np.array(board_state).reshape(3, 3)

        # Check rows and columns
        for i in range(3):
            if abs(board[i, :].sum()) == 3 or abs(board[:, i].sum()) == 3:
                return True

        # Check diagonals
        if (
            abs(board.diagonal().sum()) == 3
            or abs(np.fliplr(board).diagonal().sum()) == 3
        ):
            return True

        # Check if board is full
        if not any(cell == 0 for cell in board_state):
            return True

        return False

    def get_game_result(self, board_state: List[int], player: int) -> float:
        """Get game result from perspective of given player"""
        board = np.array(board_state).reshape(3, 3)

        # Check rows and columns
        for i in range(3):
            if board[i, :].sum() == 3 or board[:, i].sum() == 3:
                return 1.0 if player == 1 else -1.0
            if board[i, :].sum() == -3 or board[:, i].sum() == -3:
                return -1.0 if player == 1 else 1.0

        # Check diagonals
        if board.diagonal().sum() == 3 or np.fliplr(board).diagonal().sum() == 3:
            return 1.0 if player == 1 else -1.0
        if board.diagonal().sum() == -3 or np.fliplr(board).diagonal().sum() == -3:
            return -1.0 if player == 1 else 1.0

        return 0.0  # Draw

    def simulate_policy(self, board_state: List[int]) -> Tuple[List[float], float]:
        """Simulate neural network predictions (policy and value)"""
        valid_moves = self.get_valid_moves(board_state)

        # Simple heuristic policy (uniform with slight randomness)
        policy = np.zeros(9)
        if valid_moves:
            probs = np.random.dirichlet([1.0] * len(valid_moves))
            for move, prob in zip(valid_moves, probs):
                policy[move] = prob

        # Simple heuristic value (random for demonstration)
        value = np.random.uniform(-0.3, 0.3)

        return policy.tolist(), value

    def select_phase(self) -> Tuple[int, List[int]]:
        """Selection phase: traverse tree using UCB until reaching leaf"""
        path = [self.root_id]
        current_id = self.root_id

        while self.nodes[current_id].children:
            # Select child with highest UCB score
            current_node = self.nodes[current_id]
            parent_visits = current_node.visit_count

            best_child_id = max(
                current_node.children,
                key=lambda cid: self.nodes[cid].ucb_score(parent_visits, self.c_puct),
            )

            current_id = best_child_id
            path.append(current_id)

        return current_id, path

    def expand_phase(self, node_id: int) -> Optional[int]:
        """Expansion phase: add children to leaf node"""
        node = self.nodes[node_id]

        # Don't expand terminal nodes
        if self.is_terminal(node.board_state):
            return None

        # Get policy predictions
        policy, _ = self.simulate_policy(node.board_state)
        valid_moves = self.get_valid_moves(node.board_state)

        # Create child nodes for all valid moves
        new_children = []
        for action in valid_moves:
            new_board = self.make_move(node.board_state, action, node.player)
            child_id = self._create_node(
                parent_id=node_id,
                action=action,
                board_state=new_board,
                player=-node.player,  # Switch player
                prior=policy[action],
            )
            new_children.append(child_id)

        # Return first child for evaluation
        return new_children[0] if new_children else None

    def evaluate_phase(self, node_id: int) -> float:
        """Evaluation phase: get value from neural network or terminal state"""
        node = self.nodes[node_id]

        # Check if terminal
        if self.is_terminal(node.board_state):
            return self.get_game_result(node.board_state, node.player)

        # Get value from neural network
        _, value = self.simulate_policy(node.board_state)
        return value

    def backup_phase(self, path: List[int], value: float):
        """Backup phase: propagate value up the tree"""
        # Propagate from leaf to root
        for node_id in reversed(path):
            node = self.nodes[node_id]
            node.visit_count += 1
            node.value_sum += value
            # Flip value for opponent's perspective
            value = -value

    def run_iteration(self) -> dict:
        """Run one complete MCTS iteration and record it"""
        iteration_data = {"iteration": len(self.iteration_history), "phases": []}

        # Phase 1: Selection
        leaf_id, path = self.select_phase()
        iteration_data["phases"].append(
            {"name": "selection", "path": path, "leaf": leaf_id}
        )

        # Phase 2: Expansion & Evaluation (combined)
        # In AlphaZero, expansion and evaluation happen together because
        # the neural network returns both policy (for creating children) and value (for evaluation)
        # We evaluate the LEAF node (not its children), then expand it
        value = self.evaluate_phase(leaf_id)
        expanded_node_id = self.expand_phase(leaf_id)

        iteration_data["phases"].append(
            {
                "name": "expansion",
                "expanded_node": expanded_node_id,
                "new_children": self.nodes[leaf_id].children
                if expanded_node_id
                else [],
                "value": round(value, 3),
            }
        )

        # Phase 3: Backup
        self.backup_phase(path, value)
        iteration_data["phases"].append(
            {"name": "backup", "path": path, "value": round(value, 3)}
        )

        self.iteration_history.append(iteration_data)
        return iteration_data

    def start_step_by_step(self) -> dict:
        """Start a new step-by-step iteration"""
        self.current_step_state = {
            "iteration": len(self.iteration_history),
            "phases": [],
            "leaf_id": None,
            "path": None,
            "expanded_node_id": None,
            "value": None,
            "children_created": [],  # Track children created during expansion
        }
        self.current_step_phase = 0
        return {"status": "started", "current_phase": "ready"}

    def prev_step(self) -> dict:
        """Go back to previous phase in step-by-step iteration"""
        if self.current_step_state is None:
            return {"error": "No step-by-step iteration in progress."}

        if self.current_step_phase == 0:
            return {"error": "Already at the beginning."}

        elif self.current_step_phase == 1:
            # Back from selection to start
            self.current_step_state["phases"].pop()  # Remove selection phase
            self.current_step_state["leaf_id"] = None
            self.current_step_state["path"] = None
            self.current_step_phase = 0
            return {"status": "back_to_start", "current_phase": "ready"}

        elif self.current_step_phase == 2:
            # Back from expansion to selection
            # Need to remove the children that were created
            leaf_id = self.current_step_state["leaf_id"]
            children_to_remove = self.current_step_state.get("children_created", [])

            for child_id in children_to_remove:
                if child_id in self.nodes:
                    del self.nodes[child_id]

            # Clear children list from parent
            if leaf_id in self.nodes:
                self.nodes[leaf_id].children = []

            self.current_step_state["phases"].pop()  # Remove expansion phase
            self.current_step_state["expanded_node_id"] = None
            self.current_step_state["value"] = None
            self.current_step_state["children_created"] = []
            # Restore path to just the selection path
            self.current_step_state["path"] = (
                self.current_step_state["path"][
                    : self.current_step_state["path"].index(leaf_id) + 1
                ]
                if leaf_id in self.current_step_state.get("path", [])
                else []
            )
            self.current_step_phase = 1
            return {"status": "back_to_selection", "current_phase": "selection"}

        elif self.current_step_phase == 3:
            # Back from backup (complete) to expansion
            # Undo the backup
            path = self.current_step_state["path"]
            value = self.current_step_state["value"]

            # Reverse the backup by subtracting visits and values
            for node_id in reversed(path):
                node = self.nodes[node_id]
                node.visit_count -= 1
                node.value_sum -= value
                value = -value

            # Remove from history if it was added
            if (
                self.iteration_history
                and self.iteration_history[-1]["iteration"]
                == self.current_step_state["iteration"]
            ):
                self.iteration_history.pop()

            self.current_step_state["phases"].pop()  # Remove backup phase
            self.current_step_phase = 2
            return {"status": "back_to_expansion", "current_phase": "expansion"}

        return {"error": "Invalid state."}

    def next_step(self) -> dict:
        """Execute the next phase in the step-by-step iteration"""
        if self.current_step_state is None:
            return {
                "error": "No step-by-step iteration in progress. Call start_step_by_step() first."
            }

        if self.current_step_phase == 0:
            # Phase 1: Selection
            leaf_id, path = self.select_phase()
            self.current_step_state["leaf_id"] = leaf_id
            self.current_step_state["path"] = path
            self.current_step_state["phases"].append(
                {"name": "selection", "path": path, "leaf": leaf_id}
            )
            self.current_step_phase = 1
            return {
                "status": "selection_complete",
                "current_phase": "selection",
                "data": self.current_step_state["phases"][-1],
            }

        elif self.current_step_phase == 1:
            # Phase 2: Expansion & Evaluation
            leaf_id = self.current_step_state["leaf_id"]
            path = self.current_step_state["path"]

            # Evaluate the leaf node first (before expansion)
            value = self.evaluate_phase(leaf_id)

            # Track children before expansion
            children_before = self.nodes[leaf_id].children.copy()

            # Then expand the leaf node
            expanded_node_id = self.expand_phase(leaf_id)
            if expanded_node_id is not None:
                # Track new children created
                children_after = self.nodes[leaf_id].children
                self.current_step_state["children_created"] = [
                    c for c in children_after if c not in children_before
                ]
            else:
                self.current_step_state["children_created"] = []

            self.current_step_state["expanded_node_id"] = expanded_node_id
            self.current_step_state["value"] = value
            self.current_step_state["path"] = path

            self.current_step_state["phases"].append(
                {
                    "name": "expansion",
                    "expanded_node": expanded_node_id,
                    "new_children": self.nodes[leaf_id].children
                    if expanded_node_id
                    else [],
                    "value": round(value, 3),
                }
            )
            self.current_step_phase = 2
            return {
                "status": "expansion_complete",
                "current_phase": "expansion",
                "data": self.current_step_state["phases"][-1],
            }

        elif self.current_step_phase == 2:
            # Phase 3: Backup
            path = self.current_step_state["path"]
            value = self.current_step_state["value"]

            self.backup_phase(path, value)
            self.current_step_state["phases"].append(
                {"name": "backup", "path": path, "value": round(value, 3)}
            )

            # Save to history and reset
            iteration_data = {
                "iteration": self.current_step_state["iteration"],
                "phases": self.current_step_state["phases"],
            }
            self.iteration_history.append(iteration_data)

            self.current_step_phase = 3
            result = {
                "status": "iteration_complete",
                "current_phase": "backup",
                "data": self.current_step_state["phases"][-1],
            }

            # Reset for next iteration
            self.current_step_state = None
            self.current_step_phase = 0

            return result

        else:
            return {
                "error": "Iteration already complete. Call start_step_by_step() to begin a new one."
            }

    def calculate_layout(self):
        """Calculate x, y positions for all nodes using a tree layout"""
        # Simple layered tree layout
        if not self.nodes:
            return

        # BFS to assign levels
        levels = {self.root_id: 0}
        queue = [self.root_id]
        max_level = 0

        while queue:
            node_id = queue.pop(0)
            level = levels[node_id]
            max_level = max(max_level, level)

            for child_id in self.nodes[node_id].children:
                levels[child_id] = level + 1
                queue.append(child_id)

        # Group nodes by level
        nodes_by_level = {i: [] for i in range(max_level + 1)}
        for node_id, level in levels.items():
            nodes_by_level[level].append(node_id)

        # Assign positions
        y_spacing = 150

        for level, node_ids in nodes_by_level.items():
            num_nodes = len(node_ids)
            x_spacing = max(120, 800 / max(1, num_nodes - 1)) if num_nodes > 1 else 400
            start_x = 400 - (num_nodes - 1) * x_spacing / 2

            for i, node_id in enumerate(node_ids):
                self.nodes[node_id].x = start_x + i * x_spacing
                self.nodes[node_id].y = 50 + level * y_spacing

    def to_json(self) -> str:
        """Export tree to JSON for visualization"""
        self.calculate_layout()
        nodes_data = [node.to_dict() for node in self.nodes.values()]
        return json.dumps(
            {
                "nodes": nodes_data,
                "root": self.root_id,
                "iterations": self.iteration_history,
            }
        )

## Interactive Visualization Widget

Now let's create an anywidget component with an interactive SVG visualization.

In [4]:
class MCTSVisualizerWidget(anywidget.AnyWidget):
    _esm = """
    function render({ model, el }) {
        // Create container
        const container = document.createElement('div');
        container.style.cssText = 'font-family: Arial, sans-serif; padding: 20px; background: #f5f5f5; border-radius: 8px;';
        
        // Title
        const title = document.createElement('h2');
        title.textContent = 'AlphaZero MCTS Interactive Visualization';
        title.style.cssText = 'margin: 0 0 20px 0; color: #333;';
        container.appendChild(title);
        
        // Control panel
        const controls = document.createElement('div');
        controls.style.cssText = 'margin-bottom: 20px; display: flex; gap: 10px; flex-wrap: wrap; align-items: center;';
        
        // Run iterations button
        const runIterBtn = document.createElement('button');
        runIterBtn.textContent = 'Run 1 Iteration';
        runIterBtn.style.cssText = 'padding: 10px 20px; background: #4CAF50; color: white; border: none; border-radius: 4px; cursor: pointer; font-size: 14px; font-weight: bold;';
        runIterBtn.onmouseover = () => runIterBtn.style.background = '#45a049';
        runIterBtn.onmouseout = () => runIterBtn.style.background = '#4CAF50';
        runIterBtn.onclick = () => {
            model.set('run_iteration', model.get('run_iteration') + 1);
            model.save_changes();
        };
        controls.appendChild(runIterBtn);
        
        // Run 5 iterations button
        const run5Btn = document.createElement('button');
        run5Btn.textContent = 'Run 5 Iterations';
        run5Btn.style.cssText = 'padding: 10px 20px; background: #2196F3; color: white; border: none; border-radius: 4px; cursor: pointer; font-size: 14px; font-weight: bold;';
        run5Btn.onmouseover = () => run5Btn.style.background = '#0b7dda';
        run5Btn.onmouseout = () => run5Btn.style.background = '#2196F3';
        run5Btn.onclick = () => {
            model.set('run_iterations', 5);
            model.save_changes();
        };
        controls.appendChild(run5Btn);
        
        // Reset button
        const resetBtn = document.createElement('button');
        resetBtn.textContent = 'Reset Tree';
        resetBtn.style.cssText = 'padding: 10px 20px; background: #f44336; color: white; border: none; border-radius: 4px; cursor: pointer; font-size: 14px; font-weight: bold;';
        resetBtn.onmouseover = () => resetBtn.style.background = '#da190b';
        resetBtn.onmouseout = () => resetBtn.style.background = '#f44336';
        resetBtn.onclick = () => {
            model.set('reset_tree', model.get('reset_tree') + 1);
            model.save_changes();
        };
        controls.appendChild(resetBtn);
        
        // Separator
        const separator = document.createElement('div');
        separator.style.cssText = 'width: 2px; height: 30px; background: #ccc; margin: 0 10px;';
        controls.appendChild(separator);
        
        // Step-by-step controls
        const stepLabel = document.createElement('span');
        stepLabel.textContent = 'Step-by-Step:';
        stepLabel.style.cssText = 'font-weight: bold; color: #555;';
        controls.appendChild(stepLabel);
        
        const prevStepBtn = document.createElement('button');
        prevStepBtn.textContent = '‚Üê Previous Phase';
        prevStepBtn.style.cssText = 'padding: 8px 16px; background: #607D8B; color: white; border: none; border-radius: 4px; cursor: pointer; font-size: 13px; font-weight: bold;';
        prevStepBtn.disabled = true; // Initially disabled
        prevStepBtn.style.opacity = '0.5';
        prevStepBtn.style.cursor = 'not-allowed';
        prevStepBtn.onmouseover = () => {
            if (!prevStepBtn.disabled) prevStepBtn.style.background = '#455A64';
        };
        prevStepBtn.onmouseout = () => {
            if (!prevStepBtn.disabled) prevStepBtn.style.background = '#607D8B';
        };
        prevStepBtn.onclick = () => {
            if (!prevStepBtn.disabled) {
                model.set('prev_step', model.get('prev_step') + 1);
                model.save_changes();
            }
        };
        controls.appendChild(prevStepBtn);
        
        const nextStepBtn = document.createElement('button');
        nextStepBtn.textContent = 'Next Phase ‚Üí';
        nextStepBtn.style.cssText = 'padding: 8px 16px; background: #9C27B0; color: white; border: none; border-radius: 4px; cursor: pointer; font-size: 13px; font-weight: bold;';
        nextStepBtn.onmouseover = () => nextStepBtn.style.background = '#7B1FA2';
        nextStepBtn.onmouseout = () => nextStepBtn.style.background = '#9C27B0';
        nextStepBtn.onclick = () => {
            model.set('next_step', model.get('next_step') + 1);
            model.save_changes();
        };
        controls.appendChild(nextStepBtn);
        
        container.appendChild(controls);
        
        // Info panel
        const info = document.createElement('div');
        info.id = 'info-panel';
        info.style.cssText = 'margin-bottom: 15px; padding: 15px; background: white; border-radius: 4px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); font-size: 14px; color: #555;';
        info.innerHTML = '<strong>Total Iterations:</strong> 0 | <strong>Current Phase:</strong> All';
        container.appendChild(info);
        
        // Step status panel
        const stepStatus = document.createElement('div');
        stepStatus.id = 'step-status';
        stepStatus.style.cssText = 'margin-bottom: 15px; padding: 12px; background: #FFF3E0; border-left: 4px solid #FF9800; border-radius: 4px; font-size: 13px; color: #555; display: none;';
        stepStatus.innerHTML = '';
        container.appendChild(stepStatus);
        
        // SVG container
        const svgContainer = document.createElement('div');
        svgContainer.style.cssText = 'background: white; border-radius: 4px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); overflow: auto; max-height: 600px;';
        
        const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg');
        svg.setAttribute('width', '100%');
        svg.setAttribute('height', '600');
        svg.setAttribute('viewBox', '0 0 800 600');
        svg.style.cssText = 'display: block;';
        
        svgContainer.appendChild(svg);
        container.appendChild(svgContainer);
        
        // Legend
        const legend = document.createElement('div');
        legend.style.cssText = 'margin-top: 20px; padding: 15px; background: white; border-radius: 4px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);';
        legend.innerHTML = `
            <strong style="color: #333; font-size: 16px;">Legend</strong><br>
            <div style="margin-top: 10px; display: grid; grid-template-columns: repeat(2, 1fr); gap: 10px; font-size: 13px;">
                <div><span style="display: inline-block; width: 16px; height: 16px; background: #4CAF50; border-radius: 50%; margin-right: 8px; vertical-align: middle;"></span>Selection Path</div>
                <div><span style="display: inline-block; width: 16px; height: 16px; background: #2196F3; border-radius: 50%; margin-right: 8px; vertical-align: middle;"></span>Expansion & Evaluation</div>
                <div><span style="display: inline-block; width: 16px; height: 16px; background: #9C27B0; border-radius: 50%; margin-right: 8px; vertical-align: middle;"></span>Backup Path</div>
                <div><strong>N:</strong> Visit count</div>
                <div><strong>Q:</strong> Average value</div>
                <div><strong>P:</strong> Prior probability (from NN policy)</div>
                <div><strong>Value:</strong> Position evaluation (from NN value head)</div>
            </div>
        `;
        container.appendChild(legend);
        
        el.appendChild(container);
        
        // Helper functions
        function drawBoard(x, y, board, size = 45) {
            const g = document.createElementNS('http://www.w3.org/2000/svg', 'g');
            
            // Board background
            const bg = document.createElementNS('http://www.w3.org/2000/svg', 'rect');
            bg.setAttribute('x', x - size/2);
            bg.setAttribute('y', y - size/2);
            bg.setAttribute('width', size);
            bg.setAttribute('height', size);
            bg.setAttribute('fill', '#fafafa');
            bg.setAttribute('stroke', '#333');
            bg.setAttribute('stroke-width', '2');
            bg.setAttribute('rx', '4');
            g.appendChild(bg);
            
            // Grid lines
            const cellSize = size / 3;
            for (let i = 1; i < 3; i++) {
                // Vertical lines
                const vline = document.createElementNS('http://www.w3.org/2000/svg', 'line');
                vline.setAttribute('x1', x - size/2 + i * cellSize);
                vline.setAttribute('y1', y - size/2);
                vline.setAttribute('x2', x - size/2 + i * cellSize);
                vline.setAttribute('y2', y + size/2);
                vline.setAttribute('stroke', '#999');
                vline.setAttribute('stroke-width', '1');
                g.appendChild(vline);
                
                // Horizontal lines
                const hline = document.createElementNS('http://www.w3.org/2000/svg', 'line');
                hline.setAttribute('x1', x - size/2);
                hline.setAttribute('y1', y - size/2 + i * cellSize);
                hline.setAttribute('x2', x + size/2);
                hline.setAttribute('y2', y - size/2 + i * cellSize);
                hline.setAttribute('stroke', '#999');
                hline.setAttribute('stroke-width', '1');
                g.appendChild(hline);
            }
            
            // Draw pieces
            for (let i = 0; i < 9; i++) {
                const row = Math.floor(i / 3);
                const col = i % 3;
                const cx = x - size/2 + col * cellSize + cellSize/2;
                const cy = y - size/2 + row * cellSize + cellSize/2;
                
                if (board[i] === 1) {
                    // X
                    const size_x = cellSize * 0.5;
                    const line1 = document.createElementNS('http://www.w3.org/2000/svg', 'line');
                    line1.setAttribute('x1', cx - size_x/2);
                    line1.setAttribute('y1', cy - size_x/2);
                    line1.setAttribute('x2', cx + size_x/2);
                    line1.setAttribute('y2', cy + size_x/2);
                    line1.setAttribute('stroke', '#2196F3');
                    line1.setAttribute('stroke-width', '3');
                    line1.setAttribute('stroke-linecap', 'round');
                    g.appendChild(line1);
                    
                    const line2 = document.createElementNS('http://www.w3.org/2000/svg', 'line');
                    line2.setAttribute('x1', cx + size_x/2);
                    line2.setAttribute('y1', cy - size_x/2);
                    line2.setAttribute('x2', cx - size_x/2);
                    line2.setAttribute('y2', cy + size_x/2);
                    line2.setAttribute('stroke', '#2196F3');
                    line2.setAttribute('stroke-width', '3');
                    line2.setAttribute('stroke-linecap', 'round');
                    g.appendChild(line2);
                } else if (board[i] === -1) {
                    // O
                    const circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle');
                    circle.setAttribute('cx', cx);
                    circle.setAttribute('cy', cy);
                    circle.setAttribute('r', cellSize * 0.25);
                    circle.setAttribute('fill', 'none');
                    circle.setAttribute('stroke', '#f44336');
                    circle.setAttribute('stroke-width', '3');
                    g.appendChild(circle);
                }
            }
            
            return g;
        }
        
        function drawNode(node, highlight = null) {
            const g = document.createElementNS('http://www.w3.org/2000/svg', 'g');
            g.setAttribute('class', 'node');
            
            // Node circle
            const circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle');
            circle.setAttribute('cx', node.x);
            circle.setAttribute('cy', node.y);
            circle.setAttribute('r', '30');
            
            let fillColor = '#e8f5e9';
            let strokeColor = '#4CAF50';
            let strokeWidth = '2';
            
            if (highlight === 'selection') {
                fillColor = '#c8e6c9';
                strokeColor = '#4CAF50';
                strokeWidth = '4';
            } else if (highlight === 'expansion') {
                fillColor = '#bbdefb';
                strokeColor = '#2196F3';
                strokeWidth = '4';
            } else if (highlight === 'backup') {
                fillColor = '#e1bee7';
                strokeColor = '#9C27B0';
                strokeWidth = '4';
            }
            
            circle.setAttribute('fill', fillColor);
            circle.setAttribute('stroke', strokeColor);
            circle.setAttribute('stroke-width', strokeWidth);
            g.appendChild(circle);
            
            // Node stats text
            const text = document.createElementNS('http://www.w3.org/2000/svg', 'text');
            text.setAttribute('x', node.x);
            text.setAttribute('y', node.y + 50);
            text.setAttribute('text-anchor', 'middle');
            text.setAttribute('font-size', '11');
            text.setAttribute('font-family', 'monospace');
            text.setAttribute('fill', '#333');
            text.innerHTML = `N:${node.visits} Q:${node.value.toFixed(2)}`;
            g.appendChild(text);
            
            const text2 = document.createElementNS('http://www.w3.org/2000/svg', 'text');
            text2.setAttribute('x', node.x);
            text2.setAttribute('y', node.y + 62);
            text2.setAttribute('text-anchor', 'middle');
            text2.setAttribute('font-size', '10');
            text2.setAttribute('font-family', 'monospace');
            text2.setAttribute('fill', '#666');
            text2.innerHTML = `P:${node.prior.toFixed(2)}`;
            g.appendChild(text2);
            
            // Mini board
            const boardGroup = drawBoard(node.x, node.y, node.board, 40);
            g.appendChild(boardGroup);
            
            return g;
        }
        
        function drawEdge(fromNode, toNode, highlight = false) {
            const line = document.createElementNS('http://www.w3.org/2000/svg', 'line');
            line.setAttribute('x1', fromNode.x);
            line.setAttribute('y1', fromNode.y + 30);
            line.setAttribute('x2', toNode.x);
            line.setAttribute('y2', toNode.y - 30);
            line.setAttribute('stroke', highlight ? '#FF5722' : '#ccc');
            line.setAttribute('stroke-width', highlight ? '3' : '2');
            line.setAttribute('marker-end', 'url(#arrowhead)');
            return line;
        }
        
        function updateVisualization() {
            const data = JSON.parse(model.get('tree_data'));
            const currentPhase = model.get('current_phase');
            const iterationCount = data.iterations.length;
            const stepStatusText = model.get('step_status');
            
            // Update info panel
            info.innerHTML = `<strong>Total Iterations:</strong> ${iterationCount} | <strong>Current Phase:</strong> ${currentPhase.charAt(0).toUpperCase() + currentPhase.slice(1)}`;
            
            // Update step status panel and button states
            if (stepStatusText && stepStatusText.length > 0) {
                stepStatus.innerHTML = stepStatusText;
                stepStatus.style.display = 'block';
                
                // Enable/disable previous button based on status
                if (stepStatusText.includes('Ready for iteration') || stepStatusText.includes('‚ö†Ô∏è')) {
                    // At start of new iteration - can't go back
                    prevStepBtn.disabled = true;
                    prevStepBtn.style.opacity = '0.5';
                    prevStepBtn.style.cursor = 'not-allowed';
                } else {
                    // Can always go back if we're in the middle of an iteration
                    prevStepBtn.disabled = false;
                    prevStepBtn.style.opacity = '1';
                    prevStepBtn.style.cursor = 'pointer';
                }
            } else {
                stepStatus.style.display = 'none';
                prevStepBtn.disabled = true;
                prevStepBtn.style.opacity = '0.5';
                prevStepBtn.style.cursor = 'not-allowed';
            }
            
            // Clear SVG
            svg.innerHTML = '';
            
            // Add arrow marker definition
            const defs = document.createElementNS('http://www.w3.org/2000/svg', 'defs');
            const marker = document.createElementNS('http://www.w3.org/2000/svg', 'marker');
            marker.setAttribute('id', 'arrowhead');
            marker.setAttribute('markerWidth', '10');
            marker.setAttribute('markerHeight', '10');
            marker.setAttribute('refX', '9');
            marker.setAttribute('refY', '3');
            marker.setAttribute('orient', 'auto');
            const polygon = document.createElementNS('http://www.w3.org/2000/svg', 'polygon');
            polygon.setAttribute('points', '0 0, 10 3, 0 6');
            polygon.setAttribute('fill', '#ccc');
            marker.appendChild(polygon);
            defs.appendChild(marker);
            svg.appendChild(defs);
            
            // Create node lookup
            const nodeMap = {};
            data.nodes.forEach(node => {
                nodeMap[node.id] = node;
            });
            
            // Get last iteration for highlighting
            let highlightData = null;
            if (iterationCount > 0 && currentPhase !== 'all') {
                const lastIter = data.iterations[iterationCount - 1];
                highlightData = lastIter.phases.find(p => p.name === currentPhase);
            }
            
            // Draw edges first (so they appear behind nodes)
            data.nodes.forEach(node => {
                if (node.parent !== null) {
                    const parentNode = nodeMap[node.parent];
                    const edge = drawEdge(parentNode, node, false);
                    svg.appendChild(edge);
                }
            });
            
            // Highlight edges if needed
            if (highlightData) {
                if (highlightData.path) {
                    for (let i = 0; i < highlightData.path.length - 1; i++) {
                        const fromNode = nodeMap[highlightData.path[i]];
                        const toNode = nodeMap[highlightData.path[i + 1]];
                        const edge = drawEdge(fromNode, toNode, true);
                        svg.appendChild(edge);
                    }
                }
            }
            
            // Draw nodes
            data.nodes.forEach(node => {
                let highlight = null;
                
                if (highlightData) {
                    if (currentPhase === 'selection' && highlightData.path && highlightData.path.includes(node.id)) {
                        highlight = 'selection';
                    } else if (currentPhase === 'expansion' && highlightData.new_children && highlightData.new_children.includes(node.id)) {
                        highlight = 'expansion';
                    } else if (currentPhase === 'backup' && highlightData.path && highlightData.path.includes(node.id)) {
                        highlight = 'backup';
                    }
                }
                
                const nodeGroup = drawNode(node, highlight);
                svg.appendChild(nodeGroup);
            });
        }
        
        // Listen for data changes
        model.on('change:tree_data', updateVisualization);
        model.on('change:current_phase', updateVisualization);
        model.on('change:step_status', updateVisualization);
        
        // Initial render
        updateVisualization();
    }
    
    export default { render };
    """

    # Widget traits (state)
    tree_data = traitlets.Unicode("{}").tag(sync=True)
    run_iteration = traitlets.Int(0).tag(sync=True)
    run_iterations = traitlets.Int(0).tag(sync=True)
    reset_tree = traitlets.Int(0).tag(sync=True)
    current_phase = traitlets.Unicode("all").tag(sync=True)
    prev_step = traitlets.Int(0).tag(sync=True)
    next_step = traitlets.Int(0).tag(sync=True)
    step_status = traitlets.Unicode("").tag(sync=True)

    def __init__(self, initial_board=None):
        super().__init__()

        if initial_board is None:
            initial_board = [0] * 9  # Empty board

        self.tree = MCTSTree(initial_board)
        self.tree_data = self.tree.to_json()

        # Observe changes
        self.observe(self._on_run_iteration, "run_iteration")
        self.observe(self._on_run_iterations, "run_iterations")
        self.observe(self._on_reset, "reset_tree")
        self.observe(self._on_prev_step, "prev_step")
        self.observe(self._on_next_step, "next_step")

    def _on_run_iteration(self, change):
        """Run one MCTS iteration"""
        self.tree.run_iteration()
        self.tree_data = self.tree.to_json()
        self.current_phase = "all"
        self.step_status = ""

    def _on_run_iterations(self, change):
        """Run multiple MCTS iterations"""
        n = change["new"]
        if n > 0:
            for _ in range(n):
                self.tree.run_iteration()
            self.tree_data = self.tree.to_json()
            self.run_iterations = 0  # Reset counter
            self.current_phase = "all"
            self.step_status = ""

    def _on_reset(self, change):
        """Reset the tree"""
        initial_board = [0] * 9
        self.tree = MCTSTree(initial_board)
        self.tree_data = self.tree.to_json()
        self.current_phase = "all"
        self.step_status = ""

    def _on_prev_step(self, change):
        """Go back to previous step in step-by-step iteration"""
        # If we haven't started yet, can't go back
        if self.tree.current_step_state is None:
            self.step_status = (
                "‚ö†Ô∏è No iteration in progress. Click 'Next Phase' to start."
            )
            return

        result = self.tree.prev_step()

        if "error" in result:
            self.step_status = "‚ö†Ô∏è Cannot go back further."
        else:
            status = result["status"]
            phase = result["current_phase"]

            if status == "back_to_start":
                self.step_status = (
                    "Ready for iteration. Click 'Next Phase' to execute Selection."
                )
                self.current_phase = "all"
            elif status == "back_to_selection":
                self.step_status = "‚úì Selection complete. Click 'Next Phase' for Expansion & Evaluation."
                self.current_phase = "selection"
            elif status == "back_to_expansion":
                self.step_status = (
                    "‚úì Expansion & Evaluation complete. Click 'Next Phase' for Backup."
                )
                self.current_phase = "expansion"

            self.tree_data = self.tree.to_json()

    def _on_next_step(self, change):
        """Execute next step in step-by-step iteration"""
        # Auto-start new iteration if needed
        if self.tree.current_step_state is None or self.tree.current_step_phase == 3:
            self.tree.start_step_by_step()
            self.step_status = (
                "Ready for iteration. Click 'Next Phase' to execute Selection."
            )
            self.current_phase = "all"
            return

        result = self.tree.next_step()

        if "error" in result:
            # This shouldn't happen now, but keep as fallback
            self.step_status = "‚ö†Ô∏è Error: " + result.get("error", "Unknown error")
        else:
            status = result["status"]
            phase = result["current_phase"]

            if status == "selection_complete":
                self.step_status = "‚úì Selection complete. Click 'Next Phase' for Expansion & Evaluation."
                self.current_phase = "selection"
            elif status == "expansion_complete":
                self.step_status = (
                    "‚úì Expansion & Evaluation complete. Click 'Next Phase' for Backup."
                )
                self.current_phase = "expansion"
            elif status == "iteration_complete":
                self.step_status = "‚úì Backup complete. Iteration finished! Click 'Next Phase' to start next iteration."
                self.current_phase = "backup"

            self.tree_data = self.tree.to_json()

## Create and Display the Interactive Widget

Now let's create an instance of the widget with an initial game state.

In [5]:
# Create widget with empty initial board
widget = MCTSVisualizerWidget()
widget

MCTSVisualizerWidget(tree_data='{"nodes": [{"id": 0, "parent": null, "action": null, "board": [0, 0, 0, 0, 0, ‚Ä¶

## How to Use This Widget

### Full Iteration Mode
1. **Run 1 Iteration**: Click this button to run a complete MCTS iteration (all 3 phases) at once
2. **Run 5 Iterations**: Quickly run 5 complete iterations
3. **Reset Tree**: Start over with a fresh empty board

### Step-by-Step Mode (Educational)
For learning purposes, you can execute one phase at a time and navigate forward/backward:
1. **Next Phase ‚Üí**: Execute the next phase (Selection ‚Üí Expansion & Evaluation ‚Üí Backup ‚Üí next iteration)
   - Automatically starts a new iteration when the previous one completes
2. **‚Üê Previous Phase**: Go back to the previous phase to review changes
3. Watch the tree and highlighted nodes change after each phase!

### Understanding the Visualization

- **Nodes**: Each circle represents a game state
  - The mini board inside shows the current position
  - **N**: Visit count (how many times this node was visited)
  - **Q**: Average value (expected outcome from this position)
  - **P**: Prior probability (neural network's initial prediction)

- **Colors**:
  - üü¢ **Green**: Selection phase (nodes traversed using PUCT)
  - üîµ **Blue**: Expansion & Evaluation phase (new nodes created + position evaluated via neural network)
  - üü£ **Purple**: Backup phase (values propagated up the tree)

- **Board Symbols**:
  - **Blue X**: Player 1 (X)
  - **Red O**: Player -1 (O)

### Try Different Starting Positions

In [6]:
# Example: Start with a partially played game
# Board layout:
# X | O | -
# ---------
# - | X | -
# ---------
# O | - | -

partial_board = [
    1,
    -1,
    0,  # Row 0: X, O, empty
    0,
    1,
    0,  # Row 1: empty, X, empty
    -1,
    0,
    0,  # Row 2: O, empty, empty
]

widget_partial = MCTSVisualizerWidget(initial_board=partial_board)
widget_partial

MCTSVisualizerWidget(tree_data='{"nodes": [{"id": 0, "parent": null, "action": null, "board": [1, -1, 0, 0, 1,‚Ä¶

## Deep Dive: The AlphaZero MCTS Algorithm

### Phase 1: Selection

Starting from the root, we traverse the tree by repeatedly selecting the child with the highest PUCT score:

```python
PUCT(a) = Q(a) + c_puct √ó P(a) √ó ‚àö(N_parent) / (1 + N(a))
```

- **Q(a)**: Exploitation term - prefer actions that have worked well
  - **Important**: Q is **negated** from the child's perspective! 
  - Each node stores its value from its own perspective
  - But the **parent** is choosing, and wants what's best for itself (worst for opponent)
  - Example: If child has Q=+0.8 (good for child), parent sees it as Q=-0.8 (bad for parent)
  
- **U(a) = c_puct √ó P(a) √ó ‚àö(N_parent) / (1 + N(a))**: Exploration bonus
  - Higher for actions with high prior P(a) from the neural network
  - Higher for actions visited less (low N(a))
  - Grows with parent visits (‚àöN_parent)

### Phase 2: Expansion & Evaluation (Combined)

When we reach a leaf node (not yet expanded), we make a single neural network forward pass that returns both:

1. **Value (v)**: Position evaluation from current player's perspective at the **leaf node**
   - For terminal states: Direct game outcome (+1 win, -1 loss, 0 draw)
   - For non-terminal states: Use neural network value head prediction
   - **Important**: We evaluate the leaf node's position, not its children

2. **Policy (œÄ)**: Probability distribution over actions
   - Create child nodes for all legal actions from the leaf
   - Assign prior probabilities P(a) from the policy to each child
   
The neural network is called once on the leaf node's position, returning both the value (how good is this position) and policy (which moves look promising). This is more efficient than traditional MCTS which requires separate expansion and random rollout simulation phases.

### Phase 3: Backup

Propagate the value back up the path from leaf to root:
1. Start with the value obtained from the leaf node evaluation
2. For each node in the path (from leaf to root):
   - Increment visit count: `N ‚Üê N + 1`
   - Update value sum: `W ‚Üê W + v`
   - Flip sign for parent: `v ‚Üê -v` (opponent's perspective)

**Example**: If the leaf gets value +0.5 (good for current player), its parent gets -0.5 (bad from opponent's view), grandparent gets +0.5, etc.

After running many MCTS iterations (simulations), select the action with highest visit count as the final move!

---

### Why is Q Negated in the PUCT Formula?

This is a common point of confusion! Here's a detailed explanation:

**The Core Issue**: Each node stores statistics from **its own player's perspective**, but during selection, the **parent** is choosing which child to visit.

**Concrete Example**:
```
        Root (X's turn, N=100)
             |
      +------+------+------+
      |      |      |      |
    Child1 Child2 Child3  (All are O's turn)
    N=20   N=30   N=10
    Q=+0.8 Q=-0.2 Q=+0.3   (from O's perspective)
```

When Root (Player X) is selecting which child to explore:
- **Child1**: O thinks it's great (Q=+0.8), so X should think it's bad ‚Üí **Q=-0.8** from X's view
- **Child2**: O thinks it's bad (Q=-0.2), so X should think it's good ‚Üí **Q=+0.2** from X's view  
- **Child3**: O thinks it's decent (Q=+0.3), so X should think it's poor ‚Üí **Q=-0.3** from X's view

**Therefore**: We negate the child's Q-value to get the parent's perspective!

```python
PUCT(Child1) = -0.8 + exploration_bonus  # Bad for X, avoid unless exploration says otherwise
PUCT(Child2) = +0.2 + exploration_bonus  # Good for X, prefer this!
PUCT(Child3) = -0.3 + exploration_bonus  # Mediocre for X
```

This negation happens **during selection only**. During backup, we flip signs as values propagate up the tree to maintain each node's own perspective.

## Export to Web

To use this visualization in a standalone webpage, you can extract the JavaScript code from the widget and integrate it with your HTML. The `_esm` string contains the complete rendering logic.

The visualization is fully self-contained and uses only standard SVG and JavaScript, making it easy to embed anywhere!