## 1. Input Representation

We cannot simply pass a 2D grid of $0$s, $1$s, and Instead, we typically use One-Hot Encoding across multiple channels.

For a simple game like Tic-Tac-Toe, a state $s$ is often represented as a $(3, H, W)$ tensor:

1. Channel 0: Location of Player 1's pieces (binary).
2. Channel 1: Location of Player 2's pieces (binary).
3. Channel 2: Whose turn it is (filled with 1s if P1, 0s if P2).

Here $H$ and $W$ are the board's height and width (3x3 for Tic-Tac-Toe).

In complex games like Chess or Go, AlphaZero stacks the last $T$ board positions (History Planes) to capture move repetition and dynamic state.

In [1]:
from IPython.display import display, HTML
import json
import random


class InteractiveTicTacToe:
    def __init__(self, initial_state=None, player_turn=1):
        if initial_state is None:
            initial_state = [0, 1, 0, 0, 1, -1, 0, -1, 1]

        self.state = list(initial_state)
        self.player_turn = 1 if player_turn == 1 else -1

        self.size = 200
        self.padding = 20
        self.cell_size = (self.size - self.padding * 2) / 3

        self.widget_id = f"ttt_{random.randint(1000, 9999)}"

    # -------- SVG helpers --------

    def _board_svg(self):
        """Base SVG for the board (grid + click areas + mark containers)."""
        svg = [
            f'<svg id="{self.widget_id}_board" width="{self.size}" height="{self.size}" '
            'style="display:block; margin: 0 auto;">'
        ]

        # Background
        svg.append(
            f'<rect x="{self.padding}" y="{self.padding}" '
            f'width="{self.size - self.padding * 2}" '
            f'height="{self.size - self.padding * 2}" '
            'fill="transparent" stroke-width="0" rx="0"></rect>'
        )

        # Grid lines
        for i in range(1, 3):
            x = self.padding + i * self.cell_size
            y = self.padding + i * self.cell_size
            svg.append(
                f'<line x1="{x}" y1="{self.padding}" x2="{x}" '
                f'y2="{self.size - self.padding}" stroke="#666" stroke-width="2"></line>'
            )
            svg.append(
                f'<line x1="{self.padding}" y1="{y}" '
                f'x2="{self.size - self.padding}" y2="{y}" stroke="#666" stroke-width="2"></line>'
            )

        # Clickable cells + mark containers
        for i in range(9):
            row = i // 3
            col = i % 3
            x = self.padding + col * self.cell_size
            y = self.padding + row * self.cell_size

            # Click area
            svg.append(
                f'<rect class="cell-{i}" x="{x}" y="{y}" '
                f'width="{self.cell_size}" height="{self.cell_size}" '
                'fill="transparent" stroke="none" style="cursor:pointer;" '
                f'onclick="handleCellClick_{self.widget_id}({i})"></rect>'
            )

            # Mark container (we’ll draw X/O inside via JS)
            svg.append(
                f'<g id="{self.widget_id}_mark_{i}" class="mark" '
                'style="pointer-events:none;"></g>'
            )

        svg.append("</svg>")
        return "".join(svg)

    def _encoding_svg(self):
        """SVG for the three encoding channels with per-cell IDs for fast updates."""
        channel_size = 140
        channel_padding = 14
        channel_cell_size = (channel_size - channel_padding * 2) / 3
        gap = 20
        total_width = 3 * channel_size + 2 * gap

        def render_channel(offset_x, ch_index, title, mask_fn, color):
            parts = []
            # Title
            parts.append(
                f'<text x="{offset_x + channel_size / 2}" '
                f'y="{channel_padding - 4}" text-anchor="middle" '
                'fill="#4D5461" font-family="sans-serif" font-size="12" font-weight="600">'
                f"{title}</text>"
            )

            # Grid lines
            for i in range(1, 3):
                x = offset_x + channel_padding + i * channel_cell_size
                y = channel_padding + i * channel_cell_size
                parts.append(
                    f'<line x1="{x}" y1="{channel_padding}" '
                    f'x2="{x}" y2="{channel_size - channel_padding}" '
                    'stroke="#999" stroke-width="1.5"></line>'
                )
                parts.append(
                    f'<line x1="{offset_x + channel_padding}" y1="{y}" '
                    f'x2="{offset_x + channel_size - channel_padding}" y2="{y}" '
                    'stroke="#999" stroke-width="1.5"></line>'
                )

            # Cells
            for i in range(9):
                row = i // 3
                col = i % 3
                x = offset_x + channel_padding + col * channel_cell_size
                y = channel_padding + row * channel_cell_size

                val = mask_fn(i)
                fill_opacity = "0.7" if val == 1 else "0.0"

                rect_id = f"{self.widget_id}_ch{ch_index}_cell{i}"
                text_id = f"{self.widget_id}_ch{ch_index}_text{i}"

                parts.append(
                    f'<rect id="{rect_id}" x="{x + 1}" y="{y + 1}" '
                    f'width="{channel_cell_size - 2}" height="{channel_cell_size - 2}" '
                    f'fill="{color}" fill-opacity="{fill_opacity}" '
                    'stroke="none" rx="0"></rect>'
                )
                parts.append(
                    f'<text id="{text_id}" x="{x + channel_cell_size / 2}" '
                    f'y="{y + channel_cell_size / 2 + 4}" text-anchor="middle" '
                    'fill="currentColor" font-family="monospace" font-size="11">'
                    f"{val}</text>"
                )

            return "".join(parts)

        # Initial mask functions (for first render only)
        def mask_x(i):
            return 1 if self.state[i] == 1 else 0

        def mask_o(i):
            return 1 if self.state[i] == -1 else 0

        def mask_turn(i):
            return 1 if self.player_turn == 1 else 0

        svg = [
            f'<svg id="{self.widget_id}_encoding" '
            f'width="{total_width}" height="{channel_size}" '
            'style="display:block; margin: 0 auto;">'
        ]

        svg.append(render_channel(0, 0, "Ch 0: X", mask_x, "#0969da"))
        svg.append(render_channel(channel_size + gap, 1, "Ch 1: O", mask_o, "#DA7A09"))
        svg.append(
            render_channel(2 * (channel_size + gap), 2, "Ch 2: Turn", mask_turn, "#EEE")
        )

        svg.append("</svg>")
        return "".join(svg)

    # -------- HTML + JS --------

    def generate_html(self):
        html = f"""
<div id="{self.widget_id}_container"
     style="font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; color: var(--text);">

  <div style="padding: 20px; margin-bottom: 20px;">
    <p style="text-align: center; font-size: 14px; margin-bottom: 8px;">
      Click empty cells to place current player's mark. Click occupied cells to clear them.
    </p>
    <p style="text-align: center; font-size: 13px; margin-bottom: 16px;">
      Current turn: <strong><span id="{self.widget_id}_turn">{"X" if self.player_turn == 1 else "O"}</span></strong>
    </p>
    {self._board_svg()}
    <div style="text-align: center; margin-top: 16px;">
      <button onclick="clearBoard_{self.widget_id}()"
              style="padding: 8px 16px; margin: 0 5px; background: #64748b; color: white;
                     border: none; border-radius: 4px; cursor: pointer; font-size: 14px; font-weight: 500;">
        Clear Board
      </button>
    </div>
  </div>

  <div style="padding: 20px; text-align: center;">
    <h3 style="margin-bottom: 8px;">State Vector</h3>
    <div id="{self.widget_id}_state"
         style="font-family: monospace; font-size: 14px; background: #f1f5f9; color: #333;
                padding: 12px; border-radius: 6px; width: fit-content; margin: 0 auto;">
      [{", ".join(map(str, self.state))}]
    </div>
    <p style="font-family: monospace; font-size: 12px; margin-top: 8px;">
      1 = X, -1 = O, 0 = Empty
    </p>
  </div>

  <div style="padding: 20px; margin-bottom: 20px;">
    <h3 style="text-align: center; margin-bottom: 8px;">Neural Network Input Encoding</h3>
    <p style="text-align: center; font-size: 13px; margin-bottom: 16px;">
      How the board state is represented as three channels
    </p>
    {self._encoding_svg()}
  </div>

</div>

<script>
(function() {{
  // ------- State -------
  let state_{self.widget_id} = {json.dumps(self.state)};
  let playerTurn_{self.widget_id} = {self.player_turn};

  // ------- Public handlers (for onclick attributes) -------
  function handleCellClick_{self.widget_id}(index) {{
    if (state_{self.widget_id}[index] === 0) {{
      // Place mark
      state_{self.widget_id}[index] = playerTurn_{self.widget_id};
      // Toggle turn
      playerTurn_{self.widget_id} = playerTurn_{self.widget_id} === 1 ? -1 : 1;
    }} else {{
      // Clear occupied cell, do not change turn
      state_{self.widget_id}[index] = 0;
    }}
    updateDisplay_{self.widget_id}();
  }}

  function clearBoard_{self.widget_id}() {{
    state_{self.widget_id} = [0,0,0,0,0,0,0,0,0];
    playerTurn_{self.widget_id} = 1;
    updateDisplay_{self.widget_id}();
  }}

  function toggleTurn_{self.widget_id}() {{
    playerTurn_{self.widget_id} = playerTurn_{self.widget_id} === 1 ? -1 : 1;
    updateDisplay_{self.widget_id}();
  }}

  // Expose to global for HTML onclick
  window.handleCellClick_{self.widget_id} = handleCellClick_{self.widget_id};
  window.clearBoard_{self.widget_id} = clearBoard_{self.widget_id};
  window.toggleTurn_{self.widget_id} = toggleTurn_{self.widget_id};

  // ------- Rendering helpers -------
  function updateDisplay_{self.widget_id}() {{
    updateBoard_{self.widget_id}();
    updateEncoding_{self.widget_id}();
    updateStateVector_{self.widget_id}();
    const turnEl = document.getElementById('{self.widget_id}_turn');
    if (turnEl) {{
      turnEl.textContent = playerTurn_{self.widget_id} === 1 ? 'X' : 'O';
    }}
  }}

  function updateBoard_{self.widget_id}() {{
    const svg = document.getElementById('{self.widget_id}_board');
    if (!svg) return;

    const size = {self.size};
    const padding = {self.padding};
    const cellSize = {self.cell_size};

    for (let i = 0; i < 9; i++) {{
      const g = document.getElementById('{self.widget_id}_mark_' + i);
      if (!g) continue;

      // Clear previous mark
      while (g.firstChild) {{
        g.removeChild(g.firstChild);
      }}

      const val = state_{self.widget_id}[i];
      if (val === 0) continue;

      const row = Math.floor(i / 3);
      const col = i % 3;
      const cx = padding + col * cellSize + cellSize / 2;
      const cy = padding + row * cellSize + cellSize / 2;

      if (val === 1) {{
        const sizeX = cellSize * 0.5;

        const line1 = document.createElementNS('http://www.w3.org/2000/svg', 'line');
        line1.setAttribute('x1', cx - sizeX / 2);
        line1.setAttribute('y1', cy - sizeX / 2);
        line1.setAttribute('x2', cx + sizeX / 2);
        line1.setAttribute('y2', cy + sizeX / 2);
        line1.setAttribute('stroke', '#0969da');
        line1.setAttribute('stroke-width', '3');
        line1.setAttribute('stroke-linecap', 'round');

        const line2 = document.createElementNS('http://www.w3.org/2000/svg', 'line');
        line2.setAttribute('x1', cx + sizeX / 2);
        line2.setAttribute('y1', cy - sizeX / 2);
        line2.setAttribute('x2', cx - sizeX / 2);
        line2.setAttribute('y2', cy + sizeX / 2);
        line2.setAttribute('stroke', '#0969da');
        line2.setAttribute('stroke-width', '3');
        line2.setAttribute('stroke-linecap', 'round');

        g.appendChild(line1);
        g.appendChild(line2);
      }} else if (val === -1) {{
        const circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle');
        circle.setAttribute('cx', cx);
        circle.setAttribute('cy', cy);
        circle.setAttribute('r', cellSize * 0.25);
        circle.setAttribute('fill', 'none');
        circle.setAttribute('stroke', '#DA7A09');
        circle.setAttribute('stroke-width', '3');
        g.appendChild(circle);
      }}
    }}
  }}

  function updateEncoding_{self.widget_id}() {{
    // Only update attributes of existing nodes — no innerHTML rebuilding
    const valTurn = playerTurn_{self.widget_id} === 1 ? 1 : 0;

    for (let i = 0; i < 9; i++) {{
      const valX = state_{self.widget_id}[i] === 1 ? 1 : 0;
      const valO = state_{self.widget_id}[i] === -1 ? 1 : 0;

      // Channel 0: X
      const rectX = document.getElementById('{self.widget_id}_ch0_cell' + i);
      const textX = document.getElementById('{self.widget_id}_ch0_text' + i);
      if (rectX) rectX.setAttribute('fill-opacity', valX === 1 ? '0.7' : '0.0');
      if (textX) textX.textContent = valX;

      // Channel 1: O
      const rectO = document.getElementById('{self.widget_id}_ch1_cell' + i);
      const textO = document.getElementById('{self.widget_id}_ch1_text' + i);
      if (rectO) rectO.setAttribute('fill-opacity', valO === 1 ? '0.7' : '0.0');
      if (textO) textO.textContent = valO;

      // Channel 2: Turn (same value for all cells)
      const rectT = document.getElementById('{self.widget_id}_ch2_cell' + i);
      const textT = document.getElementById('{self.widget_id}_ch2_text' + i);
      if (rectT) rectT.setAttribute('fill-opacity', valTurn === 1 ? '0.7' : '0.0');
      if (textT) textT.textContent = valTurn;
    }}
  }}

  function updateStateVector_{self.widget_id}() {{
    const el = document.getElementById('{self.widget_id}_state');
    if (!el) return;
    el.textContent = '[' + state_{self.widget_id}.join(', ') + ']';
  }}

  // Initial paint
  updateDisplay_{self.widget_id}();
}})();
</script>
"""
        return html

    def show(self):
        display(HTML(self.generate_html()))


ttt_instance = InteractiveTicTacToe()
ttt_instance.show()

<br><br>
## 2. Network Architecture

The AlphaZero network consists of three main stages:

1. Convolutional Block: Initial feature extraction.
2. Residual Tower (Backbone): A stack of Residual Blocks. 
3. Heads: The network splits into two separate output paths.


**Convolutional Block**<br>
Game boards have spatial structure, they look like images. Convolutional layers are perfect for extracting local patterns and also good for symmetry. 
You could also say that they are not good at all for capturing long-range dependencies, but this is captured later with deep layers.
It is also possible to use other arhitectures, but AlphaZero uses Convolutional layers.


**Residual Blocks**<br>
Residual connections (skip connections) allow gradients to flow through the network unimpeded during backpropagation: $y = F(x) + x$, instead of just $y = F(x)$. This helps train very deep networks without suffering from vanishing gradients.
Residual blocks contain multiple convolutional layers within them.

**Multi-Headed Outputs**<br>
“multi-head” here just means: one shared trunk network, two output heads:<br>
- policy head: $p(a \mid s)$
- value head: $v(s)$
  
Why is this better than two separate networks? Because the shared layers learn a common representation that benefits both tasks, improving generalization and efficiency. That is: Efficiency (one forward pass instead of two), shared learning (features useful for both tasks, same internal understanding of the game state).

```mermaid
---
config:
  theme: 'base'
  themeVariables:
    primaryColor: '#ffffff'
    primaryTextColor: '#4D5461'
    primaryBorderColor: '#096bda11'
    lineColor: '#096bda'
    secondaryColor: '#006100'
    tertiaryColor: '#4D5461'
---

flowchart TD
    A[Input] --> B[Conv + BN + ReLU]

    B --> C[Residual Block 1]
    C --> D[Residual Block 2]
    D --> E[...]
    E --> F[Residual Block N]

    %% Policy head
    F --> H[Policy Head:\nConv + BN + ReLU + FC]
    H --> I["Softmax over Moves\n(Policy π)"]

    %% Value head
    F --> J[Value Head:\nConv + BN + ReLU + FC]
    J --> K["Tanh - Scalar\n(Value v)"]

```

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(
            num_channels, num_channels, kernel_size=3, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = nn.Conv2d(
            num_channels, num_channels, kernel_size=3, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual  # Skip Connection
        out = F.relu(out)
        return out


class AlphaZeroNet(nn.Module):
    def __init__(self, game_channels, num_res_blocks, action_size, board_size):
        super().__init__()
        self.num_channels = 64

        # Initial Convolution
        self.conv_input = nn.Conv2d(
            game_channels, self.num_channels, kernel_size=3, padding=1, bias=False
        )
        self.bn_input = nn.BatchNorm2d(self.num_channels)

        # Backbone
        self.res_blocks = nn.ModuleList(
            [ResidualBlock(self.num_channels) for _ in range(num_res_blocks)]
        )

        # Policy Head
        self.p_conv = nn.Conv2d(self.num_channels, 2, kernel_size=1, bias=False)
        self.p_bn = nn.BatchNorm2d(2)
        self.p_fc = nn.Linear(2 * board_size * board_size, action_size)

        # Value Head
        self.v_conv = nn.Conv2d(self.num_channels, 1, kernel_size=1, bias=False)
        self.v_bn = nn.BatchNorm2d(1)
        self.v_fc1 = nn.Linear(1 * board_size * board_size, 64)
        self.v_fc2 = nn.Linear(64, 1)

    def forward(self, x):
        # x: (Batch, Channels, Height, Width)
        x = F.relu(self.bn_input(self.conv_input(x)))

        for res_block in self.res_blocks:
            x = res_block(x)

        # Policy Head
        p = F.relu(self.p_bn(self.p_conv(x)))
        p = p.view(p.size(0), -1)  # Flatten
        p = F.softmax(self.p_fc(p), dim=1)

        # Value Head
        v = F.relu(self.v_bn(self.v_conv(x)))
        v = v.view(v.size(0), -1)  # Flatten
        v = F.relu(self.v_fc1(v))
        v = torch.tanh(self.v_fc2(v))

        return p, v

<br><br>
## 3. Data Generation: MCTS as a Teacher and a Tourist 


The system creates its own training data via self-play.

For every move in a self-play game, we execute an MCTS search. The result of the search is a set of visit counts $N(s, a)$ for the root node.

Then we convert these counts into a probability distribution $\pi(a|s)$ over actions. This distribution serves as the target policy for training the neural network.For the value target, we use the eventual game outcome $z \in \{-1, 0, +1\}$ from the perspective of the current player at state $s$.

First we generate self-play games using MCTS guided by the current neural network. For each move, we record the state $s$, the MCTS-derived policy $\pi(a|s)$. When the game ends, we assign the final outcome $z$ to all states in the game from the perspective of the player to move.

This way the neural network learns from MCTS, which in turn is guided by the neural network, creating a virtuous cycle of improvement. Hence, MCTS acts as both a Teacher (providing better policies) and a Tourist (exploring the game space). And, the neural network acts as a Student and a Tour Guide to the game.

```mermaid
---
config:
  theme: 'base'
  themeVariables:
    primaryColor: '#ffffff'
    primaryTextColor: '#4D5461'
    primaryBorderColor: '#096bda11'
    lineColor: '#096bda'
    secondaryColor: '#ffffff84'
    secondaryTextColor: '#4D5461'
    tertiaryColor: '#4d546110'
    tertiaryTextColor: '#4D5461'
---
flowchart TD
    Net[Neural Network]
    
    subgraph SelfPlay [Self-Play]
        MCTS[MCTS Search]
        Data["Game Data (s, π, z)"]
    end

    Buffer[Replay Buffer]

    subgraph Training [Training]
        Loss["Loss Function"]
        Update[Update Weights]
    end

    Net -->|Guides| MCTS
    MCTS -->|Generates| Data
    Data --> Buffer
    Buffer -->|Sample| Loss
    Loss --> Update
    Update -->|New Weights| Net

    
```


<br><br>
**The following is an example of generating data for 7 moves of Tic-Tac-Toe using MCTS and visualizing the board states along with the MCTS action probabilities.**

The outcomes alternate because it is from the perspective of the player to move at each state.

<br><br>

In [18]:
# for example
# we play a game of tic tac toe from an initial state

import random


def draw_tic_tac_toe(board, size=300, padding=10):
    """
    Draw a tic-tac-toe board as SVG
    """
    cell_size = (size - padding * 2) / 3
    svg = f'<svg width="{size}" height="{size}" style="display:block;">'
    # Background
    svg += f'<rect x="{padding}" y="{padding}" width="{size - padding * 2}" height="{size - padding * 2}" fill="#fafafa00" stroke-width="0" rx="4"></rect>'
    # Grid lines
    for i in range(1, 3):
        # Vertical
        x = padding + i * cell_size
        svg += f'<line x1="{x}" y1="{padding}" x2="{x}" y2="{size - padding}" stroke="#666" stroke-width="2"></line>'
        # Horizontal
        y = padding + i * cell_size
        svg += f'<line x1="{padding}" y1="{y}" x2="{size - padding}" y2="{y}" stroke="#666" stroke-width="2"></line>'
    # Marks
    for i in range(9):
        row = i // 3
        col = i % 3
        cx = padding + col * cell_size + cell_size / 2
        cy = padding + row * cell_size + cell_size / 2
        val = board[i] if i < len(board) else 0
        if val == 1:
            size_x = cell_size * 0.5
            svg += f'<line x1="{cx - size_x / 2}" y1="{cy - size_x / 2}" x2="{cx + size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="2" stroke-linecap="round"></line>'
            svg += f'<line x1="{cx + size_x / 2}" y1="{cy - size_x / 2}" x2="{cx - size_x / 2}" y2="{cy + size_x / 2}" stroke="#0969da" stroke-width="2" stroke-linecap="round"></line>'
        elif val == -1:
            svg += f'<circle cx="{cx}" cy="{cy}" r="{cell_size * 0.25}" fill="none" stroke="#DA7A09" stroke-width="2"></circle>'
    svg += "</svg>"
    return svg


initial_state = [0, 0, 0, 0, 0, 0, 0, 0, 0]
player = 1  # 1 for X, -1 for O
list_moves = [4, 0, 8, 1, 3, 2, 5, 6]
viz_svg_list = []
probs_list = []
outcome_list = []
for j in range(7):
    viz = HTML(draw_tic_tac_toe(initial_state, size=100))

    empty_indices = [i for i, v in enumerate(initial_state) if v == 0]
    # Generate dummy MCTS action probabilities (uniform over empty cells)
    probs = [0.0] * 9
    for idx in empty_indices:
        probs[idx] = random.random()
    total = sum(probs)
    if total > 0:
        probs = [p / total for p in probs]

    move = list_moves.pop(0)
    initial_state[move] = player
    player *= -1

    viz_svg_list.append(viz)

    if j != 6:
        probs_list.append(probs)
        outcome_list.append(str("&nbsp;1" if player == -1 else "-1"))
    else:
        probs_list.append([])
        outcome_list.append(" ")

# make a table to show the moves and mcts stats
table_html = '<table style="border-collapse: collapse; width: auto; border: 1px solid #ddd !important;">'
table_html += '<tr><th style="border: 0px solid #ddd; padding: 10px 20px 10px 20px;">Move</th><th style="border: 0px solid #ddd; padding: 10px 20px 10px 20px;">Board State</th><th style="border: 0px solid #ddd; padding: 10px 20px 10px 20px;">MCTS Action Probabilities</th><th style="border: 0px solid #ddd; padding: 10px 20px 10px 20px;">Outcome</th></tr>'
for i, viz in enumerate(viz_svg_list):
    probs_str = ", ".join(f"{p:.2f}" for p in probs_list[i])
    row = f"""<td style="border: 0px solid #ddd; padding: 10px 20px 10px 20px; text-align: center; ">{i + 1}</td>"""
    row += f'<td style="border: 0px solid #ddd; padding: 0px 20px 0px 20px; text-align: center;align-items: center;display:flex;justify-content:center;">{viz.data}</td>'
    row += f'<td style="border: 0px solid #ddd; padding: 0px 20px 0px 20px; text-align: center; font-family: monospace;font-size:12px;">[{probs_str}]</td>'
    row += f"""<td style="border: 0px solid #ddd; padding: 0px 20px 0px 20px; text-align: center;font-family: monospace; font-size:16px;">{outcome_list[i]}</td>"""
    table_html += f"<tr>{row}</tr>"

table_html += "</table>"

display(HTML(table_html))

Move,Board State,MCTS Action Probabilities,Outcome
1,,"[0.17, 0.16, 0.15, 0.15, 0.05, 0.02, 0.03, 0.15, 0.12]",1.0
2,,"[0.13, 0.00, 0.21, 0.14, 0.00, 0.18, 0.15, 0.17, 0.02]",-1.0
3,,"[0.00, 0.04, 0.03, 0.23, 0.00, 0.37, 0.24, 0.08, 0.02]",1.0
4,,"[0.00, 0.41, 0.07, 0.02, 0.00, 0.01, 0.08, 0.41, 0.00]",-1.0
5,,"[0.00, 0.00, 0.41, 0.23, 0.00, 0.05, 0.20, 0.11, 0.00]",1.0
6,,"[0.00, 0.00, 0.19, 0.00, 0.00, 0.41, 0.34, 0.05, 0.00]",-1.0
7,,[],


<br><br><br>
**Here is a more comprehensive diagram of the self-play and training loop:**
<br><br>

```mermaid
---
config:
  theme: 'base'
  themeVariables:
    primaryColor: '#ffffff'
    primaryTextColor: '#4D5461'
    primaryBorderColor: '#096bda11'
    lineColor: '#096bda'
    secondaryColor: '#ffffff84'
    secondaryTextColor: '#4D5461'
    tertiaryColor: '#4d546110'
    tertiaryTextColor: '#4D5461'
---
flowchart TD
    %% Define Nodes
    State[Current Board State s]
    Net[Current Neural Network]
    
    subgraph MCTS_Process [MCTS Simulation]
        Search[Run MCTS Simulations]
        Visits[Get Visit Counts N]
    end

    Outcome[Game Ends: Outcome z]
    
    %% Define Data Tuple
    TargetPi["Target Policy π\n(Normalized Visits)"]
    TargetZ["Target Value z\n(+1 / -1 / 0)"]
    
    Buffer["Replay Buffer\n(s, π, z)"]

    %% Connections
    State --> Search
    Net -->|Guides| Search
    Search --> Visits
    Visits -->|Convert w/ Temp τ| TargetPi

    MCTS_Process --> Outcome
    
    Outcome -->|Assign Winner| TargetZ
    
    %% Aggregation
    State -.-> Buffer
    TargetPi -.-> Buffer
    TargetZ -.-> Buffer

    %% Training Loop
    subgraph Training [Network Training]
        Batch["Sample Batch\n((s, π, z)"]
        Loss["Calculate Loss\n(MSE(z, v) + CrossEntropy(π, p)"]
        Update["Backprop & Update Weights"]
    end

    %% Connections for Training
    Buffer -->|Sample| Batch
    Batch --> Loss
    Net -.->|Current Weights| Loss
    Loss --> Update
    Update -->|New Weights| Net
```

## 4. The Loss Function

We train the network to minimize error on both heads simultaneously. The total loss is a sum of the Value loss, the Policy loss, and regularization.

$$L = (z - v)^2 - \boldsymbol{\pi}^T \log \mathbf{p} + c||\theta||^2$$

1. Value Loss $(z - v)^2$:
   - MSE (Mean Squared Error) between the predicted value $v$ and the actual game outcome $z$ ($+1$ for win, $-1$ for loss).
   - This teaches the network to evaluate positions accurately.

2. Policy Loss $-\boldsymbol{\pi}^T \log \mathbf{p}$:
   - Cross-Entropy between the MCTS visit probabilities $\boldsymbol{\pi}$ (target) and the network's move probabilities $\mathbf{p}$ (prediction).
   - This teaches the network to imitate the search results.

3. Regularization $c||\theta||^2$:
   - Standard L2 weight decay to prevent overfitting to the specific games in the replay buffer.

In [None]:
class AlphaZeroLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, pred_policy, pred_value, target_policy, target_value):
        """
        pred_policy: Logits or LogSoftmax from network
        target_policy: MCTS probabilities
        pred_value: Tanh output from network
        target_value: Actual game result {-1, 1}
        """

        value_loss = self.mse_loss(pred_value.view(-1), target_value.view(-1))

        log_predicted_policy = F.log_softmax(pred_policy, dim=1)

        policy_loss = -torch.sum(target_policy * log_predicted_policy, dim=1).mean()

        total_loss = value_loss + policy_loss

        return total_loss, value_loss, policy_loss

**Value Loss (Mean Squared Error):**

Mean Squared Error (MSE) between the predicted value $v$ and the target value $z$:

**Policy Loss (Cross-Entropy / Negative Log-Likelihood):**

`F.cross_entropy` is used for convenience, as it handles log_softmax internally, but it typically expects integer labels. Since our target policy is a probability distribution (from MCTS visit counts), we use a manual implementation of negative log-likelihood. 