In [1]:
import modules.board_module as bf
import modules.tree_module as tf
import modules.stockfish_module as sf
from ModelSaver import ModelSaver
import random
from dataclasses import dataclass
from collections import namedtuple
import itertools
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from math import floor, ceil

from train_nn_evaluator import EvalDataset

In [2]:
# load in the entire dataset
num_rand = 4096
datapath = "/home/luke/chess/python/gamedata/samples"
eval_file_template = "random_n={0}_sample"
inds = None# list(range(10))
dataset = EvalDataset(datapath, eval_file_template.format(num_rand),
                      indexes=inds, log_level=0)

In [3]:
print(f"Total number of positions = {len(dataset)}")
num_duplicates = dataset.check_duplicates()
num_mates = dataset.check_mate_positions()
print(f"Proportion of duplicates = {(num_duplicates / len(dataset))*100:.1f} %")
print(f"Proportion of mate positions = {(num_mates / len(dataset))*100:.1f} %")

# prepare the dataset
# num_duplicates = dataset.check_duplicates(remove=True)
# num_mates = dataset.check_mate_positions(remove=True)
dataset.board_dtype = torch.float
dataset.to_torch()

Total number of positions = 241664
Proportion of duplicates = 14.6 %
Proportion of mate positions = 0.2 %


In [4]:
class BoardCNN(nn.Module):

  name = "BoardCNN"

  def __init__(self):

    super(BoardCNN, self).__init__()

    self.board_cnn = nn.Sequential(

      # Layer 1
      nn.Conv2d(in_channels=19, out_channels=32, kernel_size=3, padding=1),  # Conv layer
      nn.ReLU(),                                                             # Activation
      nn.MaxPool2d(kernel_size=2),                                           # Pooling (output size: 32 x 4 x 4)

      # Layer 2
      nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2),                                           # Pooling (output size: 64 x 2 x 2)

      # Layer 3
      nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2),                                           # Pooling (output size: 128 x 1 x 1)

      # Flatten layer to transition to fully connected
      nn.Flatten(),

      # two fully connected layers to produce a single output
      nn.Linear(128, 64),
      nn.ReLU(),
      nn.Linear(64, 1),
  )

  def forward(self, board):
    board = board.to(self.device)
    x = self.board_cnn(board)
    return x

In [9]:
def train(net, data_x, data_y, epochs=1, lr=5e-5, device="cuda"):
  """
  Perform a training epoch for a given network based on data inputs
  data_x, and correct outputs data_y
  """

  # move onto the specified device
  net.board_cnn.to(device)
  data_x.to(device)
  data_y.to(device)

  # put the model in training mode
  net.board_cnn.train()

  lossfcn = nn.MSELoss()
  optim = torch.optim.Adam(net.board_cnn.parameters(), lr=lr)

  rand_idx = torch.randperm(data_x.shape[0])
  batch_size = 64
  num_batches = len(data_x) // batch_size

  for i in range(epochs):

    print(f"Starting epoch {i + 1}. There will be {num_batches} batches")

    for n in range(num_batches):

      batch_x = data_x[rand_idx[n * batch_size : (n+1) * batch_size]]
      batch_y = data_y[rand_idx[n * batch_size : (n+1) * batch_size]]

      # use the model for a prediction and calculate loss
      net_y = net.board_cnn(batch_x)
      loss = lossfcn(net_y, data_y)

      # backpropagate
      loss.backward()
      optim.step()
      optim.zero_grad()

      if n % 1000 == 0:
        print(f"Loss is {loss.item():.3f}, epoch {i + 1}, batch {n + 1} / {num_batches}")

  return net

In [None]:
net = BoardCNN()
device = "cpu"

trained_net = train(net, dataset.boards, dataset.evals, device=device)

modelsaver = ModelSaver("/home/luke/chess/python/models/")

modelsaver.save("eval_model", trained_net)

Starting epoch 1. There will be 3776 batches
Loss is 5.150, epoch 1, batch 0 / 3776


  return F.mse_loss(input, target, reduction=self.reduction)


Loss is 5.112, epoch 1, batch 1000 / 3776
