In [1]:
import os
os.chdir("..")

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"

import subprocess
from tqdm import tqdm

## Load data

In [2]:
def mirror_piece(pieces):
    return ((1 - pieces // 6) * 6 + (pieces % 6))

def mirror_square(squares):
    return squares ^ 56

def load_batch():
    squares = torch.tensor(np.load("data/batch/square.npy")).to(torch.long)   # int8,  (K, 32)
    pieces  = torch.tensor(np.load("data/batch/piece.npy")).to(torch.long)    # int8,  (K, 32)
    results = torch.tensor(np.load("data/batch/result.npy")).to(torch.float)  # int8,  (K,)
    side    = torch.tensor(np.load("data/batch/side.npy"))[:, None]           # int8,  (K,)  1=white,0=black
    
    valid = squares != -1
    
    idx_w = pieces * 64 + squares
    idx_b = mirror_piece(pieces) * 64 + mirror_square(squares)
    idx_w = torch.where(valid, idx_w, 768)
    idx_b = torch.where(valid, idx_b, 768)
    idx_p = side * idx_w + (1-side) * idx_b
    idx_o = side * idx_b + (1-side) * idx_w
    return idx_p, idx_o, -results

## Model

In [3]:
class Squirrel(nn.Module):
    """
    PyTorch implementation of the 'body' + 'critic' part of Squirrel.
    Expects two index tensors (own perspective idx_p, opponent idx_o) of shape (B, 32), 
    with padding index 768 for empty slots.
    """
    def __init__(self, emb_dim=256):
        super().__init__()
        
        # B0 embedding  (768 pieces + padding)
        self.embed = nn.EmbeddingBag(num_embeddings=769, embedding_dim=emb_dim, mode="sum", sparse=False, padding_idx=768)

        # B1p / B1o
        self.b1p = nn.Linear(emb_dim, 32)
        self.b1o = nn.Linear(emb_dim, 32)

        # B2 and critic head
        self.b2  = nn.Linear(64, 32)
        self.c1  = nn.Linear(32, 1)

        self.relu = nn.ReLU()

    def forward(self, idx_p, idx_o):
        """
        idx_p : (B, 32) own‑perspective indices
        idx_o : (B, 32) opponent‑perspective indices
        Returns critic value tensor shape (B, 1)
        """
        # Embedding
        emb_p = self.embed(idx_p)
        emb_o = self.embed(idx_o)
        xo = self.relu(emb_o)
        xp = self.relu(emb_p)

        # Body
        x1p = self.relu(self.b1p(xp)) # (B, 32)
        x1o = self.relu(self.b1o(xo)) # (B, 32)
        x2 = self.relu(self.b2(
            torch.cat([x1p, x1o], dim=1)
        )) # (B, 32)

        # Critic
        value = self.c1(x2)                         # (B, 1)
        
        return value

## Train

In [4]:
def train_epoch(batch_size=1024):
    """
    Trains the model for one epoch using mini-batch gradient descent.
    Returns average train and test loss.
    """
    model.train()

    num_samples = Xp_train.size(0)
    num_batches = (num_samples + batch_size - 1) // batch_size
    total_loss = 0.0

    # for i in tqdm(range(0, num_samples, batch_size), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}'):
    for i in range(0, num_samples, batch_size):
        Xpb = Xp_train[i:i+batch_size]
        Xob = Xo_train[i:i+batch_size]
        Yb  = Y_train[i:i+batch_size]

        optimizer.zero_grad()
        pred = model(Xpb, Xob).squeeze(-1)
        loss = F.mse_loss(pred, Yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * Xpb.size(0)

    avg_train_loss = total_loss / num_samples

    # Evaluate on test data
    model.eval()
    with torch.no_grad():
        pred_test = model(Xp_test, Xo_test).squeeze(-1)
        test_loss = F.mse_loss(pred_test, Y_test).item()

    return avg_train_loss, test_loss

## Save

In [5]:
def save_weights(out_dir="data/weights"):
    """
    Save weights of a trained Squirrel model to .npy files in the specified directory.
    The layout and names match the convention used for model loading in C++.
    """
    os.makedirs(out_dir, exist_ok=True)

    def save(name, tensor):
        arr = tensor.detach().cpu().numpy().astype(np.float32)
        np.save(f"{out_dir}/{name}.npy", arr)

    # B0: embedding weight (768+1, 256) → save first 768 rows and transpose
    save("B0-768-256-w", model.embed.weight[:768]) # shape (768, 256)

    # Embedding bias: not present — store zero bias
    save("B0-256-a", torch.zeros(256, dtype=torch.float32))

    # B1p
    save("B1p-32-256-w", model.b1p.weight)
    save("B1p-32-a",     model.b1p.bias)

    # B1o
    save("B1o-32-256-w", model.b1o.weight)
    save("B1o-32-a",     model.b1o.bias)

    # B2 (input is 64-dim: concat of 32 + 32)
    save("B2-32-64-w", model.b2.weight)
    save("B2-32-a",     model.b2.bias)

    # C1
    save("C1-1-32-w", model.c1.weight)
    save("C1-1-a",     model.c1.bias)

## Main

In [7]:
model = Squirrel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

In [None]:
# model.load_state_dict(torch.load("data/weights/00.pt"))

In [None]:
epoch = 0;
while True:
    epoch += 1;
    
    res_play = subprocess.run([r'bin\play', '50', '8'], capture_output=True, text=True)
    res_dump = subprocess.run([r'bin\dump', '100000', '0'], capture_output=True, text=True)
    # print(res_dump.stdout)

    Xp, Xo, Y = load_batch()
    
    N = 1024 * 6
    idx = torch.randperm(Y.size(0))
    idx_train = idx[:round(N * 0.50)]
    idx_test = idx[round(N * 0.50):N]
    
    Xp_train, Xp_test = Xp[idx_train], Xp[idx_test]
    Xo_train, Xo_test = Xo[idx_train], Xo[idx_test]
    Y_train, Y_test = Y[idx_train], Y[idx_test]
    
    del Xp, Xo, Y;

    # print(f"epoch | train  | test")
    tr, te = train_epoch()
    print(f"{epoch+1:2d}    | {tr:.4f} | {te:.4f}", end='\r')

    save_weights()
    torch.save(model.state_dict(), f"data/weights/00.pt")

77    | 0.2139 | 0.2128

In [None]:
# # positions    305043
# avg len        103.06 +- 57.93
# P(-1 0 +1)     0.27 0.27 0.46
# corr(res, val) 0.19

In [None]:
# # positions    113421
# avg len        112.08 +- 76.65
# P(-1 0 +1)     0.28 0.24 0.48
# corr(res, val) 0.15