lets learn AlphaZero by using it to play 2d Tic-Tac-Toe!

# αMCTS

Game trees are a way to represent game states and actions. Each node in the tree represents a game state, and each edge represents a possible action that can be taken from that state.

The aim is, give a state of the game, to build a tree of possible future states by exploring possible actions. This tree can then be used to make decisions about which action to take in the current state.

Monte Carlo Tree Search (MCTS) is a method for building game trees by simulating random games from the current state. The basic idea is to use random simulations to estimate the value of each action, and then use these estimates to guide the search for the best action.

AlphaZero uses a modified version of MCTS. The key difference is that instead of using random simulations, αMCTS uses a neural network to evaluate the value of each state and the probability of each action. This allows αMCTS to make more informed decisions about which actions to explore.

Three main steps in αMCTS:
1. Selection: Starting from the root node, recursively select child nodes until a leaf node is reached. The selection is based on a balance between exploration (trying out less-visited nodes) and exploitation (choosing nodes with high value estimates).
2. Expansion: If the leaf node is not a terminal state, expand the tree by adding one or more child nodes representing possible actions from the leaf node.
3. Simulation and Backpropagation: Use the neural network to evaluate the value of the new node and the probabilities of each action. This information is then backpropagated up the tree to update the value estimates and visit counts of the nodes along the path from the root to the new node.

Let's implement a simple version of αMCTS for 2D Tic-Tac-Toe!

Lets assume we have good policy and value networks already trained for Tic-Tac-Toe.

Start with a random state 

In [1]:
%%javascript
(function(element) {
    var container = element.get(0);
    var svgNS = 'http://www.w3.org/2000/svg';

    // Draw a Tic-Tac-Toe board given a board array of 9 elements (0 empty, 1 X, -1 O)
    window.drawTicTacToe = function(board, opts) {
        opts = opts || {};
        var size = opts.size || 300;
        var padding = opts.padding || 20;
        var cellSize = (size - padding * 2) / 3;
        
        var existing = container.querySelector('svg.tictactoe');
        var svg;
        if (existing) {
            svg = existing;
            svg.setAttribute('width', size);
            svg.setAttribute('height', size);
            svg.innerHTML = '';
        } else {
            svg = document.createElementNS(svgNS, 'svg');
            svg.classList.add('tictactoe');
            svg.setAttribute('width', size);
            svg.setAttribute('height', size);
            svg.style.display = 'block';
            svg.style.margin = '12px 0';
            container.appendChild(svg);
        }

        // Board background with rounded corners
        var bg = document.createElementNS(svgNS, 'rect');
        bg.setAttribute('x', padding);
        bg.setAttribute('y', padding);
        bg.setAttribute('width', size - padding * 2);
        bg.setAttribute('height', size - padding * 2);
        bg.setAttribute('fill', '#fafafa00');
        bg.setAttribute('stroke-width', '0');
        bg.setAttribute('rx', '4');
        svg.appendChild(bg);

        // Grid lines
        for (var i = 1; i < 3; i++) {
            // Vertical lines
            var vline = document.createElementNS(svgNS, 'line');
            vline.setAttribute('x1', padding + i * cellSize);
            vline.setAttribute('y1', padding);
            vline.setAttribute('x2', padding + i * cellSize);
            vline.setAttribute('y2', size - padding);
            vline.setAttribute('stroke', '#666');
            vline.setAttribute('stroke-width', '3');
            svg.appendChild(vline);

            // Horizontal lines
            var hline = document.createElementNS(svgNS, 'line');
            hline.setAttribute('x1', padding);
            hline.setAttribute('y1', padding + i * cellSize);
            hline.setAttribute('x2', size - padding);
            hline.setAttribute('y2', padding + i * cellSize);
            hline.setAttribute('stroke', '#666');
            hline.setAttribute('stroke-width', '3');
            svg.appendChild(hline);
        }

        // Draw marks
        for (var i = 0; i < 9; i++) {
            var row = Math.floor(i / 3);
            var col = i % 3;
            var cx = padding + col * cellSize + cellSize / 2;
            var cy = padding + row * cellSize + cellSize / 2;
            var val = board[i] || 0;

            if (val === 1) {
                // Draw X in blue
                var sizeX = cellSize * 0.5;
                var line1 = document.createElementNS(svgNS, 'line');
                line1.setAttribute('x1', cx - sizeX / 2);
                line1.setAttribute('y1', cy - sizeX / 2);
                line1.setAttribute('x2', cx + sizeX / 2);
                line1.setAttribute('y2', cy + sizeX / 2);
                line1.setAttribute('stroke', '#0969da');
                line1.setAttribute('stroke-width', '5');
                line1.setAttribute('stroke-linecap', 'round');
                svg.appendChild(line1);

                var line2 = document.createElementNS(svgNS, 'line');
                line2.setAttribute('x1', cx + sizeX / 2);
                line2.setAttribute('y1', cy - sizeX / 2);
                line2.setAttribute('x2', cx - sizeX / 2);
                line2.setAttribute('y2', cy + sizeX / 2);
                line2.setAttribute('stroke', '#0969da');
                line2.setAttribute('stroke-width', '5');
                line2.setAttribute('stroke-linecap', 'round');
                svg.appendChild(line2);
            } else if (val === -1) {
                // Draw O in 
                var circle = document.createElementNS(svgNS, 'circle');
                circle.setAttribute('cx', cx);
                circle.setAttribute('cy', cy);
                circle.setAttribute('r', cellSize * 0.25);
                circle.setAttribute('fill', 'none');
                circle.setAttribute('stroke', '#DA7A09');
                circle.setAttribute('stroke-width', '5');
                svg.appendChild(circle);
            }
        }
        return svg;
    };

    // demo: board sample (0 empty, 1 X, -1 O)
    var demo = [0,1,0,0,1,-1,0,-1,1];
    window.drawTicTacToe(demo, { size: 320, padding: 20 });
})(element);

<IPython.core.display.Javascript object>

 It is O's turn to play. There are 4 possible actions to choose from this state.

In [6]:
%%javascript
    (function (element) {
        var container = element.get(0);

        // Root state: O's turn
        var rootState = [0, 1, 0, 0, 1, -1, 0, -1, 1];

        // Calculate possible actions (empty cells)
        var possibleActions = [];
        for (var i = 0; i < 9; i++) {
            if (rootState[i] === 0) {
                possibleActions.push(i);
            }
        }

        // Create child states
        var childStates = possibleActions.map(function (action) {
            var newState = rootState.slice();
            newState[action] = -1; // O plays
            return { action: action, state: newState };
        });

        // Layout parameters
        var boardSize = 180;
        var boardPadding = 12;
        var horizontalSpacing = 320;
        var verticalSpacing = 350;
        var treeWidth = childStates.length * horizontalSpacing;
        var treeHeight = verticalSpacing * 2 + boardSize;

        // Create main container
        var treeDiv = document.createElement('div');
        treeDiv.style.width = treeWidth + 'px';
        treeDiv.style.height = treeHeight + 'px';
        treeDiv.style.position = 'relative';
        treeDiv.style.margin = '20px auto';
        container.appendChild(treeDiv);

        // Create SVG for connections
        var svgNS = 'http://www.w3.org/2000/svg';
        var connectionsSvg = document.createElementNS(svgNS, 'svg');
        connectionsSvg.setAttribute('width', treeWidth);
        connectionsSvg.setAttribute('height', treeHeight);
        connectionsSvg.style.position = 'absolute';
        connectionsSvg.style.top = '0';
        connectionsSvg.style.left = '0';
        treeDiv.appendChild(connectionsSvg);

        // Add arrow marker definition for lines
        var defs = document.createElementNS(svgNS, 'defs');
        var marker = document.createElementNS(svgNS, 'marker');
        marker.setAttribute('id', 'arrow');
        marker.setAttribute('viewBox', '0 0 10 10');
        marker.setAttribute('refX', '6');
        marker.setAttribute('refY', '5');
        marker.setAttribute('markerWidth', '6');
        marker.setAttribute('markerHeight', '6');
        marker.setAttribute('orient', 'auto');
        var markerPath = document.createElementNS(svgNS, 'path');
        markerPath.setAttribute('d', 'M 0 0 L 10 5 L 0 10 z');
        markerPath.setAttribute('fill', '#999');
        marker.appendChild(markerPath);
        defs.appendChild(marker);
        connectionsSvg.appendChild(defs);

        // Root position
        var rootX = treeWidth / 2;
        var rootY = 40;

        // Draw root board
        var rootDiv = document.createElement('div');
        rootDiv.style.position = 'absolute';
        rootDiv.style.left = (rootX - boardSize / 2) + 'px';
        rootDiv.style.top = rootY + 'px';
        treeDiv.appendChild(rootDiv);

        var rootSvg = window.drawTicTacToe(rootState, { size: boardSize, padding: boardPadding });
        rootSvg.style.margin = '0';
        rootDiv.appendChild(rootSvg);

        // Add circle around root
        var rootCircle = document.createElementNS(svgNS, 'circle');
        rootCircle.setAttribute('cx', rootX);
        rootCircle.setAttribute('cy', rootY + boardSize / 2);
        rootCircle.setAttribute('r', boardSize / 2 + 10);
        rootCircle.setAttribute('fill', 'none');
        rootCircle.setAttribute('stroke', '#ccc');
        rootCircle.setAttribute('stroke-width', '2');
        connectionsSvg.appendChild(rootCircle);

        // Child positions
        var childY = rootY + verticalSpacing;
        var startX = (treeWidth - (childStates.length - 1) * horizontalSpacing) / 2;

        // Draw children
        childStates.forEach(function (child, i) {
            var childX = startX + i * horizontalSpacing;

            // Draw curved connection path with arrow
            var path = document.createElementNS(svgNS, 'path');
            var controlY = (rootY + boardSize + childY) / 2;
            var d = `M ${rootX} ${rootY + boardSize} Q ${childX} ${controlY} ${childX} ${childY}`;
            path.setAttribute('d', d);
            path.setAttribute('stroke', '#999');
            path.setAttribute('stroke-width', '2');
            path.setAttribute('fill', 'none');
            path.setAttribute('marker-end', 'url(#arrow)');
            connectionsSvg.appendChild(path);

            // Draw child board
            var childDiv = document.createElement('div');
            childDiv.style.position = 'absolute';
            childDiv.style.left = (childX - boardSize / 2) + 'px';
            childDiv.style.top = childY + 'px';
            treeDiv.appendChild(childDiv);

            var childSvg = window.drawTicTacToe(child.state, { size: boardSize, padding: boardPadding });
            childSvg.style.margin = '0';
            childDiv.appendChild(childSvg);

            // Add circle around child
            var childCircle = document.createElementNS(svgNS, 'circle');
            childCircle.setAttribute('cx', childX);
            childCircle.setAttribute('cy', childY + boardSize / 2);
            childCircle.setAttribute('r', boardSize / 2 + 10);
            childCircle.setAttribute('fill', 'none');
            childCircle.setAttribute('stroke', '#ccc');
            childCircle.setAttribute('stroke-width', '2');
            connectionsSvg.appendChild(childCircle);
        });

    })(element);

<IPython.core.display.Javascript object>

In [3]:
# Nodes of tree
class Node:
    def __init__(self, state, parent=None, action=None, prior=0):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0.0
        self.prior = prior

**state**: Array of 9 elements representing the board (0 = empty, 1 = X, -1 = O)

**parent**: Reference to the parent node

**action**: Integer from 0 to 8 representing the position to place the mark

**children**: Dictionary mapping actions to child nodes

**visit_count**: Number of times the node has been visited

**value_sum**: Sum of the values of the node

**prior**: Prior probability of selecting this node

## Selection

At each node, we select the child node that maximizes the PUCT (Prior Upper Confidence applied to Trees). 

$$ PUCT = Q + U $$

$$ Q = \frac{\text{value\_sum}}{1 + \text{visit\_count}} $$

$$ U = c_{\text{puct}} \cdot \text{prior} \cdot \frac{\sqrt{\text{parent\_visit\_count}}}{1 + \text{visit\_count}} $$

Generally, Q is the average value of the node, the *exploitation* term, and U is the *exploration* term. c_puct is a constant that controls the balance between exploration and exploitation.

Prior is the probability of selecting this action, as given by the policy network.

PUCT derives from the Upper Confidence Bound (UCB) algorithm used in multi-armed bandit problems.

Let's break it down.

In [4]:
# add function to the Node class


def select_child(self, c_puct):
    best_score = -float("inf")
    best_action = None
    best_child = None

    for action, child in self.children.items():
        ucb_score = child.value_sum / (1 + child.visit_count) + c_puct * child.prior * (
            self.visit_count**0.5
        ) / (1 + child.visit_count)
        if ucb_score > best_score:
            best_score = ucb_score
            best_action = action
            best_child = child

    return best_action, best_child