In [1]:
import os
import math
from datetime import datetime
from glob import glob
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from lib.service import SamplesService
from lib.model import NnueModel
from lib.model import decode_int64_bitset
from lib.serialize import NnueWriter

2024-03-17 20:12:08.218832: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-17 20:12:08.218912: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-17 20:12:08.220536: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-17 20:12:08.229350: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class PQRLoss(torch.nn.Module):
    def __init__(self):
        super(PQRLoss, self).__init__()

    def forward(self, pred):
        pred = pred.reshape(-1, 3)
        
        p = pred[:,0]
        q = pred[:,1]
        r = pred[:,2]
        
        a = -torch.mean(torch.log(torch.sigmoid(r - q)))
        b = torch.mean(torch.square(p + q))

        loss = a + b

        return loss

class EvalLoss(torch.nn.Module):
    def __init__(self):
        super(EvalLoss, self).__init__()

    def forward(self, output, target):
        # go from nnue units to engine units
        output = output * 600.0

        scaling = 356.0

        # scale CP score to engine units [-10_000, 10_000]
        target = target * scaling / 100.0

        # targets are in CP-space change it to WDL-space [0, 1]
        wdl_model = torch.sigmoid(output / scaling)
        wdl_target = torch.sigmoid(target / scaling)

        loss = torch.pow(torch.abs(wdl_model - wdl_target), 2.5)

        return loss.mean()

In [3]:
EPOCHS = 100000
BATCHES_PER_EPOCH = 1000
BATCH_SIZE = 4096

FEATURE_SET = "basic"
NUM_FEATURES = 768
METHOD = "eval"

if METHOD == "pqr":
    X_SHAPE = (BATCH_SIZE, 3, 2, NUM_FEATURES // 64)
    Y_SHAPE = (BATCH_SIZE, 0)
    INPUTS = glob("/mnt/d/datasets/pqr-1700/*.csv")
    loss_fn = PQRLoss()
elif METHOD == "eval":
    X_SHAPE = (BATCH_SIZE, 2, NUM_FEATURES // 64)
    Y_SHAPE = (BATCH_SIZE, 1)
    INPUTS = glob("/mnt/d/datasets/eval/*.csv")
    loss_fn = EvalLoss()

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
folder = f'runs/{timestamp}_{METHOD}_{FEATURE_SET}_{BATCH_SIZE}'
os.makedirs(f'{folder}/models', exist_ok=True)

samples_service = SamplesService(x_shape=X_SHAPE, y_shape=Y_SHAPE, inputs=INPUTS, feature_set=FEATURE_SET, method=METHOD)
chessmodel = NnueModel(num_features=NUM_FEATURES)
chessmodel.cuda()

#for i in tqdm(range(1000000)):
#    a = samples_service.next_batch()

optimizer = torch.optim.Adam(chessmodel.parameters(), lr=0.0015)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', threshold=0.0001, factor=0.7, patience=10)
writer = SummaryWriter(folder)

# @torch.compile # 30% speedup
def train_step(X, y):
    # Clear the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = chessmodel(X)

    # Compute the loss
    loss = loss_fn(outputs, y)
    loss.backward()

    # Update the parameters
    optimizer.step()

    chessmodel.clip_weights()

    return loss

# Make sure gradient tracking is on
chessmodel.train()

for epoch in range(EPOCHS):
    avg_loss = 0.0

    for _ in tqdm(range(BATCHES_PER_EPOCH), desc=f'Epoch {epoch}'):
        X, y = samples_service.next_batch()
    
        # expand bitset
        X = decode_int64_bitset(X)
        X = X.reshape(-1, 2, NUM_FEATURES)

        loss = train_step(X, y)
        avg_loss += loss.item()

        if math.isnan(avg_loss):
            raise Exception("Loss is NaN, exiting")

    avg_loss /= BATCHES_PER_EPOCH

    # Step the scheduler
    scheduler.step(avg_loss)

    writer.add_scalar('Train/loss', avg_loss, epoch)
    writer.add_scalar('Train/lr', scheduler._last_lr[0], epoch) # get_last_lr()
    writer.add_scalar('Params/mean-f1', torch.mean(chessmodel.ft.weight), epoch)
    writer.add_scalar('Params/mean-l1', torch.mean(chessmodel.linear1.weight), epoch)
    writer.add_scalar('Params/mean-l2', torch.mean(chessmodel.linear2.weight), epoch)
    writer.add_scalar('Params/mean-out', torch.mean(chessmodel.output.weight), epoch)
    for name, param in chessmodel.named_parameters():
        writer.add_histogram(name, param, epoch)
    writer.flush()

    # save model
    model_path = f'{folder}/models/{epoch}'
    torch.save(chessmodel.state_dict(), f'{model_path}.pth')
    nn_writer = NnueWriter(chessmodel, FEATURE_SET)
    with open(f'{model_path}.nn', "wb") as f:
        f.write(nn_writer.buf)

Epoch 0: 100%|██████████| 1000/1000 [00:11<00:00, 87.85it/s]
Epoch 1: 100%|██████████| 1000/1000 [00:10<00:00, 91.62it/s]
Epoch 2: 100%|██████████| 1000/1000 [00:11<00:00, 90.40it/s]
Epoch 3: 100%|██████████| 1000/1000 [00:12<00:00, 77.30it/s]
Epoch 4: 100%|██████████| 1000/1000 [00:11<00:00, 84.54it/s]
Epoch 5: 100%|██████████| 1000/1000 [00:12<00:00, 79.19it/s]
Epoch 6: 100%|██████████| 1000/1000 [00:11<00:00, 87.09it/s]
Epoch 7: 100%|██████████| 1000/1000 [00:09<00:00, 101.93it/s]
Epoch 8: 100%|██████████| 1000/1000 [00:10<00:00, 96.33it/s]
Epoch 9: 100%|██████████| 1000/1000 [00:11<00:00, 84.93it/s]
Epoch 10: 100%|██████████| 1000/1000 [00:11<00:00, 88.63it/s]
Epoch 11: 100%|██████████| 1000/1000 [00:10<00:00, 91.83it/s]
Epoch 12: 100%|██████████| 1000/1000 [00:11<00:00, 89.34it/s]
Epoch 13: 100%|██████████| 1000/1000 [00:10<00:00, 91.67it/s]
Epoch 14: 100%|██████████| 1000/1000 [00:10<00:00, 93.71it/s]
Epoch 15: 100%|██████████| 1000/1000 [00:10<00:00, 91.77it/s]
Epoch 16: 100%|██

KeyboardInterrupt: 