In [1]:
# imports

## allow interaction with the plots
%matplotlib widget

%load_ext autoreload
%autoreload 2

# add the main directory reference and import 'imports.py'
import sys
import os

if ".." not in sys.path:
    sys.path.insert(0, "..")
from imports import *

logger = logging.getLogger()
logger.setLevel(level=logging.WARN)

__builtins__.verbosity = 4
# Set the HF_HOME environment variable
os.environ["HF_HOME"] = "/run/media/HUNK/DATASETS/HF"


In [4]:
class params:
    num_workers = 24
    # batch_size = 160
    batch_size = 64
    seq_len = 140
    files_per_epoch = 800

    max_epochs = 5
    # lr = 0.005
    lr = 0.005

    flush_epoch_units = False
    loss_every = 100


In [2]:
# load models

from lib.chess import chess_move_labels as le
from models.nanoRWKV import *
from models.nanoGPT import *

rwkv_model = RWKV(GPTConfig(vocab_size=len(le.classes_), bias=True, seq_len=140))

rwkv_model.load_from_pth(
    "out/RWKV__seq_len=140__vocab_size=7797__n_layer=12__n_head=12__n_embd=768__dropout=0.0__bias=True__lr=0.005__weight_decay=0.01__epoch=164-239.pth"
)

rwkv_model.to(device)

gpt_model = GPT(GPTConfig(vocab_size=len(le.classes_), bias=True, seq_len=140))

gpt_model.load_from_pth(
    "out/GPT__seq_len=140__vocab_size=7797__n_layer=12__n_head=12__n_embd=768__dropout=0.0__bias=True__lr=0.005__weight_decay=0.01__epoch=86-132.pth"
)
gpt_model.to(device)

number of parameters: 91.13M


  checkpoint = torch.load(pth_path)


number of parameters: 91.04M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(7797, 768)
    (wpe): Embedding(140, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=768, out_features=7797, bias=False)
)

In [5]:
# synthetic data

# from lib.chess import generate_random_games_to_pgn

# with change_dir("resources/synthetic"):
#     generate_random_games_to_pgn(500, game_length=params.seq_len + 1)

from data.pgnSeq import *

synth_dt = PGNData(
    PGNDataConfig.create(params, directory="resources/synthetic", files_per_epoch=20)
)

# synth_dt.preview(1)

In [13]:
# prepatory function defs

# customize trainer
from extras.train import SeqTrainer
from lib.optim import WarmupCosineLR


def eval_legal_movemaking(model, data):
    trainer = SeqTrainer(
        TrainerConfig.create(
            params,
        )
    )
    trainer.init(model, data.loaders())

    evaluator = [[], []]

    def batch_fun(outputs, batch, batch_num):
        for pred_seq, input_seq in zip(map(model.pred, outputs), batch[0]):
            try:
                original = data.le.inverse_transform(input_seq.flatten().cpu())
                evaluator[0].append(original)

                decoded = data.le.inverse_transform(pred_seq.flatten().cpu())
                evaluator[1].append(decoded)
            except:
                pass

    loaders = data.loaders()
    loader = loaders[1] if len(loaders) > 1 else loaders[0]

    trainer.eval(loader=loader, batch_fun=batch_fun)

    total = 0
    legal = 0

    for game, preds in zip(*evaluator):
        board = chess.Board()
        for move, pred in zip(game, preds):
            if move == "<PAD>":
                break
            else:
                board.push_uci(move[1:])
                if pred == "<PAD>":
                    continue
                total += 1
                if chess.Move.from_uci(pred[1:]) in list(board.legal_moves):
                    legal += 1

    print(legal, total, legal / total * 100)

In [8]:
eval_legal_movemaking(rwkv_model, synth_dt)
eval_legal_movemaking(gpt_model, synth_dt)


In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.

The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.

See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.

or by setting POLARS_ALLOW_FORKING_THREAD=1.

  self.pid = os.fork()
In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.

The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.

See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.

or by setting POLARS_ALLOW_FORKING_THREAD=1.

  self.pid = os.fork()


213906 273515 78.20631409611904
22023 110911 19.856461487138336


In [14]:
elite_dt = PGNData(
    PGNDataConfig.create(
        params,
        directory="resources/lichess_elite",  # irrelevant
        val_directory="resources/lichess_elite_val",
        files_per_epoch=None,
        max_games_per_file=999,
    )
)

elite_dt.preview(1)



Loader 0 (IterableDataset) Preview:
--------------------------------------------------


Traceback (most recent call last):
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/shutil.py", line 759, in rmtree
    _rmtree_safe_fd(stack, onexc)
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/shutil.py", line 703, in _rmtree_safe_fd
    onexc(func, path, err)
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/shutil.py", line 750, in onexc
    return onerror(func, path, exc_info)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^

Constituent shapes:
batch[0]: torch.Size([64, 141]), torch.int64

First 1 samples:

Sample 0: 

tensor([1435, 5328, 1440, 6162, 1010, 6078, 1404, 5298, 1256, 5189,  363, 4097,
        1527, 4368,  777, 5360, 1497, 5139, 1529, 4436,  158, 4331, 1064, 4156,
        1170, 4086, 1076, 5541, 1042, 5949, 2089, 4366, 1373, 5443, 2122, 6106,
        1410, 5295, 1380, 5699,  129, 5355, 1214, 4301,  294, 4942, 3567, 4727,
         381, 7009,  392, 4786, 3461, 4168, 3013, 4843, 3460,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0

Traceback (most recent call last):
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/shutil.py", line 759, in rmtree
    _rmtree_safe_fd(stack, onexc)
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/shutil.py", line 703, in _rmtree_safe_fd
    onexc(func, path, err)
  File "/ARCHIVE/Personal/2186474940/.pixi/envs/dev/lib/python3.12/shutil.py", line 750, in onexc
    return onerror(func, path, exc_info)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^

Constituent shapes:
batch[0]: torch.Size([4, 141]), torch.int64

First 1 samples:

Sample 0: 

tensor([1373, 5297, 1010, 4943, 1496, 5422,  366, 4333, 1435, 5360, 1255, 5188,
         777, 4728, 1403, 5236,  128, 4983, 2089, 7563,  307, 4395, 1157, 5292,
        1063, 7006, 1465, 5266, 1341, 5328, 1440, 5355, 1379, 5094,  545, 5391,
         476, 7118, 1382, 7107, 1385, 6164, 3567, 5418,  452, 5414, 1157, 5127,
         520, 5021,  890, 5111,  270, 5382, 1526, 7435, 3011, 7211, 3453, 4097,
         356, 5387,  178, 4842, 3229, 5001,  249, 4835, 3247, 4827, 3297, 4455,
         842, 5260, 3082, 4132, 3416, 4908, 3395, 5443, 1320, 4955,  383, 4884,
        3066, 4946, 3175, 4919,  314, 5439, 1501, 4821, 1324, 5384, 1326, 5381,
         249, 5021, 1407, 4200, 3166, 5112, 1328, 4365,  895, 4275,  838, 5018,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,

In [15]:
eval_legal_movemaking(rwkv_model, elite_dt)
eval_legal_movemaking(gpt_model, elite_dt)


424015 615250 68.91751320601381
133066 381591 34.87136751128826
