# αMCTS

**Game trees** are a way to represent game states and actions. 

Each node in the tree represents a game state, and each edge represents a possible action that can be taken from that state.

In [1]:
import math


def draw_tic_tac_toe(board, size=300, padding=20):
    """
    Draw a tic-tac-toe board as SVG
    """
    cell_size = (size - padding * 2) / 3
    svg = f'<svg width="{size}" height="{size}" style="display:block;">'
    # Background
    svg += f'<rect x="{padding}" y="{padding}" width="{size - padding * 2}" height="{size - padding * 2}" fill="#fafafa00" stroke-width="0" rx="4"></rect>'
    # Grid lines
    for i in range(1, 3):
        # Vertical
        x = padding + i * cell_size
        svg += f'<line x1="{x}" y1="{padding}" x2="{x}" y2="{size - padding}" stroke="#666" stroke-width="3"></line>'
        # Horizontal
        y = padding + i * cell_size
        svg += f'<line x1="{padding}" y1="{y}" x2="{size - padding}" y2="{y}" stroke="#666" stroke-width="3"></line>'
    # Marks
    for i in range(9):
        row = i // 3
        col = i % 3
        cx = padding + col * cell_size + cell_size / 2
        cy = padding + row * cell_size + cell_size / 2
        val = board[i] if i < len(board) else 0
        if val == 1:
            size_x = cell_size * 0.5
            svg += f'<line x1="{cx - size_x / 2}" y1="{cy - size_x / 2}" x2="{cx + size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="3" stroke-linecap="round"></line>'
            svg += f'<line x1="{cx + size_x / 2}" y1="{cy - size_x / 2}" x2="{cx - size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="3" stroke-linecap="round"></line>'
        elif val == -1:
            svg += f'<circle cx="{cx}" cy="{cy}" r="{cell_size * 0.25}" fill="none" stroke="#DA7A09" stroke-width="3"></circle>'
    svg += "</svg>"
    return svg


def draw_game_tree(
    root_node,
    board_size=120,
    board_padding=12,
    horizontal_spacing=160,
    vertical_spacing=250,
    selected_node=None,
):
    """
    Draw a game tree with root state and its children.
    Supports recursive tree structure.

    Args:
        root_node: Dictionary with 'state', 'children' (list of dicts), and optional 'action'
        board_size: Size of each board
        board_padding: Padding for each board
        selected_node: Optional node dictionary to highlight with a circle

    Returns:
        HTML string with the complete game tree visualization
    """

    # Calculate tree dimensions and positions recursively
    def get_subtree_width(node):
        if not node.get("children"):
            return horizontal_spacing
        return sum(get_subtree_width(child) for child in node["children"])

    positions = {}

    def assign_positions(node, x, y):
        positions[id(node)] = (x, y)
        if not node.get("children"):
            return

        children = node["children"]
        total_width = sum(get_subtree_width(child) for child in children)
        start_x = x - total_width / 2

        current_x = start_x
        for child in children:
            child_width = get_subtree_width(child)
            assign_positions(child, current_x + child_width / 2, y + vertical_spacing)
            current_x += child_width

    tree_width = get_subtree_width(root_node) + 20

    def get_depth(node):
        if not node.get("children"):
            return 1
        return 1 + max(get_depth(child) for child in node["children"])

    tree_height = (get_depth(root_node) * vertical_spacing + board_size) - 200

    root_x = tree_width / 2
    root_y = 40
    assign_positions(root_node, root_x, root_y)

    html = f'<div style="width:{tree_width}px;height:{tree_height}px;position:relative;margin:20px auto;">'
    svg_connections = f'<svg width="{tree_width}" height="{tree_height}" style="position:absolute;top:0;left:0;pointer-events:none;">'
    svg_connections += '<defs><marker id="arrow" viewBox="0 0 10 10" refX="6" refY="5" markerWidth="6" markerHeight="6" orient="auto"><path d="M 0 0 L 10 5 L 0 10 z" fill="#999"></path></marker></defs>'

    nodes_html = ""

    def draw_node_recursive(node, depth=0):
        nonlocal svg_connections, nodes_html
        x, y = positions[id(node)]

        # Draw node
        delay = depth * 0.5
        node_svg = draw_tic_tac_toe(
            node["state"], size=board_size, padding=board_padding
        )
        is_selected = selected_node is node

        if is_selected:
            highlight_padding = 0
            highlight_size = board_size + highlight_padding * 2
            highlight_left = x - highlight_size / 2
            highlight_top = y - highlight_padding
            nodes_html += f'<div style="position:absolute;left:{highlight_left}px;top:{highlight_top}px;width:{highlight_size}px;height:{highlight_size}px;border-radius:50%;box-sizing:border-box;pointer-events:none;z-index:1;border:2px solid var(--text-lite);"></div>'

        nodes_html += f'<div class="" style="position:absolute;left:{x - board_size / 2}px;top:{y}px;z-index:2;">{node_svg}</div>'

        if node.get("children"):
            for child in node["children"]:
                cx, cy = positions[id(child)]

                # Edge
                start_y = y + board_size
                end_y = cy

                # Draw line
                svg_connections += f'<line x1="{x}" y1="{start_y}" x2="{cx}" y2="{end_y}" stroke="#999" stroke-width="2" marker-end="url(#arrow)"  ></line>'

                # Label
                if "action" in child:
                    # Midpoint for label
                    label_x = x + (cx - x) * 0.4
                    label_y = start_y + (end_y - start_y) * 0.4

                    text = f"{child['action']}"
                    text_width = 28

                    # White background for readability
                    svg_connections += f'<rect x="{label_x - text_width / 2}" y="{label_y - 15}" width="{text_width}" height="{text_width}" fill="#999" fill-opacity="1" rx="20" ry="20" />'

                    # text
                    svg_connections += f'<text x="{label_x}" y="{label_y - 10 + text_width / 2}" text-anchor="middle" fill="#fff" font-size="16" font-family="Noto Sans" >{text}</text>'

                draw_node_recursive(child, depth + 1)

    draw_node_recursive(root_node)

    svg_connections += "</svg>"
    html += svg_connections
    html += nodes_html
    html += "</div>"

    return html


# Example usage
from IPython.display import HTML, display

# Root state: O's turn to play
root_state = [
    0,
    1,
    0,  # Row 1: _, X, _
    0,
    1,
    -1,  # Row 2: _, X, O
    0,
    -1,
    1,
]  # Row 3: _, O, X

# Find all possible actions (empty cells)
possible_actions = [i for i in range(9) if root_state[i] == 0]

# Create child states for each possible action
child_states = []
for action in possible_actions:
    new_state = root_state.copy()
    new_state[action] = -1  # O plays (player -1)
    child_states.append({"action": action, "state": new_state})

# Add grandchildren to action 0 (second level of the tree)
# After O plays at position 0, it's X's turn
action_0_state = child_states[0]["state"]
action_0_possible_moves = [i for i in range(9) if action_0_state[i] == 0]

grandchildren = []
for action in action_0_possible_moves:
    new_state = action_0_state.copy()
    new_state[action] = 1  # X plays (player 1)
    grandchildren.append({"action": action, "state": new_state})

# Attach grandchildren to the first child (action 0)
child_states[0]["children"] = grandchildren

# Create root node
root_node = {"state": root_state, "children": child_states}

# Generate and display the game tree
tree_html = draw_game_tree(root_node)
display(HTML(tree_html))


**Monte Carlo Tree Search (MCTS)** is a method for building game trees by simulating random games from the current state. The basic idea is to use random simulations to estimate the value of each action, and then use these estimates to guide the search for the best action.

AlphaZero uses a modified version of MCTS. The key difference is that instead of using random simulations, **αMCTS** uses a neural network to evaluate the value of each state and the probability of each action. This allows αMCTS to make more informed decisions about which actions to explore.

Three main steps in αMCTS:
1. Select
2. Expand & Evaluate
3. Backup

```mermaid
---
config:
  theme: 'base'
  themeVariables:
    primaryColor: '#ffffff00'
    primaryTextColor: '#4D5461'
    primaryBorderColor: '#096bda'
    lineColor: '#4D5461'
    secondaryColor: '#006100'
    tertiaryColor: '#4D5461'
---
stateDiagram-v2
s1: Select
s2: Expand & Evaluate
s3: Backup
[*] --> s1
s1 --> s2
s2 --> s3
s3 --> s1 : # simulations
s3 --> [*]
```

Given a game state, αMCTS incrementally builds a search tree by exploring the actions available from that state. The current position is the root; repeated simulations expand the tree and update estimates of each action’s value. The resulting search tree then guides the choice of the best move in the current state.

1.	Selection — From the root, repeatedly choose the child that best trades off exploration and exploitation (per the selection rule) until you reach a leaf.
2.	Expansion & Evaluation — If the leaf is non-terminal, expand it by adding children for its legal actions. Query the neural network for a value estimate of the position and policy priors over those actions.
3.	Backup (Backpropagation) — Propagate the evaluated value back along the path to the root, updating each node’s visit count and action-value estimates.

Note: When we start the search, the tree only contains the root node (the current game state). We cannot select any child nodes yet, so we proceed directly to the expansion and evaluation step.

## Select


Starting from the root node, repeatedly select child nodes according to a selection policy until a leaf node is reached. The selection policy balances exploration and exploitation, using PUCT (Predictor + Upper Confidence Bound for Trees). The prior / predictor comes from the neural network's policy output.

In [2]:
# Root state: O's turn to play
root_state = [
    0,
    1,
    0,  # Row 1: _, X, _
    0,
    1,
    -1,  # Row 2: _, X, O
    0,
    -1,
    1,
]  # Row 3: _, O, X

# Find all possible actions (empty cells)
possible_actions = [i for i in range(9) if root_state[i] == 0]

# Create child states for each possible action
child_states = []
for action in possible_actions:
    new_state = root_state.copy()
    new_state[action] = -1  # O plays (player -1)
    child_states.append({"action": action, "state": new_state})

# Add grandchildren to action 0 (second level of the tree)
# After O plays at position 0, it's X's turn
action_0_state = child_states[0]["state"]
action_0_possible_moves = [i for i in range(9) if action_0_state[i] == 0]

grandchildren = []
for action in action_0_possible_moves:
    new_state = action_0_state.copy()
    new_state[action] = 1  # X plays (player 1)
    grandchildren.append({"action": action, "state": new_state})

# Attach grandchildren to the first child (action 0)
child_states[0]["children"] = grandchildren

# Create root node
root_node = {"state": root_state, "children": child_states}
# display(HTML(draw_game_tree(root_node, selected_node=child_states[0]["children"][0])))

# Animated selection sequence: root -> child 0 -> grandchild 0
# Build three frames by calling draw_game_tree with different selected nodes
frame_root = draw_game_tree(root_node, selected_node=root_node)
frame_child = draw_game_tree(root_node, selected_node=child_states[0])
selected_grandchild = None
if child_states and child_states[0].get("children"):
    selected_grandchild = child_states[0]["children"][0]
frame_grandchild = (
    draw_game_tree(root_node, selected_node=selected_grandchild)
    if selected_grandchild
    else frame_child
)

# CSS animation: show each frame for 2s in a looping sequence
html = """<style>
@keyframes cycleFade {
  0% { opacity: 1; transform: scale(1); }
  33% { opacity: 1; transform: scale(1); }
  34% { opacity: 0; transform: scale(1); }
  100% { opacity: 0; }
}
.animation-container { display: grid; width: 100%; max-width: 1200px; margin: 8px auto; }
.frame { grid-area: 1 / 1; opacity: 0; }
.frame:nth-child(1) { animation: cycleFade 6s infinite; animation-delay: 0s; }
.frame:nth-child(2) { animation: cycleFade 6s infinite; animation-delay: 2s; }
.frame:nth-child(3) { animation: cycleFade 6s infinite; animation-delay: 4s; }
</style>
<div class="animation-container">
  <div class="frame">{frame_root}</div>
  <div class="frame">{frame_child}</div>
  <div class="frame">{frame_grandchild}</div>
</div>
"""

# Use simple string replacement to avoid interpreting CSS braces as format fields
html_filled = (
    html.replace("{frame_root}", frame_root)
    .replace("{frame_child}", frame_child)
    .replace("{frame_grandchild}", frame_grandchild)
)

display(HTML(html_filled))

In [3]:
def puct_value(node, child, c_puct=1.0):
    # N(s): parent visits; n: child visits
    N_parent = node.visit_count
    n = child.visit_count

    # Q: exploitation
    Q = child.value_sum / (1 + n)

    # P: prior from policy head (already stored on the child)
    P = child.prior

    # U: exploration
    U = c_puct * P * (N_parent**0.5) / (1 + n)

    return Q + U


<br><br>

$$\text{PUCT}(s,a) \;=\; Q(s,a) \;+\; U(s,a)$$

$$Q(s,a) \;=\; \frac{\mathrm{W}(s,a)}{\,1+\mathrm{N}(s,a)\,}$$

$$U(s,a) \;=\; c_{\mathrm{puct}}\;\cdot\; \mathrm{prior}(s,a)\;\cdot\;\frac{\sqrt{ \mathrm{N}(s)}}{\,1+\mathrm{N}(s,a)\,}$$

Where
- $s$ = current state, $a$ = action (child).
- $\mathrm{W}(s,a)$ = cumulative backed-up values through edge (s,a).
- $\mathrm{N}(s,a)$ = number of visits to (s,a).
- $\mathrm{N}(s)  = \sum_b \mathrm{N}(s,b)$ = total visits of the parent 
- $\mathrm{prior}(s,a)$ = policy prior from the network at s.
- $c_{\mathrm{puct}}>0$ = exploration constant.
- We add 1 to denominators to avoid division by zero.

### Why does PUCT work?

PUCT (Predictor + Upper Confidence Bound for Trees) is AlphaZero’s adaptation of the UCB1 formula from classical multi-armed bandits to tree search guided by a neural network.


The numerator $\sqrt{N(s)}$ increases as we explore the parent node more — meaning we’re more confident overall in this position.

The denominator (1 + N(s,a)) ensures that the more an action is tried, the smaller its exploration bonus becomes.

This comes directly from the theoretical foundations of Upper Confidence Bounds (UCB) in bandit theory — the mathematical framework underlying MCTS.



**Consider the following example:**

| Term | $a_1$ | $a_2$ |
|------|-------|-------|
| P(s,a) | 0.8   | 0.2   |
| N(s)  | 100   | 100   |
| N(s,a) | 80    | 5    |

Calculating U(s,a) for both actions with c_puct = 1.0:

- For $a_1$:
  $$U(s,a_1) = 1.0 \cdot 0.8 \cdot \frac{\sqrt{100}}{1 + 80} = 0.8 \cdot \frac{10}{81} \approx 0.0988$$
- For $a_2$:
  $$U(s,a_2) = 1.0 \cdot 0.2 \cdot \frac{\sqrt{100}}{1 + 5} = 0.2 \cdot \frac{10}{6} \approx 0.3333$$

Although $a_1$ has a higher prior probability, its exploration bonus is much smaller due to its high visit count. In contrast, $a_2$ has a lower prior but a significantly larger exploration bonus, encouraging the algorithm to explore it further.

$c_{\mathrm{puct}}$ controls the balance between exploration and exploitation. It is set to 1 in AlphaZero.

**Why $\sqrt{N(s)}$ in the exploration term U(s,a)?**

Example with varying N(s):

Case 1: $N(s) = 10$, We have not explored the parent node much yet.
| Term | $a_1$ | $a_2$ |
|------|-------|-------|
| P(s,a) | 0.8   | 0.2   |
| N(s)  | 10    | 10    |
| N(s,a) | 10   | 10    |

- For $a_1$:
  $$U(s,a_1) = 1.0 \cdot 0.8 \cdot \frac{\sqrt{10}}{1 + 10} \approx 0.2162$$
- For $a_2$:
  $$U(s,a_2) = 1.0 \cdot 0.2 \cdot \frac{\sqrt{10}}{1 + 10} \approx 0.0541$$

Case 2: $N(s) = 1000$, We have explored the parent node extensively.
| Term | $a_1$ | $a_2$ |
|------|-------|-------|
| P(s,a) | 0.8   | 0.2   |
| N(s)  | 1000   | 1000   |
| N(s,a) | 10   | 10    |

- For $a_1$:
  $$U(s,a_1) = 1.0 \cdot 0.8 \cdot \frac{\sqrt{1000}}{1 + 10} \approx 2.3094$$
- For $a_2$:
  $$U(s,a_2) = 1.0 \cdot 0.2 \cdot \frac{\sqrt{1000}}{1 + 10} \approx 0.5774$$


The $\sqrt{N(s)}$ term ensures exploration pressure increases smoothly with overall search effort

<br>**Compare Action Selection Interactively**

In [4]:
import ipywidgets as widgets
from IPython.display import display, HTML
import math


def draw_puct_interactive(c_puct, p1, n1, v1, p2, n2, v2, extra_parent_visits):
    # Calculate Derived Values
    N_parent = n1 + n2 + extra_parent_visits

    # Action 1 Calculations
    q1 = v1 / (1 + n1)
    u1 = c_puct * p1 * (math.sqrt(N_parent)) / (1 + n1)
    score1 = q1 + u1

    # Action 2 Calculations
    q2 = v2 / (1 + n2)
    u2 = c_puct * p2 * (math.sqrt(N_parent)) / (1 + n2)
    score2 = q2 + u2

    # Visualization Config
    width = 800
    height = 450
    padding = 80
    bar_width = 120
    scale_y = 90
    zero_y = height / 2 + 30

    # Color palette matching theme
    col_q = "#0969da"
    col_u = "#DA7A09"
    col_selected = "#10b981"
    col_bg = "transparent"
    col_border = "#9b9ea7"
    col_text = "#4D5461"
    col_text_lite = "#9b9ea7"

    svg = f'<svg width="{width}" height="{height}" style="font-family: var(--primary-font);">'

    # Draw axis
    svg += f'<line x1="{padding}" y1="{zero_y}" x2="{width - padding}" y2="{zero_y}" stroke="{col_text_lite}" stroke-width="1.5" />'
    svg += f'<text x="{width - padding}" y="{zero_y - 10}" text-anchor="end" font-size="11" fill="{col_text_lite}">baseline (0)</text>'

    def draw_bar(x_center, q, u, score, label, n_val, p_val, is_winner):
        h_q = q * scale_y
        y_q = zero_y - h_q if q >= 0 else zero_y
        h_q_abs = abs(h_q)

        h_u = u * scale_y
        y_u = (zero_y - h_q) - h_u

        svg_content = ""

        # Draw bars
        if h_q_abs > 0:
            svg_content += f'<rect x="{x_center - bar_width / 2}" y="{y_q}" width="{bar_width}" height="{h_q_abs}" fill="{col_q}" />'
            if h_q_abs > 25:
                svg_content += f'<text x="{x_center}" y="{y_q + h_q_abs / 2 + 5}" text-anchor="middle" font-size="12" font-weight="500" fill="white">Q: {q:.2f}</text>'

        if h_u > 0:
            svg_content += f'<rect x="{x_center - bar_width / 2}" y="{y_u}" width="{bar_width}" height="{h_u}" fill="{col_u}" />'
            if h_u > 25:
                svg_content += f'<text x="{x_center}" y="{y_u + h_u / 2 + 5}" text-anchor="middle" font-size="12" font-weight="500" fill="white">U: {u:.2f}</text>'

        # Score badge
        badge_y = min(y_u, y_q) - 30
        badge_color = col_selected if is_winner else col_text_lite
        svg_content += f'<rect x="{x_center - 40}" y="{badge_y}" width="80" height="24" fill="{badge_color}" rx="12" />'
        svg_content += f'<text x="{x_center}" y="{badge_y + 16}" text-anchor="middle" font-size="14" font-weight="600" fill="white">{score:.3f}</text>'

        # Winner checkmark
        if is_winner:
            svg_content += f'<circle cx="{x_center}" cy="{badge_y - 18}" r="10" fill="{col_selected}" />'
            svg_content += f'<path d="M {x_center - 3} {badge_y - 18} L {x_center - 1} {badge_y - 15} L {x_center + 4} {badge_y - 22}" stroke="white" stroke-width="2" stroke-linecap="round" fill="none" />'

        # Bottom info
        card_y = height - 80

        svg_content += f'<text x="{x_center}" y="{card_y + 15}" text-anchor="middle" font-size="15" font-weight="600" fill="{col_text}">{label}</text>'
        svg_content += f'<text x="{x_center}" y="{card_y + 35}" text-anchor="middle" font-size="11" fill="{col_text_lite}">P = {p_val:.2f}</text>'
        svg_content += f'<text x="{x_center}" y="{card_y + 50}" text-anchor="middle" font-size="11" fill="{col_text_lite}">N = {n_val}</text>'

        return svg_content

    # Render Bars
    is_winner_1 = score1 >= score2
    svg += draw_bar(width * 0.3, q1, u1, score1, "Action 1", n1, p1, is_winner_1)
    svg += draw_bar(width * 0.7, q2, u2, score2, "Action 2", n2, p2, not is_winner_1)

    # Legend
    svg += f'''
    <g transform="translate(30, 30)">
        <rect x="0" y="0" width="160" height="75" fill="transparent" stroke="{col_border}" stroke-width="0" rx="4"/>
        <rect x="12" y="28" width="16" height="12" fill="{col_q}"/>
        <text x="32" y="37" font-size="11" fill="{col_text}">Q (Exploitation)</text>
        <rect x="12" y="48" width="16" height="12" fill="{col_u}"/>
        <text x="32" y="57" font-size="11" fill="{col_text}">U (Exploration)</text>
    </g>
    '''

    # Parent visits info
    svg += f'<text x="{width / 2}" y="25" text-anchor="middle" font-size="12" fill="{col_text_lite}">Total Parent Visits: <tspan font-weight="600" fill="{col_text}">{N_parent}</tspan></text>'

    svg += "</svg>"
    display(HTML(svg))


# Widget Styling
style = {"description_width": "150px"}
layout_wide = widgets.Layout(width="500px")
layout_normal = widgets.Layout(width="450px")

w_c_puct = widgets.FloatSlider(
    value=1.0,
    min=0.1,
    max=5.0,
    step=0.1,
    description="C_PUCT:",
    style=style,
    layout=layout_wide,
    readout_format=".2f",
)

# Action 1
w_p1 = widgets.FloatSlider(
    value=0.8,
    min=0.0,
    max=1.0,
    step=0.05,
    description="Prior P(s,a₁):",
    style=style,
    layout=layout_normal,
    readout_format=".2f",
)
w_n1 = widgets.IntSlider(
    value=80,
    min=0,
    max=200,
    step=1,
    description="Visits N(s,a₁):",
    style=style,
    layout=layout_normal,
)
w_v1 = widgets.FloatSlider(
    value=20.0,
    min=-50.0,
    max=50.0,
    step=1.0,
    description="Value W(s,a₁):",
    style=style,
    layout=layout_normal,
    readout_format=".1f",
)

# Action 2
w_p2 = widgets.FloatSlider(
    value=0.2,
    min=0.0,
    max=1.0,
    step=0.05,
    description="Prior P(s,a₂):",
    style=style,
    layout=layout_normal,
    readout_format=".2f",
)
w_n2 = widgets.IntSlider(
    value=5,
    min=0,
    max=200,
    step=1,
    description="Visits N(s,a₂):",
    style=style,
    layout=layout_normal,
)
w_v2 = widgets.FloatSlider(
    value=1.0,
    min=-50.0,
    max=50.0,
    step=1.0,
    description="Value W(s,a₂):",
    style=style,
    layout=layout_normal,
    readout_format=".1f",
)

w_extra = widgets.IntSlider(
    value=15,
    min=0,
    max=1000,
    description="Extra Parent Visits:",
    style=style,
    layout=layout_wide,
)

ui = widgets.VBox(
    [
        widgets.HTML(
            "<h3 style='margin: 0 0 5px 0; color: var(--text); font-family: var(--primary-font);'>PUCT Selection Visualizer</h3>"
        ),
        widgets.HTML(
            "<p style='margin: 0 0 15px 0; color: var(--text-lite); font-size: 13px; font-family: var(--primary-font);'>Explore how exploitation (Q) and exploration (U) balance action selection</p>"
        ),
        w_c_puct,
        w_extra,
        widgets.HBox(
            [
                widgets.VBox(
                    [
                        widgets.HTML(
                            "<div style='padding: 8px; background: transparent;font-family: var(--primary-font);'><strong style='color: var(--text);'>Action 1</strong></div>"
                        ),
                        w_p1,
                        w_n1,
                        w_v1,
                    ],
                    layout=widgets.Layout(
                        padding="4px",
                        width="auto",
                    ),
                ),
                widgets.VBox(
                    [
                        widgets.HTML(
                            "<div style='padding: 8px; background: transparent;font-family: var(--primary-font);'><strong style='color: var(--text);'>Action 2</strong></div>"
                        ),
                        w_p2,
                        w_n2,
                        w_v2,
                    ],
                    layout=widgets.Layout(
                        padding="4px",
                        width="auto",
                    ),
                ),
            ],
            layout=widgets.Layout(justify_content="space-between"),
        ),
        widgets.HTML("<div style='height: 4px;'></div>"),
    ],
    layout=widgets.Layout(),
)

out = widgets.interactive_output(
    draw_puct_interactive,
    {
        "c_puct": w_c_puct,
        "p1": w_p1,
        "n1": w_n1,
        "v1": w_v1,
        "p2": w_p2,
        "n2": w_n2,
        "v2": w_v2,
        "extra_parent_visits": w_extra,
    },
)

display(widgets.VBox([ui, out], layout=widgets.Layout(align_items="center")))

VBox(children=(VBox(children=(HTML(value="<h3 style='margin: 0 0 5px 0; color: var(--text); font-family: var(-…

---

## Expand & Evaluate

Once the Selection phase reaches a leaf node (a node at the edge of the current search tree), we check if the node is Terminal or Non-Terminal.

Terminal State is a game state where the game has ended, meaning there are no further moves or actions possible.
- If it is terminal, no expansion is needed. We assign it a value based on the game outcome (win/loss/draw).
- If it is non-terminal, we expand the node by adding child nodes for each legal action from that state.

If the game is not over, the node is expanded. This involves identifying all legal actions available from the current state and creating child nodes for them.
- The current leaf node becomes a parent.
- New child nodes are added to the tree, representing every possible move a player could make from that position.


The neural network, denoted as $f_{\theta}$, evaluates the current game state $s$ and outputs a tuple containing two pieces of inference:$$(\mathbf{p}, v) = f_{\theta}(s)$$

**Policy Vector** ($\mathbf{p}$): A probability distribution over all possible moves. This represents the "Prior" probability—how good the network thinks a move is based on its training, before doing any search. This value is stored in the newly created child nodes as prior.

**Value Scalar** ($\mathbf{v}$): A single number (usually between -1 and 1) estimating the expected outcome of the game from the current state $s$ (e.g., how likely the current player is to win).This value is sent backwards up the tree during the Backup phase to update the parent nodes' `value_sum` and `visit_count`.

We will look more deeply into the network in the next notebook.

*The input for the neural network is the game state, but it could also contain the game history, the player to move, etc.* 


```mermaid
---
config:
  theme: 'base'
  themeVariables:
    primaryColor: '#ffffff00'
    primaryTextColor: '#4D5461'
    primaryBorderColor: '#096bda'
    lineColor: '#4D5461'
    secondaryColor: '#006100'
    tertiaryColor: '#4D5461'
---
graph LR
    A[State] --> B[Neural Network]
    B --> C[Policy]
    B --> D[Value]
```

```mermaid
---
config:
  theme: 'base'
  themeVariables:
    primaryColor: '#e6f4ff'
    primaryTextColor: '#4D5461'
    primaryBorderColor: '#0969da'
    lineColor: '#9aa0a6'
    secondaryColor: '#fbfbfb7f'
    tertiaryColor: '#ffffff'
---
flowchart LR
    %% --- Nodes ---
    
    Start([Select Leaf Node]) --> Check{Is Terminal?}

    %% Branch 1: Game Over (Terminal)
    Check -- Yes --> Result[Get Game Result]
    Result -->|v = Outcome| Backup([Backup Phase])

    %% Branch 2: Expand & Evaluate (Non-Terminal)
    Check -- No --> Expand[Expand: Create Children]
    Expand --> NN[[Evaluate: Neural Network]]

    %% The Neural Network "Split"
    NN -->|Policy p| Updates[Initialize Child Priors]
    NN -->|Value v| Backup

    %% Loop Indication
    Backup -.->|Update Stats &<br>Next Sim| Start

    %% --- Styling ---
    
    %% Define Classes based on AlphaZero functions
    classDef flow fill:#fff,stroke:#4D5461,stroke-width:1px,color:#4D5461;
    classDef treeOp fill:#e6f4ff,stroke:#0969da,stroke-width:2px,color:#0969da;
    classDef brain fill:#fff0e0,stroke:#DA7A09,stroke-width:2px,color:#DA7A09;
    classDef term fill:#f0f0f0,stroke:#4D5461,stroke-width:2px,stroke-dasharray: 5 5;

    %% Apply Classes
    class Start,Backup treeOp;
    class Expand,Updates,Result flow;
    class Check term;
    class NN brain;
```

*in the following example, we have selected the action 0 (top left cell) for player O and then action 2(top right cell) for player X. The leaf node is not terminal, so we continue to expand it.*

In [5]:
# Root state: O's turn to play
root_state = [
    0,
    1,
    0,  # Row 1: _, X, _
    0,
    1,
    -1,  # Row 2: _, X, O
    0,
    -1,
    1,
]  # Row 3: _, O, X

# Find all possible actions (empty cells)
possible_actions = [i for i in range(9) if root_state[i] == 0]

# Create child states for each possible action
child_states = []
for action in possible_actions:
    new_state = root_state.copy()
    new_state[action] = -1  # O plays (player -1)
    child_states.append({"action": action, "state": new_state})

# Add grandchildren to action 0 (second level of the tree)
action_0_state = child_states[0]["state"]
action_0_possible_moves = [i for i in range(9) if action_0_state[i] == 0]

grandchildren = []
for action in action_0_possible_moves:
    new_state = action_0_state.copy()
    new_state[action] = 1  # X plays (player 1)
    grandchildren.append({"action": action, "state": new_state})

# Attach grandchildren to the first child (action 0)
child_states[0]["children"] = grandchildren

# Create root node
root_node = {"state": root_state, "children": child_states}
# Highlight the first grandchild as the current selection (if it exists)
selected_grandchild = grandchildren[0] if grandchildren else None

# First tree: custom layout
tree1_html = draw_game_tree(
    root_node,
    board_size=100,
    board_padding=12,
    horizontal_spacing=100,
    vertical_spacing=200,
    selected_node=selected_grandchild,
)

# great-grandchildren
great_grandchildren = []
action_1_state = child_states[0]["children"][0]["state"]
action_1_possible_moves = [i for i in range(9) if action_1_state[i] == 0]
for action in action_1_possible_moves:
    new_state = action_1_state.copy()
    new_state[action] = 1  # X plays (player 1)
    great_grandchildren.append({"action": action, "state": new_state})

# Attach great-grandchildren to the first child (action 0)
child_states[0]["children"][0]["children"] = great_grandchildren

# Second tree: default layout with a selected node highlighted
tree2_html = draw_game_tree(
    root_node,
    board_size=100,
    board_padding=12,
    horizontal_spacing=100,
    vertical_spacing=200,
)

# side by side display
combined_html = f"""
<div style="display: flex; justify-content: space-around; align-items: center;">
    <div style="flex: 1; margin: 10px; margin-right: 30px;align-self: start;">
        {tree1_html}
    </div>
    <div style="font-size: 78px; color: #666;">
        &rarr;
    </div>
    <div style="flex: 1; margin: 10px;">
        {tree2_html}
    </div>
</div>
"""

display(HTML(combined_html))


## Code

In [6]:
def score(board):
    # Check rows, columns, diagonals
    lines = [
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8],  # rows
        [0, 3, 6],
        [1, 4, 7],
        [2, 5, 8],  # columns
        [0, 4, 8],
        [2, 4, 6],  # diagonals
    ]

    for line in lines:
        if board[line[0]] == board[line[1]] == board[line[2]] != 0:
            return board[line[0]]

    return 0  # Draw or game not over

In [7]:
# Nodes of tree


class Node:
    def __init__(self, state, parent=None, action=None, prior=0):
        # Array of 9 elements representing the board (0 = empty, 1 = X, -1 = O)
        self.state = state  # same as TicTacToeGame.board
        # Reference to the parent node
        self.parent = parent
        # Action taken to reach this node from parent
        self.action = action
        # Dictionary mapping actions to child nodes
        self.children = {}
        # Number of times the node has been visited
        self.visit_count = 0
        # Sum of the values of the node
        self.value_sum = 0.0
        # Prior probability of selecting this node
        self.prior = prior

    def terminal(self):
        # Check if the game is over (win/loss/draw)
        return score(self.state) != 0 or all(cell != 0 for cell in self.state)

    def puct_value(self, child, c_puct=1.0):
        """
        Calculate the PUCT value for a child node based on its visit count,
        value sum, and prior probability.
        """
        N_parent = self.visit_count
        n = child.visit_count

        # Q: exploitation
        Q = child.value_sum / (1 + n)

        # P: prior from policy head (already stored on the child)
        P = child.prior

        # U: exploration
        U = c_puct * P * (N_parent**0.5) / (1 + n)

        return Q + U

    def best_child(self, c_puct=1.0):
        """
        Select the child with the highest PUCT value.
        """
        best_action = None
        best_value = float("-inf")

        for action, child in self.children.items():
            puct = self.puct_value(child, c_puct)
            if puct > best_value:
                best_value = puct
                best_action = action

        return best_action

    def expand(self, action_probs):
        """
        Expand the node by adding all possible child nodes.
        """
        if self.terminal():
            return

        # Find all possible/legal actions (empty cells)
        possible_actions = [i for i in range(9) if self.state[i] == 0]

        for action in possible_actions:
            new_state = self.state.copy()
            new_state[action] = (
                1 if sum(1 for x in new_state if x != 0) % 2 == 0 else -1
            )
            prior = action_probs.get(action, 0)
            self.children[action] = Node(
                new_state, parent=self, action=action, prior=prior
            )

    def backup(self, value):
        """
        Backpropagate the value up the tree, updating visit counts and value sums.
        """
        self.visit_count += 1
        self.value_sum += value
        if self.parent:
            self.parent.backup(-value)  # Switch perspective for the parent

    def viz_node(self, size=200, padding=20):
        """
        Visualize the node's board state as SVG.
        """
        return draw_tic_tac_toe(self.state, size=size, padding=padding)


In [8]:
import random


class TreeSearch:
    def __init__(self, root_state):
        self.root = Node(root_state)

    def search(self, num_simulations=100, c_puct=1.0):
        for _ in range(num_simulations):
            node = self.root
            # Selection
            while node.children:
                action = node.best_child(c_puct)
                node = node.children[action]

            # Expansion
            if not node.terminal():
                # Dummy uniform probabilities for expansion
                possible_actions = [i for i in range(9) if node.state[i] == 0]
                action_probs = {
                    action: 1 / len(possible_actions) for action in possible_actions
                }
                node.expand(action_probs)

                # Choose one child to simulate
                action = node.best_child(c_puct)
                node = node.children[action]

            # Evaluation
            policy, value = self.evaluate(node.state)

            # Backup
            node.backup(value)

    def evaluate(self, state):
        """
        Dummy evaluation function that returns uniform policy and random value.
        In a real scenario, this would use a neural network.
        """
        possible_actions = [i for i in range(9) if state[i] == 0]
        if not possible_actions:
            return {}, 0.0  # No possible actions, game over

        policy = {action: 1 / len(possible_actions) for action in possible_actions}
        value = random.uniform(-1, 1)  # Random value between -1 and 1
        return policy, value

    def viz_tree(self, selected_node=None):
        """
        Visualize the entire tree starting from the root node.
        """

        # Helper functions for Node/dict compatibility
        def get_children(node):
            if isinstance(node, Node):
                return list(node.children.values()) if node.children else []
            return node.get("children", [])

        def get_state(node):
            if isinstance(node, Node):
                return node.state
            return node.get("state", [])

        def get_action(node):
            if isinstance(node, Node):
                return node.action
            return node.get("action")

        def get_visit_count(node):
            if isinstance(node, Node):
                return node.visit_count
            return 0

        def get_value_sum(node):
            if isinstance(node, Node):
                return node.value_sum
            return 0.0

        def draw_game_tree(
            root_node,
            board_size=120,
            board_padding=12,
            horizontal_spacing=160,
            vertical_spacing=250,
            selected_node=None,
            interactive=False,
        ):
            """
            Draw a game tree with root state and its children.
            Supports recursive tree structure with optional interactivity.

            Args:
                root_node: Node object or dictionary with 'state', 'children', and optional 'action'
                board_size: Size of each board
                board_padding: Padding for each board
                horizontal_spacing: Horizontal spacing between nodes
                vertical_spacing: Vertical spacing between levels
                selected_node: Optional node to highlight with a circle
                interactive: If True, show depth-1 nodes sideways with hover-to-reveal subtrees

            Returns:
                HTML string with the complete game tree visualization
            """

            if interactive:
                # Horizontal layout with hover interactivity
                from IPython.display import HTML

                # Generate subtree HTML for each depth-1 child
                def generate_subtree_html(child_node):
                    """Generate HTML for a child's subtree"""
                    children = get_children(child_node)
                    if not children:
                        return "<div style='padding:20px;color:#999;'>No subtree</div>"

                    # Create a mini tree for this subtree
                    subtree_html = '<div style="display:flex;gap:20px;padding:20px;flex-wrap:wrap;">'
                    for grandchild in children:
                        gc_state = get_state(grandchild)
                        gc_action = get_action(grandchild)
                        gc_board = draw_tic_tac_toe(gc_state, size=80, padding=8)

                        gc_visits = get_visit_count(grandchild)
                        gc_value = get_value_sum(grandchild)

                        subtree_html += f"""
                        <div style="display:flex;flex-direction:column;align-items:center;gap:8px;">
                            <div style="font-size:11px;color:#666;">Action: {gc_action}</div>
                            {gc_board}
                            <div style="font-size:10px;color:#999;">V:{gc_visits} W:{gc_value:.1f}</div>
                        </div>
                        """
                    subtree_html += "</div>"
                    return subtree_html

                # Main layout
                root_state = get_state(root_node)
                root_board = draw_tic_tac_toe(
                    root_state, size=board_size, padding=board_padding
                )

                children = get_children(root_node)

                html = """
                <style>
                .interactive-tree-container {
                    display: flex;
                    gap: 40px;
                    align-items: center;
                    padding: 20px;
                    font-family: var(--primary-font, -apple-system, system-ui);
                }
                .root-node {
                    flex-shrink: 0;
                }
                .children-list {
                    display: flex;
                    flex-direction: column;
                    gap: 16px;
                    flex: 1;
                }
                .child-item {
                    position: relative;
                    display: flex;
                    align-items: center;
                    gap: 12px;
                    padding: 12px;
                    border: 2px solid #e0e0e0;
                    border-radius: 8px;
                    background: #fafafa;
                    cursor: pointer;
                    transition: all 0.2s ease;
                }
                .child-item:hover {
                    background: #f0f0f0;
                    border-color: #0969da;
                    transform: translateX(4px);
                }
                .child-board {
                    flex-shrink: 0;
                }
                .child-info {
                    display: flex;
                    flex-direction: column;
                    gap: 4px;
                    flex: 1;
                }
                .child-action {
                    font-weight: 600;
                    color: #333;
                    font-size: 14px;
                }
                .child-stats {
                    font-size: 11px;
                    color: #666;
                }
                .subtree-popup {
                    position: fixed;
                    display: none;
                    background: white;
                    border: 2px solid #0969da;
                    border-radius: 8px;
                    box-shadow: 0 8px 24px rgba(0,0,0,0.15);
                    z-index: 1000;
                    max-width: 600px;
                    max-height: 400px;
                    overflow: auto;
                }
                .child-item:hover .subtree-popup {
                    display: block;
                }
                .arrow-connector {
                    width: 30px;
                    height: 2px;
                    background: #999;
                    position: relative;
                }
                .arrow-connector::after {
                    content: '';
                    position: absolute;
                    right: -6px;
                    top: -4px;
                    width: 0;
                    height: 0;
                    border-left: 6px solid #999;
                    border-top: 5px solid transparent;
                    border-bottom: 5px solid transparent;
                }
                </style>
                """

                html += '<div class="interactive-tree-container">'

                # Root node
                html += f"""
                <div class="root-node">
                    <div style="text-align:center;margin-bottom:8px;font-size:12px;color:#666;font-weight:600;">Root</div>
                    {root_board}
                </div>
                """

                # Arrow
                html += '<div class="arrow-connector"></div>'

                # Children list
                html += '<div class="children-list">'

                for i, child in enumerate(children):
                    child_state = get_state(child)
                    child_action = get_action(child)
                    child_board = draw_tic_tac_toe(child_state, size=80, padding=8)
                    child_visits = get_visit_count(child)
                    child_value = get_value_sum(child)

                    subtree_html = generate_subtree_html(child)

                    html += f"""
                    <div class="child-item" id="child-{i}">
                        <div class="child-board">{child_board}</div>
                        <div class="child-info">
                            <div class="child-action">Action: {child_action}</div>
                            <div class="child-stats">Visits: {child_visits} | Value: {child_value:.2f}</div>
                        </div>
                        <div class="subtree-popup" id="popup-{i}" style="left: calc(100% + 20px); top: 0;">
                            <div style="padding:12px;border-bottom:1px solid #e0e0e0;background:#f8f9fa;font-weight:600;color:#333;">
                                Subtree for Action {child_action}
                            </div>
                            {subtree_html}
                        </div>
                    </div>
                    """

                html += "</div>"  # children-list
                html += "</div>"  # interactive-tree-container

                return html

            else:
                # Original vertical tree implementation
                def get_subtree_width(node):
                    children = get_children(node)
                    if not children:
                        return horizontal_spacing
                    return sum(get_subtree_width(child) for child in children)

                positions = {}

                def assign_positions(node, x, y):
                    positions[id(node)] = (x, y)
                    children = get_children(node)
                    if not children:
                        return

                    total_width = sum(get_subtree_width(child) for child in children)
                    start_x = x - total_width / 2

                    current_x = start_x
                    for child in children:
                        child_width = get_subtree_width(child)
                        assign_positions(
                            child, current_x + child_width / 2, y + vertical_spacing
                        )
                        current_x += child_width

                tree_width = get_subtree_width(root_node) + 20

                def get_depth(node):
                    children = get_children(node)
                    if not children:
                        return 1
                    return 1 + max(get_depth(child) for child in children)

                tree_height = (
                    get_depth(root_node) * vertical_spacing + board_size
                ) - 200

                root_x = tree_width / 2
                root_y = 40
                assign_positions(root_node, root_x, root_y)

                html = f'<div style="width:{tree_width}px;height:{tree_height}px;position:relative;margin:20px auto;">'
                svg_connections = f'<svg width="{tree_width}" height="{tree_height}" style="position:absolute;top:0;left:0;pointer-events:none;">'
                svg_connections += '<defs><marker id="arrow" viewBox="0 0 10 10" refX="6" refY="5" markerWidth="6" markerHeight="6" orient="auto"><path d="M 0 0 L 10 5 L 0 10 z" fill="#999"></path></marker></defs>'

                nodes_html = ""

                def draw_node_recursive(node, depth=0):
                    nonlocal svg_connections, nodes_html
                    x, y = positions[id(node)]

                    node_state = get_state(node)
                    node_svg = draw_tic_tac_toe(
                        node_state, size=board_size, padding=board_padding
                    )
                    is_selected = selected_node is node

                    if is_selected:
                        highlight_padding = 0
                        highlight_size = board_size + highlight_padding * 2
                        highlight_left = x - highlight_size / 2
                        highlight_top = y - highlight_padding
                        nodes_html += f'<div style="position:absolute;left:{highlight_left}px;top:{highlight_top}px;width:{highlight_size}px;height:{highlight_size}px;border-radius:50%;box-sizing:border-box;pointer-events:none;z-index:1;border:2px solid var(--text-lite);"></div>'

                    nodes_html += f'<div class="" style="position:absolute;left:{x - board_size / 2}px;top:{y}px;z-index:2;">{node_svg}</div>'

                    children = get_children(node)
                    if children:
                        for child in children:
                            cx, cy = positions[id(child)]

                            start_y = y + board_size
                            end_y = cy

                            svg_connections += f'<line x1="{x}" y1="{start_y}" x2="{cx}" y2="{end_y}" stroke="#999" stroke-width="2" marker-end="url(#arrow)"  ></line>'

                            child_action = get_action(child)
                            if child_action is not None:
                                label_x = x + (cx - x) * 0.4
                                label_y = start_y + (end_y - start_y) * 0.4

                                text = f"{child_action}"
                                text_width = 28

                                svg_connections += f'<rect x="{label_x - text_width / 2}" y="{label_y - 15}" width="{text_width}" height="{text_width}" fill="#999" fill-opacity="1" rx="20" ry="20" />'
                                svg_connections += f'<text x="{label_x}" y="{label_y - 10 + text_width / 2}" text-anchor="middle" fill="#fff" font-size="16" font-family="Noto Sans" >{text}</text>'

                            draw_node_recursive(child, depth + 1)

                draw_node_recursive(root_node)

                svg_connections += "</svg>"
                html += svg_connections
                html += nodes_html
                html += "</div>"

                return html

        return draw_game_tree(
            self.root,
            board_size=120,
            board_padding=12,
            horizontal_spacing=160,
            vertical_spacing=250,
            selected_node=selected_node,
        )

In [12]:
# test the TreeSearch and Node classes
initial_state = [
    0,
    1,
    0,  # Row 1: _, X, _
    0,
    1,
    -1,  # Row 2: _, X, O
    0,
    -1,
    1,
]  # Row 3: _, O, X

tree_search = TreeSearch(initial_state)
tree_search.search(num_simulations=1000, c_puct=1.0)

# Visualize the search tree
display(HTML(tree_search.root.viz_node()))

In [13]:
for child_action, child_node in tree_search.root.children.items():
    print(
        f"Action: {child_action}, Visits: {child_node.visit_count}, Value Sum: {child_node.value_sum:.2f}"
    )

Action: 0, Visits: 269, Value Sum: 0.87
Action: 2, Visits: 466, Value Sum: 7.33
Action: 3, Visits: 84, Value Sum: -5.44
Action: 6, Visits: 181, Value Sum: -2.40


In [14]:
tree_search_html = tree_search.viz_tree()
display(HTML(tree_search_html))