In [1]:
import sys
sys.path.append('../')

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from tensorflow.python.client import device_lib
print([dev.name for dev in device_lib.list_local_devices()])

['/device:CPU:0', '/device:GPU:0']


# Load dataset

In [2]:
import math
from glob import glob
import tensorflow as tf

In [3]:
files = glob("../data/dataset/2018_1600.bin")
record_size = 3 * 12 * 8
batch_size = 4096
dataset_size = sum([os.path.getsize(f) for f in files]) / record_size
batches_per_epoch = math.ceil(dataset_size / batch_size)

dataset = tf.data.FixedLengthRecordDataset(filenames=files, record_bytes=record_size)
dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
dataset = dataset.apply(tf.data.experimental.copy_to_device('/gpu:0'))
dataset = dataset.prefetch(tf.data.AUTOTUNE)

#import tensorflow_datasets as tfds
#tfds.benchmark(dataset, batch_size=batch_size)

dataset_size

1774212.0

# Train

In [4]:
from tqdm import tqdm
from time import time
import keras
from keras.models import Model
from keras.layers import Dense
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint, TensorBoard, CallbackList
from lib.encoding import encode_board, decode_board

In [5]:
def custom_loss(y_pred):
    """
    Compute loss as defined in https://erikbern.com/2014/11/29/deep-learning-for-chess.html
    // sum(p,q,r)logS(f(q)−f(r))+K*log(f(p)+f(q))+K*log(−f(q)−f(p))
    """
    p = y_pred[:,0]
    q = y_pred[:,1]
    r = y_pred[:,2]
    K = 1.0

    rq_diff = r - q
    pq_diff = K * (p + q)

    a = - tf.math.reduce_mean(tf.math.log(tf.math.sigmoid(rq_diff)))
    b = - tf.math.reduce_mean(tf.math.log(tf.math.sigmoid( pq_diff)))
    c = - tf.math.reduce_mean(tf.math.log(tf.math.sigmoid(-pq_diff)))

    return a + b + c

def make_chess_model():
    inp = tf.keras.Input(shape=(12,), dtype=tf.int64)
    x = decode_board(inp) # convert 12 ints to 768 floats
    x = Dense(256, activation="relu")(x)
    x = Dense(256, activation="relu")(x)
    x = Dense(256, activation="relu")(x)
    x = Dense(1)(x)
    return Model(inp, x)

chess_model = make_chess_model()
chess_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 12)]              0         
                                                                 
 tf.expand_dims (TFOpLambda  (None, 12, 1)             0         
 )                                                               
                                                                 
 tf.bitwise.bitwise_and (TF  (None, 12, 64)            0         
 OpLambda)                                                       
                                                                 
 tf.math.not_equal (TFOpLam  (None, 12, 64)            0         
 bda)                                                            
                                                                 
 tf.cast (TFOpLambda)        (None, 12, 64)            0         
                                                             

In [6]:
ts = int(time())
epochs = 9999
#optimizer = SGD(learning_rate=0.03, nesterov=True, momentum=0.9, clipnorm=1)
optimizer = Adam(learning_rate=0.001)
loss_tracker = keras.metrics.Mean(name="loss")
callbacks = CallbackList([
    ModelCheckpoint(f"checkpoints/{ts}" + "/model-{epoch:04d}-{loss:.3f}.keras", monitor="loss", save_best_only=True),
    TensorBoard(log_dir=f"./logs/{ts}", write_graph=False)    
], model=chess_model)

@tf.function
def train_step(batch):
    batch = tf.reshape(tf.io.decode_raw(batch, tf.int64), (-1, 3, 12))

    # Open a GradientTape to record the operations run
    # during the forward pass, which enables auto-differentiation.
    with tf.GradientTape() as tape:
        # Run the forward pass of the layer.
        # The operations that the layer applies
        # to its inputs are going to be recorded
        # on the GradientTape.
        logits = tf.reshape(chess_model(tf.reshape(batch, (-1, 12)), training=True), (-1, 3))  # Logits for this minibatch

        # Compute the loss value for this minibatch.
        loss_value = custom_loss(logits)

    # Use the gradient tape to automatically retrieve
    # the gradients of the trainable variables with respect to the loss.
    grads = tape.gradient(loss_value, chess_model.trainable_weights)

    # Run one step of gradient descent by updating
    # the value of the variables to minimize the loss.
    optimizer.apply_gradients(zip(grads, chess_model.trainable_weights))

    # Update metrics
    loss_tracker.update_state(loss_value)

    return loss_value

callbacks.on_train_begin()
for epoch in range(epochs):
    loss_tracker.reset_states()
    callbacks.on_epoch_begin(epoch)

    batch_i = 0
    batch_i_last = 0
    with tqdm(total=batches_per_epoch, bar_format=f"Epoch {epoch+1}/{epochs}" + " {l_bar}{bar:10}{r_bar}{bar:-10b}") as pbar:
        for batch in dataset:
            loss_value = train_step(batch)

            batch_i += 1
            if batch_i % 10 == 0 or batch_i == batches_per_epoch:
                pbar.set_postfix_str(f"loss={loss_tracker.result():.4f} loss_batch={float(loss_value):.4f}")
                pbar.update(batch_i - batch_i_last)
                batch_i_last = batch_i

    import chess
    good_board = chess.Board("2kr3r/1pp1pp1p/1p6/q4bP1/2B5/4BP2/Pb1NQK1P/R2R4 w - - 1 18")
    bad_board = chess.Board("2kr3r/1pp1pp1p/1p6/q4bP1/2B5/4BP2/Pb1NQK1P/R1R5 w - - 1 18")
    pred = chess_model.predict(tf.concat([encode_board(good_board), encode_board(bad_board)], axis=0), verbose=0)

    callbacks.on_epoch_end(epoch, logs={"loss": loss_tracker.result(), "good": pred[0][0], "bad": pred[1][0]})
callbacks.on_train_end()


I0000 00:00:1706387549.638982   16636 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Epoch 1/9999 100%|██████████| 434/434 [00:08<00:00, 49.24it/s, loss=1.9578 loss_batch=1.9069]
Epoch 2/9999 100%|██████████| 434/434 [00:04<00:00, 93.21it/s, loss=1.8917 loss_batch=1.8517]
Epoch 3/9999 100%|██████████| 434/434 [00:04<00:00, 93.64it/s, loss=1.8681 loss_batch=1.8139]
Epoch 4/9999 100%|██████████| 434/434 [00:04<00:00, 94.37it/s, loss=1.8550 loss_batch=1.7942]
Epoch 5/9999 100%|██████████| 434/434 [00:05<00:00, 84.62it/s, loss=1.8461 loss_batch=1.7723]
Epoch 6/9999 100%|██████████| 434/434 [00:04<00:00, 93.74it/s, loss=1.8410 loss_batch=1.7568]
Epoch 7/9999 100%|██████████| 434/434 [00:04<00:00, 94.39it/s, loss=1.8354 loss_batch=1.7483]
Epoch 8/9999 100%|██████████| 434/434 [00:04<00:00, 94.21it/s, loss=1.8292 loss_batch=1.7420]
Epoch 9/9999 100%|██████████| 434/434 [00:04<00:00, 94.49it/s, loss=1.8292 loss_batch=1.7329

KeyboardInterrupt: 