In [43]:
import polars as pl
import numpy as np
import torch.utils.data
import torch
import torch.nn as nn
import lightning
import lightning.pytorch
import random

In [2]:
scan = pl.scan_parquet("../../fishnet-position-dataset/endgame_evals.parquet")

In [3]:
scan.group_by("piece_count").len().sort("len").collect()

piece_count,len
u8,u32
2,1850480
3,118569229
4,173072838
5,212005402
6,254848650
7,297170292
8,343192392


In [4]:
scan.filter(pl.col("piece_count") == 8).group_by("op1").len().collect()

op1,len
bool,u32
False,152134204
True,191058188


In [7]:
DIM = 12 * 64 + 1
EXTRA_DIM = 0

PIECE_TO_PLANE = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11,
}

PIECE_VALUE = {
    'P': 1, 'N': 3, 'B': 3, 'R': 5, 'Q': 9, 'K': 0,
    'p': -1, 'n': -3, 'b': -3, 'r': -5, 'q': -9, 'k': 0,
}

def fen_to_onehot(fen: str) -> np.ndarray:
    board, turn, _ = fen.split(" ", 2)
    tensor = np.zeros(DIM + EXTRA_DIM, dtype=np.int8)
    x = 0
    #material_balance = 0
    for ch in board:
        if ch.isdigit():
            x += int(ch)
        elif ch == "/":
            continue
        else:
            tensor[64 * PIECE_TO_PLANE[ch] + (x ^ 0x38)] = 1
            #material_balance += PIECE_VALUE[ch]
            x += 1
    tensor[DIM - 1] = int(turn == "w")
    return tensor

In [8]:
fen_to_onehot("4k3/8/8/8/8/8/8/R3K3 w - - 0 1")

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [9]:
df = scan.filter(pl.col("piece_count").eq(8) & pl.col("op1")).select(pl.col("fen"), (pl.col("cp").fill_null(0).gt(0) | pl.col("mate").fill_null(0).gt(0)).alias("win")).collect()

In [10]:
df

fen,win
str,bool
"""r5k1/5p1p/4p3/8/8/4K1P1/7P/8 b…",false
"""6k1/5p1p/4p3/8/8/4K1P1/r6P/8 w…",false
"""6k1/5p1p/4p3/8/5K2/6P1/r6P/8 b…",false
"""8/7R/1Pp3p1/1nP2k2/2K5/8/8/8 b…",true
"""8/7R/1Pp5/1nP2kp1/2K5/8/8/8 w …",true
…,…
"""8/8/8/8/1p3R1k/3r1p1P/5P1K/8 b…",true
"""8/8/8/6k1/1p3R2/3r1p1P/5P1K/8 …",true
"""8/8/1p6/pkp2K2/5B2/1P3P2/8/8 b…",true
"""8/8/1p6/p1p2K2/1k3B2/1P3P2/8/8…",true


In [44]:
class EndgameDataset(torch.utils.data.IterableDataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is None:
            # Single-process data loading
            start = 0
            end = len(self.df)
        else:
            # In a worker process
            per_worker = len(self.df) // worker_info.num_workers
            worker_id = worker_info.id
            start = worker_id * per_worker
            end = start + per_worker if worker_id != worker_info.num_workers - 1 else len(self.df)

        # Only iterate over your slice
        for sub in self.df[start:end].iter_slices(512):
            rows = [
                (torch.tensor(fen_to_onehot(row[0]), dtype=torch.float32), torch.tensor(row[1], dtype=torch.float32))
                for row in sub.iter_rows()
            ]
            random.shuffle(rows)
            yield from rows

In [45]:
val_dataset = EndgameDataset(df.slice(0, 100_000))
train_dataset = EndgameDataset(df.slice(100_000))

In [46]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512 * 20, num_workers=30)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=512 * 2, num_workers=10)

In [47]:
#def material_count_predict_batch(inputs):
#    return (inputs[:, DIM] >= 0).clone().detach().unsqueeze(1)
#
#correct = 0
#total = 0
#
#with torch.no_grad():
#    for inputs, labels in val_loader:
#        preds = material_count_predict_batch(inputs)
#        correct += (preds.squeeze() == labels).sum().item()
#        total += len(labels)
#
#accuracy = correct / total
#print(f"Material count prediction acc: {accuracy:.4f}")  # ~75%

In [48]:
class EndgameNet(lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(DIM + EXTRA_DIM, 64),
            nn.GELU(),
            nn.Linear(64, 16),
            nn.GELU(),
            nn.Linear(16, 1),
        )
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.unsqueeze(1).float())
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.unsqueeze(1).float())
        preds = torch.sigmoid(logits) > 0.5
        acc = (preds == y.unsqueeze(1)).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [49]:
model = EndgameNet()
trainer = lightning.pytorch.Trainer(max_epochs=1, limit_train_batches=0.1, accelerator="auto")
trainer.fit(model, train_loader, val_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | Sequential        | 50.3 K | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
50.3 K    Trainable params
0         Non-trainable params
50.3 K    Total params
0.201     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=1` reached.
