# Thought Process 

## Value Network

This part accepts a board state and give us a score as the output. 
The simple implemtation used a basic function that compared the amount of 
pieces and give the value of the position based on that. 

However, in this case we have a neural network that takes in a position and gives an estimation on the 
value of the position. The idea is that the value network is going to work as a classifier giving perhaps something like a 
probility that white wins or a probabilty that black wins (can be represented with a negative probability perhaps)

## Policy Network: 
This part accepts a board state as an input and gives a set of probabilities 
representing the probability of a move where higher probability means that the 
probability of the move leading to a dub is higher. 

## MCTS: 
Basic idea is to start with an empty tree, then use MCTS to build up a portion of the game tree by running 
a number of simulations, where each simulation adds a node to the tree.  


## Notes: 

In the original implementaion we used a 2D list, which is fine, but here 
using a tensor using numpy is going to be much faster and thus we need to rerepresnt the 
board (the game state)

In [None]:
from __future__ import annotations

import copy
import random
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import Engine

NUM_MOVES = 8 ** 4 

def move_to_index(mv: Engine.move) -> int:
    return ((mv.start_row * 8 + mv.start_col) * 64) + (mv.end_row * 8 + mv.end_col)


In [None]:
class PolicyNetSmall(nn.Module):
    """150 K parameters. Fast enough for use at every tree node."""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.flat  = nn.Flatten()
        self.fc    = nn.Linear(64 * 8 * 8, NUM_MOVES)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flat(x)
        return self.fc(x)  # raw logits

class ValueNet(nn.Module):
    """Deeper network (~2 M params) used only at leaf nodes → slower but OK."""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 128, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 256, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(256)
        self.flat  = nn.Flatten()
        self.fc    = nn.Linear(256 * 8 * 8, 512)
        self.val   = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.flat(x)
        x = F.relu(self.fc(x))
        return torch.tanh(self.val(x)).squeeze(1)


# DualNet 

In [None]:
class DualNet:
    """Handles inference & (pre)training for PolicyNetSmall + ValueNet."""
    def __init__(self, device: str | None = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.policy = PolicyNetSmall().to(self.device)
        self.value  = ValueNet().to(self.device)
        self.opt = torch.optim.Adam(list(self.policy.parameters()) + list(self.value.parameters()), lr=1e-3)
        self.kl  = nn.KLDivLoss(reduction="batchmean")
        self.mse = nn.MSELoss()

    #  inference
    @torch.no_grad()
    def predict(self, state: Engine.GameState) -> Tuple[Dict[Engine.move, float], float]:
        board = torch.tensor(state.get_current_state(), dtype=torch.float32, device=self.device).unsqueeze(0)
        logits = self.policy(board).squeeze(0)
        p_full = torch.softmax(logits, dim=0).cpu().numpy()
        v = float(self.value(board).item())

        legal_moves = state.get_all_valid_moves()
        if not legal_moves:
            return {}, v
        move_probs: Dict[Engine.move, float] = {}
        total = 0.0
        for mv in legal_moves:
            p = p_full[move_to_index(mv)]
            move_probs[mv] = p
            total += p
        if total == 0:
            move_probs = {m: 1/len(legal_moves) for m in legal_moves}
        else:
            move_probs = {m: p/total for m, p in move_probs.items()}
        return move_probs, v

    #  joint self‑play training (MCTS targets)
    def train(self, batch: List[Tuple[np.ndarray, Dict[Engine.move, float], int]], epochs=1, batch_sz=64):
        if not batch:
            return
        X, P, V = [], [], []
        for board, pi, outcome in batch:
            X.append(torch.tensor(board, dtype=torch.float32))
            vec = np.zeros(NUM_MOVES, dtype=np.float32)
            for mv, p in pi.items():
                vec[move_to_index(mv)] = p
            P.append(torch.tensor(vec, dtype=torch.float32))
            V.append(float(outcome))
        X = torch.stack(X).to(self.device)
        P = torch.stack(P).to(self.device)
        V = torch.tensor(V, dtype=torch.float32).to(self.device)

        for _ in range(epochs):
            perm = torch.randperm(len(batch), device=self.device)
            X, P, V = X[perm], P[perm], V[perm]
            for i in range(0, len(batch), batch_sz):
                xb, pb, vb = X[i:i+batch_sz], P[i:i+batch_sz], V[i:i+batch_sz]
                self.opt.zero_grad()
                pol_logits = self.policy(xb)
                val_pred   = self.value(xb)
                loss_p = self.kl(F.log_softmax(pol_logits, dim=1), pb)
                loss_v = self.mse(val_pred, vb)
                (loss_p + loss_v).backward()
                self.opt.step()

    #  policy pretraining 
    def pretrain_policy(self, examples: List[Tuple[np.ndarray, np.ndarray]], epochs=3, batch_sz=256):
        if not examples:
            return
        X = torch.stack([torch.tensor(b, dtype=torch.float32) for b, _ in examples]).to(self.device)
        Y = torch.stack([torch.tensor(y, dtype=torch.float32) for _, y in examples]).to(self.device)
        for _ in range(epochs):
            perm = torch.randperm(len(examples), device=self.device)
            X, Y = X[perm], Y[perm]
            for i in range(0, len(examples), batch_sz):
                xb, yb = X[i:i+batch_sz], Y[i:i+batch_sz]
                self.opt.zero_grad()
                logits = self.policy(xb)
                loss = self.kl(F.log_softmax(logits, dim=1), yb)
                loss.backward(); self.opt.step()

    #  value pretraining 
    def pretrain_value(self, examples: List[Tuple[np.ndarray, float]], epochs=1, batch_sz=256):
        if not examples:
            return
        X = torch.stack([torch.tensor(b, dtype=torch.float32) for b, _ in examples]).to(self.device)
        y = torch.tensor([v for _, v in examples], dtype=torch.float32).to(self.device)
        for _ in range(epochs):
            perm = torch.randperm(len(examples), device=self.device)
            X, y = X[perm], y[perm]
            for i in range(0, len(examples), batch_sz):
                xb, yb = X[i:i+batch_sz], y[i:i+batch_sz]
                self.opt.zero_grad()
                v_pred = self.value(xb)
                loss = self.mse(v_pred, yb)
                loss.backward(); self.opt.step()

    #  persistence
    def save(self, path_root="dualnet"):
        torch.save(self.policy.state_dict(), f"{path_root}_policy.pth")
        torch.save(self.value.state_dict(),  f"{path_root}_value.pth")

    def load(self, path_root="dualnet"):
        self.policy.load_state_dict(torch.load(f"{path_root}_policy.pth", map_location=self.device))
        self.value.load_state_dict(torch.load(f"{path_root}_value.pth",  map_location=self.device))
        self.policy.eval(); self.value.eval()


# Bootstrapping labels



In [None]:
PV = {"p":100, "N":320, "B":330, "R":500, "Q":900, "K":0}
# Simple PSTs (centipawns). White tables; black tables are mirrored by row.
PST_P = np.array([
    [ 0,  0,  0,  0,  0,  0,  0,  0],
    [50, 50, 50, 50, 50, 50, 50, 50],
    [10, 10, 20, 30, 30, 20, 10, 10],
    [ 5,  5, 10, 25, 25, 10,  5,  5],
    [ 0,  0,  0, 20, 20,  0,  0,  0],
    [ 5, -5,-10,  0,  0,-10, -5,  5],
    [ 5, 10, 10,-20,-20, 10, 10,  5],
    [ 0,  0,  0,  0,  0,  0,  0,  0],
], dtype=np.int16)
PST_N = np.array([
    [-50,-40,-30,-30,-30,-30,-40,-50],
    [-40,-20,  0,  0,  0,  0,-20,-40],
    [-30,  0, 10, 15, 15, 10,  0,-30],
    [-30,  5, 15, 20, 20, 15,  5,-30],
    [-30,  0, 15, 20, 20, 15,  0,-30],
    [-30,  5, 10, 15, 15, 10,  5,-30],
    [-40,-20,  0,  5,  5,  0,-20,-40],
    [-50,-40,-30,-30,-30,-30,-40,-50],
], dtype=np.int16)
PST_B = np.array([
    [-20,-10,-10,-10,-10,-10,-10,-20],
    [-10,  5,  0,  0,  0,  0,  5,-10],
    [-10, 10, 10, 10, 10, 10, 10,-10],
    [-10,  0, 10, 10, 10, 10,  0,-10],
    [-10,  5,  5, 10, 10,  5,  5,-10],
    [-10,  0,  5, 10, 10,  5,  0,-10],
    [-10,  0,  0,  0,  0,  0,  0,-10],
    [-20,-10,-10,-10,-10,-10,-10,-20],
], dtype=np.int16)
PST_R = np.array([
    [ 0,  0,  0,  5,  5,  0,  0,  0],
    [-5,  0,  0,  0,  0,  0,  0, -5],
    [-5,  0,  0,  0,  0,  0,  0, -5],
    [-5,  0,  0,  0,  0,  0,  0, -5],
    [-5,  0,  0,  0,  0,  0,  0, -5],
    [-5,  0,  0,  0,  0,  0,  0, -5],
    [ 5, 10, 10, 10, 10, 10, 10,  5],
    [ 0,  0,  0,  0,  0,  0,  0,  0],
], dtype=np.int16)
PST_Q = np.array([
    [-20,-10,-10, -5, -5,-10,-10,-20],
    [-10,  0,  0,  0,  0,  0,  0,-10],
    [-10,  0,  5,  5,  5,  5,  0,-10],
    [ -5,  0,  5,  5,  5,  5,  0, -5],
    [  0,  0,  5,  5,  5,  5,  0, -5],
    [-10,  5,  5,  5,  5,  5,  0,-10],
    [-10,  0,  5,  0,  0,  0,  0,-10],
    [-20,-10,-10, -5, -5,-10,-10,-20],
], dtype=np.int16)
PST_K = np.array([
    [ 20, 30, 10,  0,  0, 10, 30, 20],
    [ 20, 20,  0,  0,  0,  0, 20, 20],
    [-10,-20,-20,-20,-20,-20,-20,-10],
    [-20,-30,-30,-40,-40,-30,-30,-20],
    [-30,-40,-40,-50,-50,-40,-40,-30],
    [-30,-40,-40,-50,-50,-40,-40,-30],
    [-30,-40,-40,-50,-50,-40,-40,-30],
    [-30,-40,-40,-50,-50,-40,-40,-30],
], dtype=np.int16)
PST_MAP = {"p": PST_P, "N": PST_N, "B": PST_B, "R": PST_R, "Q": PST_Q, "K": PST_K}

class BootstrapPretrainer:
    def __init__(self, temperature: float = 0.7):
        self.t = max(temperature, 1e-3)

    @staticmethod
    def eval_board_white(board: List[List[str]]) -> int:
        """Static eval in centipawns from White's perspective."""
        score = 0
        for r in range(8):
            for c in range(8):
                p = board[r][c]
                if p == "--":
                    continue
                val = PV[p[1]]
                pst = PST_MAP[p[1]][r, c] if p[0] == 'w' else PST_MAP[p[1]][7 - r, c]
                sgn = 1 if p[0] == 'w' else -1
                score += sgn * (val + pst)
        return score

    def policy_targets(self, state: Engine.GameState) -> Dict[Engine.move, float]:
        """Return a probability distribution over legal moves using softmax of
        *evaluation improvement* for the side to move."""
        moves = state.get_all_valid_moves()
        if not moves:
            return {}
        base = self.eval_board_white(state.board)
        scores = []
        for mv in moves:
            s2 = copy.deepcopy(state)
            s2.make_move(mv)
            nxt = self.eval_board_white(s2.board)
            diff = nxt - base
            if not state.white_to_move:
                diff = -diff  # from side-to-move perspective
            scores.append(diff)
        # softmax with temperature
        x = np.array(scores, dtype=np.float32) / self.t
        x -= x.max()
        e = np.exp(x)
        Z = e.sum()
        if Z <= 0:
            probs = np.full(len(moves), 1/len(moves), dtype=np.float32)
        else:
            probs = e / Z
        return dict(zip(moves, probs.tolist()))

    def generate_policy_dataset(self, n_positions: int = 5000, max_plies: int = 80, eps: float = 0.2) -> Tuple[List[Tuple[np.ndarray, np.ndarray]], List[Tuple[np.ndarray, float]]]:
        """Play many quick pseudo‑games with epsilon-greedy using the static eval
        to collect (board, policy-distribution) pairs. Also returns optional
        value targets derived from static eval (tanh-scaled) for value pretraining."""
        policy_examples: List[Tuple[np.ndarray, np.ndarray]] = []
        value_examples:  List[Tuple[np.ndarray, float]] = []
        while len(policy_examples) < n_positions:
            gs = Engine.GameState()
            for _ in range(max_plies):
                if gs.check_mate or gs.stale_mate:
                    break
                moves = gs.get_all_valid_moves()
                if not moves:
                    break
                # build policy targets for current state
                dist = self.policy_targets(gs)
                vec = np.zeros(NUM_MOVES, dtype=np.float32)
                for mv, p in dist.items():
                    vec[move_to_index(mv)] = p
                policy_examples.append((gs.get_current_state(), vec))
                # optional value target from static eval scaled to [-1,1]
                e = self.eval_board_white(gs.board)
                v = np.tanh(e / 400.0)
                value_examples.append((gs.get_current_state(), float(v)))
                # choose action epsilon‑greedy from dist
                if random.random() < eps:
                    mv = random.choice(list(dist.keys()))
                else:
                    mv = max(dist.items(), key=lambda kv: kv[1])[0]
                gs.make_move(mv)
                if len(policy_examples) >= n_positions:
                    break
        return policy_examples, value_examples


# MCTS using dual net

In [None]:
class MCTSNode:
    __slots__ = ("state","parent","children","visit_count","total_value","prior","move")
    def __init__(self, state: Engine.GameState, parent: Optional['MCTSNode']=None,
                 prior=0.0, move: Optional[Engine.move]=None):
        self.state = state; self.parent = parent; self.prior = prior; self.move = move
        self.children: Dict[Engine.move,'MCTSNode'] = {}
        self.visit_count = 0; self.total_value = 0.0
    def q(self):
        return self.total_value / self.visit_count if self.visit_count else 0.0
    def is_leaf(self):
        return not self.children

class MCTS:
    def __init__(self, net: DualNet, sims=200, c_puct=1.4):
        self.net, self.N, self.c = net, sims, c_puct
    def search(self, root_state: Engine.GameState) -> Dict[Engine.move,float]:
        root = MCTSNode(root_state)
        priors, _ = self.net.predict(root_state)
        for mv,p in priors.items():
            root.children[mv] = MCTSNode(self._next(root_state,mv), parent=root, prior=p, move=mv)
        for _ in range(self.N):
            leaf = self._select(root)
            value = self._expand_eval(leaf)
            self._backprop(leaf, value)
        visits = np.array([ch.visit_count for ch in root.children.values()], dtype=np.float32)
        if visits.sum()==0: return {}
        probs = visits/visits.sum()
        return dict(zip(root.children.keys(), probs))
    def _select(self,n):
        while not n.is_leaf():
            best, best_s = None, -1e9; sqrt_p = np.sqrt(n.visit_count)
            for ch in n.children.values():
                u = self.c*ch.prior*sqrt_p/(1+ch.visit_count); s = ch.q()+u
                if s>best_s: best,best_s=ch,s
            n = best
        return n
    def _expand_eval(self,node):
        t = self._term(node.state)
        if t is not None: return t
        priors,val = self.net.predict(node.state)
        for mv,p in priors.items():
            node.children[mv] = MCTSNode(self._next(node.state,mv), parent=node, prior=p, move=mv)
        return val
    def _backprop(self,node,val):
        while node:
            node.visit_count+=1; node.total_value+=val; val=-val; node=node.parent
    @staticmethod
    def _next(state,mv):
        s = copy.deepcopy(state); s.make_move(mv); return s
    @staticmethod
    def _term(state):
        if state.check_mate: return 1 if not state.white_to_move else -1
        if state.stale_mate: return 0
        return None

In [None]:

def self_play(net: DualNet, sims=300) -> List[Tuple[np.ndarray, Dict[Engine.move,float], int]]:
    gs = Engine.GameState(); tree = MCTS(net, sims); hist=[]
    while not (gs.check_mate or gs.stale_mate):
        pi = tree.search(gs)
        if not pi: break
        hist.append((gs.get_current_state(), pi, None))
        mv = max(pi.items(), key=lambda kv: kv[1])[0]
        gs.make_move(mv)
    z = 0
    if gs.check_mate: z = 1 if not gs.white_to_move else -1
    return [(b, p, z) for (b, p, _) in hist]

In [None]:
dual = DualNet()

# 1) Bootstrap pretraining so the policy is NOT random
bootstrap = BootstrapPretrainer(temperature=0.7)
pol_ds, val_ds = bootstrap.generate_policy_dataset(n_positions=3000, max_plies=80, eps=0.2)
print(f"Bootstrap: policy {len(pol_ds)}  value {len(val_ds)}")
dual.pretrain_policy(pol_ds, epochs=2)
dual.pretrain_value(val_ds,  epochs=1)
dual.save("dualnet_bootstrap")

# 2) Self‑play using the pretrained policy
batch = self_play(dual, sims=80)
print(f"Self-play positions: {len(batch)}")
dual.train(batch, epochs=1)
dual.save("dualnet")
