In [1]:
import modules.board_module as bf
import modules.tree_module as tf
import modules.stockfish_module as sf
from ModelSaver import ModelSaver
import random
from dataclasses import dataclass
from collections import namedtuple
import itertools
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from math import floor, ceil

from train_nn_evaluator import EvalDataset

In [2]:
# load in the entire dataset
num_rand = 4096
datapath = "/home/luke/chess/python/gamedata/samples"
eval_file_template = "random_n={0}_sample"
inds = None# list(range(10))
dataset = EvalDataset(datapath, eval_file_template.format(num_rand),
                      indexes=inds, log_level=0)

In [22]:
print(f"Total number of positions = {len(dataset)}")
num_duplicates = dataset.check_duplicates()
num_mates = dataset.check_mate_positions()
print(f"Proportion of duplicates = {(num_duplicates / len(dataset))*100:.1f} %")
print(f"Proportion of mate positions = {(num_mates / len(dataset))*100:.1f} %")

# prepare the dataset
num_duplicates = dataset.check_duplicates(remove=True)
num_mates = dataset.check_mate_positions(remove=True)
dataset.board_dtype = torch.float
dataset.to_torch()

Total number of positions = 245760
Proportion of duplicates = 14.6 %
Proportion of mate positions = 0.2 %


In [4]:
class BoardCNN(nn.Module):

  name = "BoardCNN"

  def __init__(self):

    super(BoardCNN, self).__init__()

    self.board_cnn = nn.Sequential(

      # Layer 1
      nn.Conv2d(in_channels=19, out_channels=32, kernel_size=3, padding=1),  # Conv layer
      nn.ReLU(),                                                             # Activation
      nn.MaxPool2d(kernel_size=2),                                           # Pooling (output size: 32 x 4 x 4)

      # Layer 2
      nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2),                                           # Pooling (output size: 64 x 2 x 2)

      # Layer 3
      nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2),                                           # Pooling (output size: 128 x 1 x 1)

      # Flatten layer to transition to fully connected
      nn.Flatten(),

      # two fully connected layers to produce a single output
      nn.Linear(128, 64),
      nn.ReLU(),
      nn.Linear(64, 1),
  )

  def forward(self, board):
    board = board.to(self.device)
    x = self.board_cnn(board)
    return x

In [35]:
def train(net, data_x, data_y, epochs=1, lr=5e-5, device="cuda"):
  """
  Perform a training epoch for a given network based on data inputs
  data_x, and correct outputs data_y
  """

  # move onto the specified device
  net.board_cnn.to(device)
  data_x = data_x.to(device)
  data_y = data_y.to(device)

  # put the model in training mode
  net.board_cnn.train()

  lossfcn = nn.MSELoss()
  optim = torch.optim.Adam(net.board_cnn.parameters(), lr=lr)

  batch_size = 64
  num_batches = len(data_x) // batch_size

  for i in range(epochs):

    print(f"Starting epoch {i + 1}. There will be {num_batches} batches")

    rand_idx = torch.randperm(data_x.shape[0])
    avg_loss = 0

    for n in range(num_batches):

      batch_x = data_x[rand_idx[n * batch_size : (n+1) * batch_size]]
      batch_y = data_y[rand_idx[n * batch_size : (n+1) * batch_size]]

      # use the model for a prediction and calculate loss
      net_y = net.board_cnn(batch_x)
      loss = lossfcn(net_y, batch_y)

      # backpropagate
      loss.backward()
      optim.step()
      optim.zero_grad()

      avg_loss += loss.item()

      # if n % 500 == 0:
      #   print(f"Loss is {(avg_loss / (n + 1)) * 1000:.3f}, epoch {i + 1}, batch {n + 1} / {num_batches}")
    
    print(f"Loss is {(avg_loss / (num_batches)) * 1000:.3f}, at end of epoch {i + 1}")

  return net

In [38]:
net = BoardCNN()
device = "cuda"
epochs = 10
lr = 1e-7

# # normalise the evaluations
# print(torch.min(dataset.evals))
# print(torch.max(dataset.evals))
# dataset.evals /= torch.max(dataset.evals)
# print(torch.min(dataset.evals))
# print(torch.max(dataset.evals))

trained_net = train(net, dataset.boards, dataset.evals, device=device, epochs=epochs, lr=lr)

modelsaver = ModelSaver("/home/luke/chess/python/models/")

modelsaver.save("eval_model", trained_net)

Starting epoch 1. There will be 3272 batches
Loss is 20.681, at end of epoch 1
Starting epoch 2. There will be 3272 batches
Loss is 18.147, at end of epoch 2
Starting epoch 3. There will be 3272 batches
Loss is 15.684, at end of epoch 3
Starting epoch 4. There will be 3272 batches
Loss is 13.315, at end of epoch 4
Starting epoch 5. There will be 3272 batches
Loss is 11.077, at end of epoch 5
Starting epoch 6. There will be 3272 batches
Loss is 9.085, at end of epoch 6
Starting epoch 7. There will be 3272 batches
Loss is 7.371, at end of epoch 7
Starting epoch 8. There will be 3272 batches
Loss is 6.059, at end of epoch 8
Starting epoch 9. There will be 3272 batches
Loss is 5.177, at end of epoch 9
Starting epoch 10. There will be 3272 batches
Loss is 4.653, at end of epoch 10
Saving file /home/luke/chess/python/models/eval_model_003.lz4 with pickle ... finished


'/home/luke/chess/python/models/eval_model_003.lz4'

In [46]:
for i in range(10):
  sf_eval = dataset.positions[i].eval
  net_eval = trained_net.board_cnn(dataset.boards[i].to(device).unsqueeze(dim=0))
  net_eval = net_eval.to("cpu").item()

  print(f"sf_eval = {sf_eval * 1e-3:.3f}, net_eval = {net_eval:.3f}, difference is {sf_eval*1e-3 - net_eval:.3f}")

sf_eval = 0.000, net_eval = 0.006, difference is -0.006
sf_eval = -0.692, net_eval = -0.006, difference is -0.686
sf_eval = -0.605, net_eval = -0.032, difference is -0.573
sf_eval = 0.003, net_eval = -0.034, difference is 0.037
sf_eval = -0.778, net_eval = 0.001, difference is -0.779
sf_eval = 0.358, net_eval = -0.033, difference is 0.391
sf_eval = 0.010, net_eval = -0.037, difference is 0.047
sf_eval = -0.404, net_eval = 0.036, difference is -0.440
sf_eval = 0.022, net_eval = -0.042, difference is 0.064
sf_eval = -0.488, net_eval = 0.015, difference is -0.503
