In [1]:
import polars as pl
import numpy as np
import torch.utils.data
import torch
import torch.nn as nn
import lightning
import lightning.pytorch

In [2]:
scan = pl.scan_parquet("../../fishnet-position-dataset/endgame_evals.parquet")

In [3]:
scan.group_by("piece_count").len().sort("len").collect()

piece_count,len
u8,u32
2,1850480
3,118569229
4,173072838
5,212005402
6,254848650
7,297170292
8,343192392


In [4]:
scan.filter(pl.col("piece_count") == 8).group_by("op1").len().collect()

op1,len
bool,u32
True,191058188
False,152134204


In [5]:
DIM = 12 * 64 + 1
EXTRA_DIM = 2

PIECE_TO_PLANE = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11,
}

PIECE_VALUE = {
    'P': 1, 'N': 3, 'B': 3, 'R': 5, 'Q': 9, 'K': 0,
    'p': -1, 'n': -3, 'b': -3, 'r': -5, 'q': -9, 'k': 0,
}

def fen_to_onehot(fen: str) -> np.ndarray:
    board, turn, _ = fen.split(" ", 2)
    tensor = np.zeros(DIM + EXTRA_DIM, dtype=np.int8)
    x = 0
    material_balance = 0
    wk = 0
    bk = 0
    for ch in board:
        if ch.isdigit():
            x += int(ch)
        elif ch == "/":
            continue
        else:
            tensor[64 * PIECE_TO_PLANE[ch] + (x ^ 0x38)] = 1
            material_balance += PIECE_VALUE[ch]
            if ch == "k":
                bk = x
            elif ch == "K":
                wk = x
            x += 1
    tensor[DIM - 1] = int(turn == "w")
    tensor[DIM] = material_balance
    tensor[DIM + 1] = max(abs((wk & 7) - (bk & 7)), abs((wk >> 3) - (bk >> 3))) 
    return tensor

In [6]:
fen_to_onehot("4k3/8/8/8/8/8/8/R3K3 w - - 0 1")

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [7]:
df = scan.filter(pl.col("piece_count").eq(8) & pl.col("op1")).select(pl.col("fen"), (pl.col("cp").fill_null(0).gt(0) | pl.col("mate").fill_null(0).gt(0)).alias("win")).collect()

In [8]:
df

fen,win
str,bool
"""r5k1/5p1p/4p3/8/8/4K1P1/7P/8 b…",false
"""6k1/5p1p/4p3/8/8/4K1P1/r6P/8 w…",false
"""6k1/5p1p/4p3/8/5K2/6P1/r6P/8 b…",false
"""8/7R/1Pp3p1/1nP2k2/2K5/8/8/8 b…",true
"""8/7R/1Pp5/1nP2kp1/2K5/8/8/8 w …",true
…,…
"""8/8/8/8/1p3R1k/3r1p1P/5P1K/8 b…",true
"""8/8/8/6k1/1p3R2/3r1p1P/5P1K/8 …",true
"""8/8/1p6/pkp2K2/5B2/1P3P2/8/8 b…",true
"""8/8/1p6/p1p2K2/1k3B2/1P3P2/8/8…",true


In [9]:
class EndgameDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.fen = df["fen"]
        self.win = df["win"]

    def __len__(self):
        return len(self.fen)

    def __getitem__(self, idx):
        input_tensor = torch.tensor(fen_to_onehot(self.fen[idx]), dtype=torch.float32)
        label_tensor = torch.tensor(self.win[idx], dtype=torch.float32)
        return input_tensor, label_tensor

In [10]:
full_dataset = EndgameDataset(df)

debug_size = 10_000
val_size = 100_000
train_dataset, debug_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [len(full_dataset) - debug_size - val_size, debug_size, val_size])

In [11]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10_000, shuffle=True, num_workers=30)
debug_loader = torch.utils.data.DataLoader(debug_dataset, batch_size=1000, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1000, shuffle=False, num_workers=10)

In [12]:
def material_count_predict_batch(inputs):
    return (inputs[:, DIM] >= 0).clone().detach().unsqueeze(1)

correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in debug_loader:
        preds = material_count_predict_batch(inputs)
        correct += (preds.squeeze() == labels).sum().item()
        total += len(labels)

accuracy = correct / total
print(f"Material count prediction acc: {accuracy:.4f}")

Material count prediction acc: 0.7505


In [13]:
class EndgameNet(lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(DIM + EXTRA_DIM, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.unsqueeze(1).float())
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.unsqueeze(1).float())
        preds = torch.sigmoid(logits) > 0.5
        acc = (preds == y.unsqueeze(1)).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [14]:
model = EndgameNet()
trainer = lightning.pytorch.Trainer(max_epochs=1, accelerator="auto")
trainer.fit(model, train_loader, val_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/niklas/.virtualenvs/guess-flst/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | Sequential        | 50.5 K | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
---

Sanity Checking: |                                                          | 0/? [00:00<?, ?it/s]

Training: |                                                                 | 0/? [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=1` reached.


In [15]:
model # expecting version 17

EndgameNet(
  (model): Sequential(
    (0): Linear(in_features=771, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=1, bias=True)
  )
  (loss_fn): BCEWithLogitsLoss()
)