In [1]:
import math
from datetime import datetime
from glob import glob
from tqdm import tqdm
import torch
import wandb
import tempfile

from lib.service import SamplesService
from lib.model import NnueModel
from lib.model import decode_int64_bitset
from lib.serialize import NnueWriter
from lib.puzzles import PuzzleAccuracy
from lib.losses import EvalLoss, PQRLoss

In [2]:
EPOCHS = 100000
BATCHES_PER_EPOCH = 1000
BATCH_SIZE = 4096

FEATURE_SET = "half-piece"
NUM_FEATURES = 768 # 192 768 40960
NUM_FT = 256
NUM_L1 = 64
NUM_L2 = 32
METHOD = "eval"

if METHOD == "pqr":
    X_SHAPE = (BATCH_SIZE, 3, 2, NUM_FEATURES // 64)
    Y_SHAPE = (BATCH_SIZE, 0)
    INPUTS = glob("/mnt/d/datasets/pqr-1700/*.csv")
    loss_fn = PQRLoss()
elif METHOD == "eval":
    X_SHAPE = (BATCH_SIZE, 2, NUM_FEATURES // 64)
    Y_SHAPE = (BATCH_SIZE, 1)
    INPUTS = glob("/mnt/d/datasets/eval/*.csv")
    loss_fn = EvalLoss()

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f'{timestamp}_{METHOD}_{FEATURE_SET}_{BATCH_SIZE}'
run = wandb.init(
    project="cs-master-thesis",
    name=run_name,
    job_type="train",
    config={
        "feature_set": FEATURE_SET,
        "method": METHOD,
        "batch_size": BATCH_SIZE,
        "batches_per_epoch": BATCHES_PER_EPOCH,
    }
)

puzzles = PuzzleAccuracy('./data/puzzles.csv')
samples_service = SamplesService(x_shape=X_SHAPE, y_shape=Y_SHAPE, inputs=INPUTS, feature_set=FEATURE_SET, method=METHOD)
chessmodel = NnueModel(num_features=NUM_FEATURES, num_ft=NUM_FT, num_l1=NUM_L1, num_l2=NUM_L2)
chessmodel.cuda()

#for i in tqdm(range(1000000)):
#    a = samples_service.next_batch()

optimizer = torch.optim.Adam(chessmodel.parameters(), lr=0.0015)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', threshold=0.0001, factor=0.7, patience=50)

# @torch.compile # 30% speedup
def train_step(X, y):
    # Clear the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = chessmodel(X)

    # Compute the loss
    loss = loss_fn(outputs, y)
    loss.backward()

    # Update the parameters
    optimizer.step()

    chessmodel.clip_weights()

    return loss

# Make sure gradient tracking is on
chessmodel.train()

for epoch in range(EPOCHS):
    avg_loss = 0.0

    for _ in tqdm(range(BATCHES_PER_EPOCH), desc=f'Epoch {epoch}'):
        X, y = samples_service.next_batch()
    
        # expand bitset
        X = decode_int64_bitset(X)
        X = X.reshape(-1, 2, NUM_FEATURES)

        loss = train_step(X, y)
        avg_loss += loss.item()

        if math.isnan(avg_loss):
            raise Exception("Loss is NaN, exiting")

    avg_loss /= BATCHES_PER_EPOCH

    # Step the scheduler
    scheduler.step(avg_loss)

    # log metrics to W&B
    wandb.log(step=epoch, data={
        "Train/loss": avg_loss,
        "Train/lr": scheduler._last_lr[0], # get_last_lr()

        "Weight/mean-f1": torch.mean(chessmodel.ft.weight),
        "Weight/mean-l1": torch.mean(chessmodel.linear1.weight),
        "Weight/mean-l2": torch.mean(chessmodel.linear2.weight),
        "Weight/mean-out": torch.mean(chessmodel.output.weight),
    })

    # save model
    with tempfile.NamedTemporaryFile() as tmp:
        tmp.write(NnueWriter(chessmodel, FEATURE_SET).buf)
    
        # store artifact in W&B
        artifact = wandb.Artifact(run_name, type="model")
        artifact.add_file(tmp.name, name=f"{epoch}.nn")
        wandb.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])

        if epoch % 30 == 0:
            # run puzzles
            puzzles_results, puzzles_accuracy = puzzles.measure(["/mnt/c/Users/mlomb/Desktop/Tesis/cs-master-thesis/engine/target/release/engine", f"--nn={tmp.name}"])

            wandb.log(step=epoch, data={"Puzzles/accuracy": puzzles_accuracy})
            for category, accuracy in puzzles_results:
                wandb.log(step=epoch, data={
                    f"Puzzles/{category}": accuracy
                })


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmlomb[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: 100%|██████████| 1000/1000 [00:13<00:00, 73.60it/s]
100%|██████████| 4897/4897 [01:27<00:00, 55.68it/s]
Epoch 1: 100%|██████████| 1000/1000 [00:12<00:00, 81.12it/s]
Epoch 2: 100%|██████████| 1000/1000 [00:11<00:00, 85.12it/s]
Epoch 3: 100%|██████████| 1000/1000 [00:11<00:00, 88.00it/s]
Epoch 4: 100%|██████████| 1000/1000 [00:11<00:00, 89.86it/s]
Epoch 5: 100%|██████████| 1000/1000 [00:11<00:00, 88.32it/s]
Epoch 6: 100%|██████████| 1000/1000 [00:11<00:00, 88.85it/s]
Epoch 7: 100%|██████████| 1000/1000 [00:11<00:00, 84.53it/s]
Epoch 8: 100%|██████████| 1000/1000 [00:12<00:00, 82.38it/s]
Epoch 9: 100%|██████████| 1000/1000 [00:11<00:00, 87.76it/s]
Epoch 10: 100%|██████████| 1000/1000 [00:11<00:00, 87.02it/s]
Epoch 11: 100%|██████████| 1000/1000 [00:11<00:00, 89.08it/s]
Epoch 12: 100%|██████████| 1000/1000 [00:11<00:00, 88.86it/s]
Epoch 13: 100%|██████████| 1000/1000 [00:11<00:00, 90.86it/s]
Epoch 14: 100%|██████████| 1000/1000 [00:11<00:00, 89.06it/s]
Epoch 15: 100%|██████████| 1

KeyboardInterrupt: 

In [None]:
wandb.finish()