## 1. Input Representation (Feature Engineering)

Unlike Computer Vision tasks where inputs are RGB images, board games require feature engineering to preserve game rules and spatial logic.

We cannot simply pass a 2D grid of $0$s, $1$s, and $-1$s. CNNs struggle with negative values in inputs or categorical integers. Instead, we typically use One-Hot Encoding across multiple channels.

For a simple game like Tic-Tac-Toe, a state $s$ is often represented as a $(3, H, W)$ tensor:

1. Channel 0: Location of Player 1's pieces (binary).
2. Channel 1: Location of Player 2's pieces (binary).
3. Channel 2: Whose turn it is (filled with 1s if P1, 0s if P2).

Note: In complex games like Chess or Go, AlphaZero stacks the last $T$ board positions (History Planes) to capture move repetition and dynamic state.

In [21]:
import ipywidgets as widgets
from IPython.display import display, HTML, Javascript
import json


class InteractiveTicTacToe:
    def __init__(self, initial_state=None, player_turn=1):
        if initial_state is None:
            initial_state = [0, 1, 0, 0, 1, -1, 0, -1, 1]
        self.state = list(initial_state)
        self.player_turn = player_turn
        self.size = 200
        self.padding = 20
        self.cell_size = (self.size - self.padding * 2) / 3

        # Unique ID for this instance
        import random

        self.widget_id = f"ttt_{random.randint(1000, 9999)}"

    def viz_node_interactive(self):
        """Generate interactive SVG for the main board"""
        svg = f'<svg id="{self.widget_id}_board" width="{self.size}" height="{self.size}" style="display:block; margin: 0 auto;">'

        # Background
        svg += f'<rect x="{self.padding}" y="{self.padding}" width="{self.size - self.padding * 2}" height="{self.size - self.padding * 2}" fill="#fafafa" stroke-width="0" rx="0"></rect>'

        # Grid lines
        for i in range(1, 3):
            x = self.padding + i * self.cell_size
            svg += f'<line x1="{x}" y1="{self.padding}" x2="{x}" y2="{self.size - self.padding}" stroke="#666" stroke-width="2"></line>'
            y = self.padding + i * self.cell_size
            svg += f'<line x1="{self.padding}" y1="{y}" x2="{self.size - self.padding}" y2="{y}" stroke="#666" stroke-width="2"></line>'

        # Clickable cells and marks
        for i in range(9):
            row = i // 3
            col = i % 3
            cx = self.padding + col * self.cell_size + self.cell_size / 2
            cy = self.padding + row * self.cell_size + self.cell_size / 2
            x = self.padding + col * self.cell_size
            y = self.padding + row * self.cell_size
            val = self.state[i]

            # Clickable area with hover effect
            svg += f'<rect class="cell-{i}" x="{x}" y="{y}" width="{self.cell_size}" height="{self.cell_size}" fill="transparent" stroke="none" style="cursor:pointer;" onclick="handleCellClick_{self.widget_id}({i})"></rect>'

            # Draw X
            if val == 1:
                size_x = self.cell_size * 0.5
                svg += f'<g class="mark-{i}" style="pointer-events:none;">'
                svg += f'<line x1="{cx - size_x / 2}" y1="{cy - size_x / 2}" x2="{cx + size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="3" stroke-linecap="round"></line>'
                svg += f'<line x1="{cx + size_x / 2}" y1="{cy - size_x / 2}" x2="{cx - size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="3" stroke-linecap="round"></line>'
                svg += "</g>"
            # Draw O
            elif val == -1:
                svg += f'<circle class="mark-{i}" cx="{cx}" cy="{cy}" r="{self.cell_size * 0.25}" fill="none" stroke="#DA7A09" stroke-width="3" style="pointer-events:none;"></circle>'

        svg += "</svg>"
        return svg

    def viz_encoding_channels(self):
        """Generate SVG for the three encoding channels"""
        channel_size = 140
        channel_padding = 14
        channel_cell_size = (channel_size - channel_padding * 2) / 3
        gap = 20
        total_width = 3 * channel_size + 2 * gap

        def render_channel(offset_x, title, mask_fn, color):
            svg = ""
            svg += f'<text x="{offset_x + channel_size / 2}" y="{channel_padding - 4}" text-anchor="middle" fill="#4D5461" font-family="sans-serif" font-size="12" font-weight="600">{title}</text>'

            # Grid lines
            for i in range(1, 3):
                x = offset_x + channel_padding + i * channel_cell_size
                svg += f'<line x1="{x}" y1="{channel_padding}" x2="{x}" y2="{channel_size - channel_padding}" stroke="#999" stroke-width="1.5"></line>'
                y = channel_padding + i * channel_cell_size
                svg += f'<line x1="{offset_x + channel_padding}" y1="{y}" x2="{offset_x + channel_size - channel_padding}" y2="{y}" stroke="#999" stroke-width="1.5"></line>'

            # Cells
            for i in range(9):
                row = i // 3
                col = i % 3
                x = offset_x + channel_padding + col * channel_cell_size
                y = channel_padding + row * channel_cell_size

                if title == "Ch 2: Turn":
                    val = 1 if self.player_turn == 1 else 0
                else:
                    val = mask_fn(self.state[i])

                fill_opacity = "0.7" if val == 1 else "0.0"
                svg += f'<rect x="{x + 1}" y="{y + 1}" width="{channel_cell_size - 2}" height="{channel_cell_size - 2}" fill="{color}" fill-opacity="{fill_opacity}" stroke="none" rx="0"></rect>'
                svg += f'<text x="{x + channel_cell_size / 2}" y="{y + channel_cell_size / 2 + 4}" text-anchor="middle" fill="currentColor" font-family="monospace" font-size="11">{val}</text>'

            return svg

        svg = f'<svg id="{self.widget_id}_encoding" width="{total_width}" height="{channel_size}" style="display:block; margin: 0 auto;">'
        svg += render_channel(0, "Ch 0: X", lambda x: 1 if x == 1 else 0, "#0969da")
        svg += render_channel(
            channel_size + gap, "Ch 1: O", lambda x: 1 if x == -1 else 0, "#DA7A09"
        )
        svg += render_channel(
            2 * (channel_size + gap), "Ch 2: Turn", lambda x: 0, "#EEE"
        )
        svg += "</svg>"
        return svg

    def generate_html(self):
        """Generate complete HTML with inline JavaScript"""
        html = f"""
        <div id="{self.widget_id}_container" style="font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; color: var(--text);">
            <div style="padding: 20px; margin-bottom: 20px;">
                <p style="text-align: center; font-size: 14px; margin-bottom: 16px;">Click empty cells to place current player's mark. Click occupied cells to clear them.</p>
                {self.viz_node_interactive()}
                <div style="text-align: center; margin-top: 16px;">
                    <button onclick="clearBoard_{self.widget_id}()" style="padding: 8px 16px; margin: 0 5px; background: #64748b; color: white; border: none; border-radius: 0px; cursor: pointer; font-size: 14px; font-weight: 500;">Clear Board</button>
                </div>
            </div>

            <div style="padding: 20px;text-align: center;">
                <h3 style=" margin-bottom: 8px;">State Vector</h3>
                <div id="{self.widget_id}_state" style="font-family: monospace; font-size: 14px; background: #f1f5f9; padding: 12px; border-radius: 6px; width: fit-content; margin: 0 auto;">
                    [{", ".join(map(str, self.state))}]
                </div>
                <p style="font-family: monospace;font-size: 12px; margin-top: 8px;">1 = X, -1 = O, 0 = Empty</p>
            </div>
            
            <div style="padding: 20px; margin-bottom: 20px;">
                <h3 style="text-align: center;  margin-bottom: 8px;">Neural Network Input Encoding</h3>
                <p style="text-align: center;  font-size: 13px; margin-bottom: 16px;">How the board state is represented as three channels</p>
                {self.viz_encoding_channels()}
            </div>
            
            
        </div>
        
        <script>
        (function() {{
            // State management
            let state_{self.widget_id} = {json.dumps(self.state)};
            let playerTurn_{self.widget_id} = {self.player_turn};
            
            // Handle cell click
            window.handleCellClick_{self.widget_id} = function(index) {{
                // Only allow move if cell is empty
                if (state_{self.widget_id}[index] === 0) {{
                    state_{self.widget_id}[index] = playerTurn_{self.widget_id};
                    // Toggle turn after move
                    playerTurn_{self.widget_id} = playerTurn_{self.widget_id} === 1 ? -1 : 1;
                    document.getElementById('{self.widget_id}_turn').textContent = playerTurn_{self.widget_id} === 1 ? 'X' : 'O';
                }} else {{
                    // If cell is occupied, clear it (alternative behavior)
                    state_{self.widget_id}[index] = 0;
                }}
                updateDisplay_{self.widget_id}();
            }};
            
            // Clear board
            window.clearBoard_{self.widget_id} = function() {{
                state_{self.widget_id} = [0, 0, 0, 0, 0, 0, 0, 0, 0];
                playerTurn_{self.widget_id} = 1;  // Reset to X's turn
                document.getElementById('{self.widget_id}_turn').textContent = 'X';
                updateDisplay_{self.widget_id}();
            }};
            
            // Toggle turn
            window.toggleTurn_{self.widget_id} = function() {{
                playerTurn_{self.widget_id} = playerTurn_{self.widget_id} === 1 ? -1 : 1;
                document.getElementById('{self.widget_id}_turn').textContent = playerTurn_{self.widget_id} === 1 ? 'X' : 'O';
                updateDisplay_{self.widget_id}();
            }};
            
            // Update all displays
            window.updateDisplay_{self.widget_id} = function() {{
                updateBoard_{self.widget_id}();
                updateEncoding_{self.widget_id}();
                updateStateVector_{self.widget_id}();
            }};
            
            // Update board SVG
            function updateBoard_{self.widget_id}() {{
                const svg = document.getElementById('{self.widget_id}_board');
                const size = {self.size};
                const padding = {self.padding};
                const cellSize = {self.cell_size};
                
                // Remove old marks
                for (let i = 0; i < 9; i++) {{
                    const oldMark = svg.querySelector('.mark-' + i);
                    if (oldMark) oldMark.remove();
                }}
                
                // Draw new marks
                for (let i = 0; i < 9; i++) {{
                    const row = Math.floor(i / 3);
                    const col = i % 3;
                    const cx = padding + col * cellSize + cellSize / 2;
                    const cy = padding + row * cellSize + cellSize / 2;
                    const val = state_{self.widget_id}[i];
                    
                    if (val === 1) {{
                        const sizeX = cellSize * 0.5;
                        const g = document.createElementNS('http://www.w3.org/2000/svg', 'g');
                        g.setAttribute('class', 'mark-' + i);
                        g.style.pointerEvents = 'none';
                        
                        const line1 = document.createElementNS('http://www.w3.org/2000/svg', 'line');
                        line1.setAttribute('x1', cx - sizeX / 2);
                        line1.setAttribute('y1', cy - sizeX / 2);
                        line1.setAttribute('x2', cx + sizeX / 2);
                        line1.setAttribute('y2', cy + sizeX / 2);
                        line1.setAttribute('stroke', '#0969da');
                        line1.setAttribute('stroke-width', '3');
                        line1.setAttribute('stroke-linecap', 'round');
                        
                        const line2 = document.createElementNS('http://www.w3.org/2000/svg', 'line');
                        line2.setAttribute('x1', cx + sizeX / 2);
                        line2.setAttribute('y1', cy - sizeX / 2);
                        line2.setAttribute('x2', cx - sizeX / 2);
                        line2.setAttribute('y2', cy + sizeX / 2);
                        line2.setAttribute('stroke', '#0969da');
                        line2.setAttribute('stroke-width', '3');
                        line2.setAttribute('stroke-linecap', 'round');
                        
                        g.appendChild(line1);
                        g.appendChild(line2);
                        svg.appendChild(g);
                    }} else if (val === -1) {{
                        const circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle');
                        circle.setAttribute('class', 'mark-' + i);
                        circle.setAttribute('cx', cx);
                        circle.setAttribute('cy', cy);
                        circle.setAttribute('r', cellSize * 0.25);
                        circle.setAttribute('fill', 'none');
                        circle.setAttribute('stroke', '#DA7A09');
                        circle.setAttribute('stroke-width', '3');
                        circle.style.pointerEvents = 'none';
                        svg.appendChild(circle);
                    }}
                }}
            }}
            
            // Update encoding channels
            function updateEncoding_{self.widget_id}() {{
                const svg = document.getElementById('{self.widget_id}_encoding');
                // For simplicity, regenerate the entire encoding SVG
                // In production, you'd update individual elements
                const channelSize = 140;
                const channelPadding = 14;
                const channelCellSize = (channelSize - channelPadding * 2) / 3;
                const gap = 20;
                
                // Clear and rebuild (simple approach)
                const newSVG = generateEncodingSVG_{self.widget_id}();
                svg.innerHTML = newSVG;
            }}
            
            function generateEncodingSVG_{self.widget_id}() {{
                const channelSize = 140;
                const channelPadding = 14;
                const channelCellSize = (channelSize - channelPadding * 2) / 3;
                const gap = 20;
                let svg = '';
                
                function renderChannel(offsetX, title, maskFn, color) {{
                    let s = '';
                    s += `<text x="${{offsetX + channelSize / 2}}" y="${{channelPadding - 4}}" text-anchor="middle" fill="currentColor"  font-size="12" >${{title}}</text>`;
                    
                    for (let i = 1; i < 3; i++) {{
                        const x = offsetX + channelPadding + i * channelCellSize;
                        s += `<line x1="${{x}}" y1="${{channelPadding}}" x2="${{x}}" y2="${{channelSize - channelPadding}}" stroke="#999" stroke-width="1.5"></line>`;
                        const y = channelPadding + i * channelCellSize;
                        s += `<line x1="${{offsetX + channelPadding}}" y1="${{y}}" x2="${{offsetX + channelSize - channelPadding}}" y2="${{y}}" stroke="#999" stroke-width="1.5"></line>`;
                    }}
                    
                    for (let i = 0; i < 9; i++) {{
                        const row = Math.floor(i / 3);
                        const col = i % 3;
                        const x = offsetX + channelPadding + col * channelCellSize;
                        const y = channelPadding + row * channelCellSize;
                        const val = maskFn(i);
                        const fillOpacity = val === 1 ? 0.7 : 0.0;
                        
                        s += `<rect x="${{x + 1}}" y="${{y + 1}}" width="${{channelCellSize - 2}}" height="${{channelCellSize - 2}}" fill="${{color}}" fill-opacity="${{fillOpacity}}" stroke="none" rx="0"></rect>`;
                        s += `<text x="${{x + channelCellSize / 2}}" y="${{y + channelCellSize / 2 + 4}}" text-anchor="middle" fill="currentColor" font-family="monospace" font-size="11">${{val}}</text>`;
                    }}
                    
                    return s;
                }}
                
                svg += renderChannel(0, "Ch 0: X", (i) => state_{self.widget_id}[i] === 1 ? 1 : 0, "#0969da");
                svg += renderChannel(channelSize + gap, "Ch 1: O", (i) => state_{self.widget_id}[i] === -1 ? 1 : 0, "#DA7A09");
                svg += renderChannel(2 * (channelSize + gap), "Ch 2: Turn", () => playerTurn_{self.widget_id} === 1 ? 1 : 0, "#EEE");
                
                return svg;
            }}
            
            // Update state vector display
            function updateStateVector_{self.widget_id}() {{
                document.getElementById('{self.widget_id}_state').textContent = '[' + state_{self.widget_id}.join(', ') + ']';
            }}
        }})();
        </script>
        """
        return html

    def show(self):
        """Display the interactive widget"""
        display(HTML(self.generate_html()))


# Usage example:
ttt_instance = InteractiveTicTacToe()
ttt_instance.show()

## 2. The Architecture: Dual-Headed ResNet

The AlphaZero network consists of three main stages:

1. Convolutional Block: Initial feature extraction.
2. Residual Tower (Backbone): A stack of Residual Blocks. This is the heavy lifting.
3. Heads: The network splits into two separate output paths.

### Why Residual Blocks?

In board games, the relationship between a piece at $(0, 0)$ and $(7, 7)$ (like a Bishop sniping across the board) requires a large receptive field. To achieve this, we need deep networks.

However, simply stacking layers leads to the vanishing gradient problem. Residual connections (skip connections) allow gradients to flow through the network unimpeded during backpropagation: $y = F(x) + x$.

### Batch Normalization

AlphaZero relies heavily on Batch Normalization. In Reinforcement Learning, the distribution of input data changes constantly as the agent improves (Covariate Shift). BatchNorm helps stabilize training dynamics.

```mermaid
graph TD
    Input["Input Tensor (B, C, H, W)"] --> ConvBlock[Conv2d + BN + ReLU]
    ConvBlock --> Res1[Residual Block 1]
    Res1 --> Res2[Residual Block 2]
    Res2 --> ResN[... Residual Block N]
    
    ResN --> Split{Split}
    
    subgraph Policy Head
        Split --> P_Conv["Conv 1x1 (2 filters)"]
        P_Conv --> P_BN[BatchNorm]
        P_BN --> P_Relu[ReLU]
        P_Relu --> P_Flat[Flatten]
        P_Flat --> P_FC[Linear]
        P_FC --> P_Out["Softmax (Probabilities)"]
    end
    
    subgraph Value Head
        Split --> V_Conv["Conv 1x1 (1 filter)"]
        V_Conv --> V_BN[BatchNorm]
        V_BN --> V_Relu[ReLU]
        V_Relu --> V_Flat[Flatten]
        V_Flat --> V_FC1[Linear + ReLU]
        V_FC1 --> V_FC2[Linear]
        V_FC2 --> V_Out["Tanh (Scalar -1 to 1)"]
    end
    
    style Split fill:#f9f,stroke:#333,stroke-width:2px
```

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(
            num_channels, num_channels, kernel_size=3, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = nn.Conv2d(
            num_channels, num_channels, kernel_size=3, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual  # Skip Connection
        out = F.relu(out)
        return out


class AlphaZeroNet(nn.Module):
    def __init__(self, game_channels, num_res_blocks, action_size, board_size):
        super().__init__()
        self.num_channels = 64

        # Initial Convolution
        self.conv_input = nn.Conv2d(
            game_channels, self.num_channels, kernel_size=3, padding=1, bias=False
        )
        self.bn_input = nn.BatchNorm2d(self.num_channels)

        # Backbone
        self.res_blocks = nn.ModuleList(
            [ResidualBlock(self.num_channels) for _ in range(num_res_blocks)]
        )

        # Policy Head
        self.p_conv = nn.Conv2d(self.num_channels, 2, kernel_size=1, bias=False)
        self.p_bn = nn.BatchNorm2d(2)
        self.p_fc = nn.Linear(2 * board_size * board_size, action_size)

        # Value Head
        self.v_conv = nn.Conv2d(self.num_channels, 1, kernel_size=1, bias=False)
        self.v_bn = nn.BatchNorm2d(1)
        self.v_fc1 = nn.Linear(1 * board_size * board_size, 64)
        self.v_fc2 = nn.Linear(64, 1)

    def forward(self, x):
        # x: (Batch, Channels, Height, Width)
        x = F.relu(self.bn_input(self.conv_input(x)))

        for res_block in self.res_blocks:
            x = res_block(x)

        # Policy Head
        p = F.relu(self.p_bn(self.p_conv(x)))
        p = p.view(p.size(0), -1)  # Flatten
        p = F.softmax(self.p_fc(p), dim=1)

        # Value Head
        v = F.relu(self.v_bn(self.v_conv(x)))
        v = v.view(v.size(0), -1)  # Flatten
        v = F.relu(self.v_fc1(v))
        v = torch.tanh(self.v_fc2(v))

        return p, v

## 3. Data Generation: MCTS as a Teacher

The network creates its own training data via self-play.

For every move in a self-play game, we execute an MCTS search. The result of the search is a set of visit counts $N(s, a)$ for the root node.

Key Concept: The MCTS output ($\pi$) is a stronger policy than the Neural Network's raw output ($p$). By training the network to predict the MCTS distribution, we perform Policy Iteration.

### Temperature ($\tau$)

To control exploration during data generation, we convert visit counts $N$ into probabilities $\pi$ using a temperature parameter $\tau$:

$$\pi(a|s) = \frac{N(s, a)^{1/\tau}}{\sum_b N(s, b)^{1/\tau}}$$

- High $\tau$ (e.g., 1.0): Flattens the distribution. Used early in the game to ensure diverse opening play.
- Low $\tau$ (e.g., $\rightarrow 0$): Sharpens the distribution (Argmax). Used later in the game for competitive play.

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot_temperature(temp):
    # Dummy visit counts for 5 actions
    visits = np.array([10, 20, 150, 5, 40])
    actions = ["A1", "A2", "A3", "A4", "A5"]

    # Handle near-zero temperature for argmax behavior
    if temp < 0.05:
        probs = np.zeros_like(visits, dtype=float)
        probs[np.argmax(visits)] = 1.0
        title_suffix = "(Argmax / Exploitation)"
    else:
        # Apply temperature formula
        exponentiated = visits ** (1 / temp)
        probs = exponentiated / np.sum(exponentiated)
        title_suffix = "(Exploration)"

    # Plotting
    plt.figure(figsize=(8, 4))
    bars = plt.bar(actions, probs, color="#0969da")
    plt.ylim(0, 1.1)
    plt.ylabel("Probability π")
    plt.title(f"Target Policy Distribution at τ={temp:.2f} {title_suffix}")

    # Add text labels
    for bar, visit in zip(bars, visits):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            0.05,
            f"N={visit}",
            ha="center",
            va="bottom",
            color="white",
            fontweight="bold",
        )

    plt.show()


w_temp = widgets.FloatSlider(
    value=1.0, min=0.01, max=2.0, step=0.05, description="Temp (τ):"
)

widgets.interactive(plot_temperature, temp=w_temp)

interactive(children=(FloatSlider(value=1.0, description='Temp (τ):', max=2.0, min=0.01, step=0.05), Output())…

## 4. The Loss Function

We train the network to minimize error on both heads simultaneously. The total loss is a sum of the Value loss, the Policy loss, and regularization.

$$L = (z - v)^2 - \boldsymbol{\pi}^T \log \mathbf{p} + c||\theta||^2$$

1. Value Loss $(z - v)^2$:
   - MSE (Mean Squared Error) between the predicted value $v$ and the actual game outcome $z$ ($+1$ for win, $-1$ for loss).
   - This teaches the network to evaluate positions accurately.

2. Policy Loss $-\boldsymbol{\pi}^T \log \mathbf{p}$:
   - Cross-Entropy between the MCTS visit probabilities $\boldsymbol{\pi}$ (target) and the network's move probabilities $\mathbf{p}$ (prediction).
   - This teaches the network to imitate the search results.

3. Regularization $c||\theta||^2$:
   - Standard L2 weight decay to prevent overfitting to the specific games in the replay buffer.

In [4]:
class AlphaZeroLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, pred_policy, pred_value, target_policy, target_value):
        """
        pred_policy: Logits or LogSoftmax from network
        target_policy: MCTS probabilities
        pred_value: Tanh output from network
        target_value: Actual game result {-1, 1}
        """
        # 1. Value Loss (MSE)
        value_loss = self.mse(pred_value.view(-1), target_value.view(-1))

        # 2. Policy Loss (Cross Entropy)
        # Note: We compute -sum(target * log(pred))
        # Adding a small epsilon for numerical stability if not using LogSoftmax directly
        policy_loss = -torch.sum(
            target_policy * torch.log(pred_policy + 1e-8), dim=1
        ).mean()

        # Total Loss
        return value_loss + policy_loss

## 5. The Training Loop

The AlphaZero training process is cyclical:

1. Self-Play: The current best network plays games against itself (using MCTS). The games are saved as $(s, \pi, z)$ tuples.
2. Symmetries: For board games, we augment the data. Rotating the board 90 degrees or flipping it produces a valid state. This expands the dataset by 8x without new simulations.
3. Training: We sample a mini-batch from the replay buffer and perform SGD/Adam to update $\theta$.
4. Evaluation: (Optional) In the original AlphaGo Zero, the new network played against the old one. If it won >55%, it replaced the old one. In later versions (AlphaZero), this step was removed, and the network was updated continuously.

This cycle repeats thousands of times, with the MCTS search effectively acting as a policy improvement operator at every step.