# αMCTS

Game trees are a way to represent game states and actions. Each node in the tree represents a game state, and each edge represents a possible action that can be taken from that state.

The aim is, give a state of the game, to build a tree of possible future states by exploring possible actions. This tree can then be used to make decisions about which action to take in the current state.

Monte Carlo Tree Search (MCTS) is a method for building game trees by simulating random games from the current state. The basic idea is to use random simulations to estimate the value of each action, and then use these estimates to guide the search for the best action.

AlphaZero uses a modified version of MCTS. The key difference is that instead of using random simulations, αMCTS uses a neural network to evaluate the value of each state and the probability of each action. This allows αMCTS to make more informed decisions about which actions to explore.

Three main steps in αMCTS:
1. Selection: Starting from the root node, recursively select child nodes until a leaf node is reached. The selection is based on a balance between exploration (trying out less-visited nodes) and exploitation (choosing nodes with high value estimates).
2. Expansion: If the leaf node is not a terminal state, expand the tree by adding one or more child nodes representing possible actions from the leaf node.
3. Simulation and Backpropagation: Use the neural network to evaluate the value of the new node and the probabilities of each action. This information is then backpropagated up the tree to update the value estimates and visit counts of the nodes along the path from the root to the new node.

Let's implement a simple version of αMCTS for 2D Tic-Tac-Toe!

Lets assume we have good policy and value networks already trained for Tic-Tac-Toe.

Start with a random state 

In [2]:
print("testing")

testing


In [1]:
def draw_tic_tac_toe(board, size=300, padding=20):
    cell_size = (size - padding * 2) / 3
    svg = f'<svg width="{size}" height="{size}" style="display:block;margin:12px 0;">'
    # Background
    svg += f'<rect x="{padding}" y="{padding}" width="{size - padding * 2}" height="{size - padding * 2}" fill="#fafafa00" stroke-width="0" rx="4"></rect>'
    # Grid lines
    for i in range(1, 3):
        # Vertical
        x = padding + i * cell_size
        svg += f'<line x1="{x}" y1="{padding}" x2="{x}" y2="{size - padding}" stroke="#666" stroke-width="3"></line>'
        # Horizontal
        y = padding + i * cell_size
        svg += f'<line x1="{padding}" y1="{y}" x2="{size - padding}" y2="{y}" stroke="#666" stroke-width="3"></line>'
    # Marks
    for i in range(9):
        row = i // 3
        col = i % 3
        cx = padding + col * cell_size + cell_size / 2
        cy = padding + row * cell_size + cell_size / 2
        val = board[i] if i < len(board) else 0
        if val == 1:
            size_x = cell_size * 0.5
            svg += f'<line x1="{cx - size_x / 2}" y1="{cy - size_x / 2}" x2="{cx + size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="5" stroke-linecap="round"></line>'
            svg += f'<line x1="{cx + size_x / 2}" y1="{cy - size_x / 2}" x2="{cx - size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="5" stroke-linecap="round"></line>'
        elif val == -1:
            svg += f'<circle cx="{cx}" cy="{cy}" r="{cell_size * 0.25}" fill="none" stroke="#DA7A09" stroke-width="5"></circle>'
    svg += "</svg>"
    return svg


from IPython.display import HTML, display

# demo: board sample (0 empty, 1 X, -1 O)
demo = [0, 1, 0, 0, 1, -1, 0, -1, 1]
display(HTML(draw_tic_tac_toe(demo, size=320, padding=20)))

 It is O's turn to play. There are 4 possible actions to choose from this state.

In [None]:
def draw_tic_tac_toe(board, size=300, padding=20):
    cell_size = (size - padding * 2) / 3
    svg = f'<svg width="{size}" height="{size}" style="display:block;margin:12px 0;">'
    # Background
    svg += f'<rect x="{padding}" y="{padding}" width="{size - padding * 2}" height="{size - padding * 2}" fill="#fafafa00" stroke-width="0" rx="4"></rect>'
    # Grid lines
    for i in range(1, 3):
        # Vertical
        x = padding + i * cell_size
        svg += f'<line x1="{x}" y1="{padding}" x2="{x}" y2="{size - padding}" stroke="#666" stroke-width="3"></line>'
        # Horizontal
        y = padding + i * cell_size
        svg += f'<line x1="{padding}" y1="{y}" x2="{size - padding}" y2="{y}" stroke="#666" stroke-width="3"></line>'
    # Marks
    for i in range(9):
        row = i // 3
        col = i % 3
        cx = padding + col * cell_size + cell_size / 2
        cy = padding + row * cell_size + cell_size / 2
        val = board[i] if i < len(board) else 0
        if val == 1:
            size_x = cell_size * 0.5
            svg += f'<line x1="{cx - size_x / 2}" y1="{cy - size_x / 2}" x2="{cx + size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="5" stroke-linecap="round"></line>'
            svg += f'<line x1="{cx + size_x / 2}" y1="{cy - size_x / 2}" x2="{cx - size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="5" stroke-linecap="round"></line>'
        elif val == -1:
            svg += f'<circle cx="{cx}" cy="{cy}" r="{cell_size * 0.25}" fill="none" stroke="#DA7A09" stroke-width="5"></circle>'
    svg += "</svg>"
    return svg


from IPython.display import HTML, display

# Root state: O's turn
root_state = [0, 1, 0, 0, 1, -1, 0, -1, 1]

# Calculate possible actions (empty cells)
possible_actions = [i for i in range(9) if root_state[i] == 0]

# Create child states
child_states = []
for action in possible_actions:
    new_state = root_state.copy()
    new_state[action] = -1  # O plays
    child_states.append({"action": action, "state": new_state})

# Layout parameters
board_size = 180
board_padding = 12
horizontal_spacing = 320
vertical_spacing = 350
tree_width = len(child_states) * horizontal_spacing
tree_height = vertical_spacing * 2 + board_size

# Build HTML
html = f'<div style="width:{tree_width}px;height:{tree_height}px;position:relative;margin:20px auto;">'

# SVG for connections
svg_connections = f'<svg width="{tree_width}" height="{tree_height}" style="position:absolute;top:0;left:0;">'

# defs for arrow
svg_connections += '<defs><marker id="arrow" viewBox="0 0 10 10" refX="6" refY="5" markerWidth="6" markerHeight="6" orient="auto"><path d="M 0 0 L 10 5 L 0 10 z" fill="#999"></path></marker></defs>'

root_x = tree_width / 2
root_y = 40

# Root board
root_svg = draw_tic_tac_toe(root_state, size=board_size, padding=board_padding)
html += f'<div style="position:absolute;left:{root_x - board_size / 2}px;top:{root_y}px;">{root_svg}</div>'

# Circle around root
svg_connections += f'<circle cx="{root_x}" cy="{root_y + board_size / 2}" r="{board_size / 2 + 10}" fill="none" stroke="#ccc" stroke-width="2"></circle>'

child_y = root_y + vertical_spacing
start_x = (tree_width - (len(child_states) - 1) * horizontal_spacing) / 2

# Draw children
for i, child in enumerate(child_states):
    child_x = start_x + i * horizontal_spacing

    # Draw curved connection path with arrow
    control_y = (root_y + board_size + child_y) / 2
    d = f"M {root_x} {root_y + board_size} Q {child_x} {control_y} {child_x} {child_y}"
    svg_connections += f'<path d="{d}" stroke="#999" stroke-width="2" fill="none" marker-end="url(#arrow)"></path>'

    # Draw child board
    child_svg = draw_tic_tac_toe(child["state"], size=board_size, padding=board_padding)
    html += f'<div style="position:absolute;left:{child_x - board_size / 2}px;top:{child_y}px;">{child_svg}</div>'

    # Add circle around child
    svg_connections += f'<circle cx="{child_x}" cy="{child_y + board_size / 2}" r="{board_size / 2 + 10}" fill="none" stroke="#ccc" stroke-width="2"></circle>'

svg_connections += "</svg>"
html += svg_connections
html += "</div>"

display(HTML(html))

In [3]:
# Nodes of tree
class Node:
    def __init__(self, state, parent=None, action=None, prior=0):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0.0
        self.prior = prior

**state**: Array of 9 elements representing the board (0 = empty, 1 = X, -1 = O)

**parent**: Reference to the parent node

**action**: Integer from 0 to 8 representing the position to place the mark

**children**: Dictionary mapping actions to child nodes

**visit_count**: Number of times the node has been visited

**value_sum**: Sum of the values of the node

**prior**: Prior probability of selecting this node

## Selection

At each node, we select the child node that maximizes the PUCT (Prior Upper Confidence applied to Trees). 

$$ PUCT = Q + U $$

$$ Q = \frac{\text{value\_sum}}{1 + \text{visit\_count}} $$

$$ U = c_{\text{puct}} \cdot \text{prior} \cdot \frac{\sqrt{\text{parent\_visit\_count}}}{1 + \text{visit\_count}} $$

Generally, Q is the average value of the node, the *exploitation* term, and U is the *exploration* term. c_puct is a constant that controls the balance between exploration and exploitation.

Prior is the probability of selecting this action, as given by the policy network.

PUCT derives from the Upper Confidence Bound (UCB) algorithm used in multi-armed bandit problems.

Let's break it down.

In [4]:
# add function to the Node class


def select_child(self, c_puct):
    best_score = -float("inf")
    best_action = None
    best_child = None

    for action, child in self.children.items():
        ucb_score = child.value_sum / (1 + child.visit_count) + c_puct * child.prior * (
            self.visit_count**0.5
        ) / (1 + child.visit_count)
        if ucb_score > best_score:
            best_score = ucb_score
            best_action = action
            best_child = child

    return best_action, best_child