In [1]:
# imports

## allow interaction with the plots
%matplotlib widget

%load_ext autoreload
%autoreload 2

# add the main directory reference and import 'imports.py'
import sys
import os

if ".." not in sys.path:
    sys.path.insert(0, "..")
from imports import *

logger = logging.getLogger()
logger.setLevel(level=logging.WARN)

# Set the HF_HOME environment variable
os.environ["HF_HOME"] = "/run/media/HUNK/DATASETS/HF"

In [2]:
# Load dataset

from datasets import load_dataset

train_set = load_dataset(
    "Lichess/standard-chess-games",
    split="train",
    streaming=True,
    columns=["TimeControl", "movetext", "WhiteElo", "BlackElo"],
)


Resolving data files:   0%|          | 0/25131 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/25131 [00:00<?, ?it/s]

In [3]:
def min_time(x):
    time_string = x["TimeControl"]
    for i, char in enumerate(time_string):
        if not char.isdigit():
            initial_time_string = time_string[:i]
    if initial_time_string:
        return int(initial_time_string) >= 300
    return False


def min_elo(x):
    if not (x["WhiteElo"] and x["BlackElo"]):
        return False
    return (int(x["WhiteElo"]) + int(x["BlackElo"])) > 4000


train_set = train_set.filter(min_time)
train_set = train_set.filter(min_elo)


In [47]:
from lib.chess import *

sample = []
all_moves = generate_chess_moves()

for row in all_moves:
    print(row)

for row in train_set.take(5):
    sample.append(convert_pgn_to_move_strings(row["movetext"], seq_len=80))
print(sample)
assert all([[x in all_moves for x in game] for game in sample])

Ke1g1
Ke1c1
ke8g8
ke8c8
Pa2a3
Pa2a4
Pa2b3
Pa3a4
Pa3b4
Pa4a5
Pa4b5
Pa5a6
Pa5b6
Pa6a7
Pa6b7
Pa7a8
Pa7b8
Pa7a8=Q
Pa7a8=R
Pa7a8=B
Pa7a8=N
Pa7b8=Q
Pa7b8=R
Pa7b8=B
Pa7b8=N
Pb2b3
Pb2b4
Pb2a3
Pb2c3
Pb3b4
Pb3a4
Pb3c4
Pb4b5
Pb4a5
Pb4c5
Pb5b6
Pb5a6
Pb5c6
Pb6b7
Pb6a7
Pb6c7
Pb7b8
Pb7a8
Pb7c8
Pb7a8=Q
Pb7a8=R
Pb7a8=B
Pb7a8=N
Pb7b8=Q
Pb7b8=R
Pb7b8=B
Pb7b8=N
Pb7c8=Q
Pb7c8=R
Pb7c8=B
Pb7c8=N
Pc2c3
Pc2c4
Pc2b3
Pc2d3
Pc3c4
Pc3b4
Pc3d4
Pc4c5
Pc4b5
Pc4d5
Pc5c6
Pc5b6
Pc5d6
Pc6c7
Pc6b7
Pc6d7
Pc7c8
Pc7b8
Pc7d8
Pc7b8=Q
Pc7b8=R
Pc7b8=B
Pc7b8=N
Pc7c8=Q
Pc7c8=R
Pc7c8=B
Pc7c8=N
Pc7d8=Q
Pc7d8=R
Pc7d8=B
Pc7d8=N
Pd2d3
Pd2d4
Pd2c3
Pd2e3
Pd3d4
Pd3c4
Pd3e4
Pd4d5
Pd4c5
Pd4e5
Pd5d6
Pd5c6
Pd5e6
Pd6d7
Pd6c7
Pd6e7
Pd7d8
Pd7c8
Pd7e8
Pd7c8=Q
Pd7c8=R
Pd7c8=B
Pd7c8=N
Pd7d8=Q
Pd7d8=R
Pd7d8=B
Pd7d8=N
Pd7e8=Q
Pd7e8=R
Pd7e8=B
Pd7e8=N
Pe2e3
Pe2e4
Pe2d3
Pe2f3
Pe3e4
Pe3d4
Pe3f4
Pe4e5
Pe4d5
Pe4f5
Pe5e6
Pe5d6
Pe5f6
Pe6e7
Pe6d7
Pe6f7
Pe7e8
Pe7d8
Pe7f8
Pe7d8=Q
Pe7d8=R
Pe7d8=B
Pe7d8=N
Pe7e8=Q
Pe7e8=R
Pe7e8=B
Pe7e8=N
Pe7f8=Q
Pe7f8=R
Pe7f8=B
Pe

In [63]:
@dataclass(kw_only=True)
class PGNDataConfig(DataConfig):
    seq_len: int = 129


class PGNData(ClassifierData):
    def __init__(self, c: PGNDataConfig):
        super().__init__()
        self.save_config(c)
        self.le = LabelEncoder()
        self.le.fit(["<PAD>"] + generate_chess_moves())

        def transform(row):
            return {
                "game": torch.tensor(
                    self.le.transform(
                        convert_pgn_to_move_strings(
                            row["movetext"], seq_len=self.seq_len + 1, pad_token="<PAD>"
                        )
                    )
                )
            }

        info_columns = train_set.column_names
        info_columns.remove("movetext")
        games_only = train_set.remove_columns(info_columns)

        self.dataset = games_only.map(transform, remove_columns="movetext")


dt = PGNData(PGNDataConfig(num_workers=16, batch_size=125))


In [64]:
dt.classes
dt.preview()



Loader 0 (IterableDataset) Preview:
--------------------------------------------------
Constituent shapes:
batch[game]: torch.Size([32, 130]), torch.int64

First 5 samples:

Sample 0: 

tensor([1404, 5328, 1373, 5360, 1256, 5390, 1010, 5298, 1380, 5355, 1216, 5189,
        1434, 4331, 1466, 4728,  364, 4095, 1527, 5137,  777, 4944, 1064, 5383,
         222, 5042, 1411, 4155, 1319, 6154, 2089, 7008, 1342, 4068,  126, 5235,
          68, 5738, 3011, 5232, 1323, 7335, 1348, 5294,  232, 5267,  146, 5950,
         232, 5322, 1437, 7000, 2120, 5292,  146, 6925, 2330, 6139,  892, 4245,
        1497, 5350, 3236, 6067,  949, 7567, 3348, 4840, 3566, 7509, 2313, 5348,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,