In [1]:
!pip install datasets



In [2]:
!pip install polars



In [38]:
!pip install torch

Collecting torch
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl (766.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m766.7/766.7 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting networkx
  Downloading networkx-3.4.2-py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m84.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m59.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.4.127
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m99.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==1

In [126]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.5.1-py3-none-any.whl (818 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m818.9/818.9 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting lightning-utilities<2.0,>=0.10.0
  Downloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)
Collecting packaging<25.0,>=20.0
  Downloading packaging-24.2-py3-none-any.whl (65 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.5/65.5 kB[0m [31m36.5 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics<3.0,>=0.7.0
  Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m961.5/961.5 kB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.1-py3-none-any.whl (822 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.0/823.0 kB[0m [31m33.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packag

In [116]:
import datasets

In [114]:
def count_pieces(fen: str) -> int:
    return sum(c.isalpha() for c in fen.split(" ", 1)[0])

assert count_pieces("QN4n1/6r1/3k4/8/b2K4/8/8/8 b - - 0 1") == 7

def batch_piece_count(batch):
    return {"piece_count": [count_pieces(fen) for fen in batch["fen"]]}

def prepare_ds(split):
    ds = datasets.load_dataset("Lichess/chess-position-evaluations", split=split)
    ds = ds.remove_columns(["line", "depth", "knodes"])
    ds = ds.map(batch_piece_count, batched=True, batch_size=10000, num_proc=30)
    ds = ds.filter(lambda x: x["piece_count"] <= 8, num_proc=30)
    return ds

In [14]:
prepare_ds("train").to_parquet("endgames.parquet")

Creating parquet from Arrow format:   0%|          | 0/47975 [00:00<?, ?ba/s]

2916439133

In [117]:
import polars as pl

In [118]:
df = pl.read_parquet("endgames.parquet")

In [119]:
df.group_by("piece_count").len().sort("piece_count")

piece_count,len
i64,u32
3,1302608
4,8308036
5,9705053
6,9458272
7,9717520
8,9482848


In [22]:
import numpy as np

In [33]:
PIECE_TO_PLANE = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11,
}

def fen_to_onehot(fen: str) -> np.ndarray:
    board, turn, _ = fen.split(" ", 2)
    board_tensor = np.zeros((12, 8, 8), dtype=np.uint8)
    for (y, rank) in enumerate(board.split("/")):
        x = 0
        for ch in rank:
            if ch.isdigit():
                x += int(ch)
            else:
                board_tensor[PIECE_TO_PLANE[ch], y, x] = 1
                x += 1
    flat_tensor = board_tensor.reshape(-1)
    turn_tensor = np.array([int(turn == "w")], dtype=np.uint8)
    return np.concatenate([flat_tensor, turn_tensor])

In [35]:
fen_to_onehot("4k3/8/8/8/8/8/8/3K4 w - - 0 1")

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [43]:
import torch.utils.data
import torch

In [63]:
class EndgameDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        input_tensor = torch.tensor(fen_to_onehot(self.df["fen"][idx]), dtype=torch.float32)

        mate = self.df["mate"][idx]
        cp = self.df["cp"][idx]
        label_tensor = torch.tensor(int(mate is not None and mate > 0 or cp is not None and cp > 0), dtype=torch.float32)
        
        return input_tensor, label_tensor

In [121]:
full_dataset = EndgameDataset(df)

debug_size = 10_000
val_size = 100_000
train_dataset, debug_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [len(full_dataset) - debug_size - val_size, debug_size, val_size])

In [149]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10_000, shuffle=True, num_workers=30)
debug_loader = torch.utils.data.DataLoader(debug_dataset, batch_size=1000, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10_000, shuffle=False, num_workers=10)

In [134]:
import torch.nn as nn
import lightning
import lightning.pytorch

In [147]:
def material_count_predict(x):
    balance = 0
    for i, v in enumerate([1, 3, 3, 5, 9, 0, -1, -3, -3, -5, -9, 0]):
        balance += sum(x[64 * i:64 * (i + 1)]) * v
    return int(balance >= 0)

def material_count_predict_batch(inputs):
    return torch.tensor([material_count_predict(x) for x in inputs], dtype=torch.int, device=inputs.device).unsqueeze(1)

correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in debug_loader:
        preds = material_count_predict_batch(inputs)
        correct += (preds.squeeze() == labels).sum().item()
        total += len(labels)

accuracy = correct / total
print(f"Material count prediction acc: {accuracy:.4f}")

Material count prediction acc: 0.7568


In [150]:
class EndgameNet(lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(64 * 6 * 2 + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.unsqueeze(1).float())
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.unsqueeze(1).float())
        preds = torch.sigmoid(logits) > 0.5
        acc = (preds == y.unsqueeze(1)).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [None]:
model = EndgameNet()
trainer = lightning.pytorch.Trainer(max_epochs=10, accelerator="auto")
trainer.fit(model, train_loader, val_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | Sequential        | 50.3 K | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
50.3 K    Trainable params
0         Non-trainable params
50.3 K    Total params
0.201     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                          | 0/? [00:00<?, ?it/s]

Training: |                                                                 | 0/? [00:00<?, ?it/s]