# Chess Training

# Server Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
! pip install -q python-chess

In [3]:
! rm -rf  '/content/drive/MyDrive/Colab Notebooks/chess_web/output/'
! mkdir   '/content/drive/MyDrive/Colab Notebooks/chess_web/output/'
output =  '/content/drive/MyDrive/Colab Notebooks/chess_web/output/'

# Parameters

In [4]:
import logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(filename="/home/logs.log",
                    filemode="w",
                    format="%(name)s -> %(levelname)s: %(message)s",
                    level=logging.INFO)

logging.info("Hello")

# Data Selection Parameters
data_offset = 200000 # Offset in data file, there are altogether 3.5M samples
min_samp    = 125    # Minimum of either White Win, Black Win or Tie
min_elo     = 2500   # Minimum ELO for game to be considered
remove_ties = 1      # If set to 1, it does not train with ties
logging.info("data_offset={0}, min_samp={1}, min_elo={2}, remove_ties={3}"\
             .format(data_offset, min_samp, min_elo, remove_ties))


# Model Parameters
move_hist = 8
# Neural net grouping options
# "no_group"
# "no_group_b2b_conv2d"
# "group_by_move_hist"
# "group_by_piece"
group = "group_by_piece"
logging.info("history_count = {0}, group = {1}"\
             .format(move_hist, group))

# Training Parameters
learning_rate = 0.0015
weight_decay  = 0.0001
batch_size    = 5000
epochs        = 100
logging.info("learning_rate={0}, weight_decay={1}, batch_size={2}, epochs={3}"\
              .format(learning_rate, weight_decay, batch_size, epochs))

# Dual-Use Function
This function is used for training and also for prediction on the web server

In [5]:
import chess.pgn
import io
import numpy as np

# Read a PGN sequence
# Convert each move to to a board state (each piece has a layer)
# Sequence through all the moves of the PGN
def pgn_to_states(pgn):

    # Initialize
    pgn = io.StringIO(pgn)
    game = chess.pgn.read_game(pgn)
    pgn.close()
    board = game.board()
    board_states = []

    piece_to_num = {
      'P': 0,  # Pawn
      'R': 1,  # Rook
      'N': 2,  # Knight
      'B': 3,  # Bishop
      'Q': 4,  # Queen
      'K': 5}  # King

    for move in game.mainline_moves():

        # There are 6 types of pieces on an 8x8 board
        # Initialize board to all zeroes then add pieces per layer
        # as either +1 or -1
        board_state = np.zeros((6, 8, 8), dtype=np.int8)
        board.push(move)

        for row in range(8):
          for col in range(8):
            piece_type = board.piece_at(chess.square(row, col))
            if piece_type:
              piece = str(piece_type)
              color = int(piece.isupper()) # Upper is White
              layer = piece_to_num[piece.upper()] # Piece defines layer
              board_state[layer, 7-col, row] = color*2-1 # White=+1, Black=-1
        board_states.append(board_state)

    return board_states

# Data Cleaning and Downsampling
Because dataset is big, downsampling happens during cleaning to ensure that cleaning is not wasted

In [6]:
import pandas as pd
from sklearn.utils import resample

'''
=============================
Dataset Description
=============================
[0]  Running index per game
[1]  Date at which the game was played (the format is year.month.day).
[2]  Game result specified inside brackets in the PGN file.
       - The value can be 1-0, 1/2-1/2 or 0-1
       - Corresponding to white win, draw or loose, respectively.
[3]  ELO of white player (an integer number).
[4]  ELO of black player (an integer number).
[5]  Number of moves in the game (for some games it may be zero!)
[6]  date_c = date (in year.month.day) is corrupted or missing?
       - The label should be date_true, meaning the date is corrupted
       - Or date_false, meaning the date is NOT corrupted.
[7]  resu_c = result (1-0, 1/2-1/2, or 0-1) is corrupted or missing?
[8]  welo_c = withe ELO is corrupted or missing?
[9]  belo_c = black ELO is corrupted or missing?
[10] edate_c = event date is corrupted or missing?
[11] setup may be setup_true or setup_false
       - If it is true then the game initial position is specified.
[12] fen may be fen_true and fen_false. It is related to column 12.
[13] In the original file the result is provided in two places.
       - At the end of each sequence of moves and in the attributes part.
       - This flag indicates if the result is (is not) properly provided
          after the sequence of moves
[14] oyrange means out of year range
       - This flag is false only for games with dates in the
         range of years from 1998 to 2007
[15] bad_len (or bad len) flag
       - Indicates, when blen_true, if the length of the game is not good
[###] After the token ###, you can find the sequence of moves.
       - Each move has a number and a letter W (white) or B (black)
         indicating the nth-move of the white or black player, respectively.
'''

file = open("/content/drive/MyDrive/Colab Notebooks/chess_web/data.txt", "r")
data = file.readlines()[5+data_offset:]
print ("Length of file", len(data))

df = pd.DataFrame(columns=['win', 'pgn'])
pd.set_option('max_colwidth', 2000)

def winner_to_number(winner):
  match winner:
    case "1-0"    : return "1"  # White Win
    case "0-1"    : return "-1" # Black Win
    case "1/2-1/2": return "0"  # Tie
    case default  : return "Invalid"

num_wwin  = 0
num_bwin  = 0
num_tie   = 0
num_proc  = 0
num_valid = 0

# processed_cache = ""
for line in data:

  if (num_wwin >= min_samp and \
      num_bwin >= min_samp):
    if (remove_ties==1):
      if (num_tie >= min_samp):
        break
    else:
      break

  num_proc += 1
  # Before the ### is supplemental game information
  # After  the ### is the move sequence (not in PGN notation)
  data_info, data_moves = line.split(" ### ")
  data_info = data_info.strip().split(" ")

  # Remove all invalid entries
  win     = winner_to_number(data_info[2])
  bad_len = data_info[15]
  welo    = data_info[3]
  belo    = data_info[4]
  if (win     =="Invalid")  : continue
  if (welo    =="None")     : continue
  if (belo    =="None")     : continue
  if (bad_len =="blen_true"): continue
  if (remove_ties==1 and win=="0"): continue
  win = int(win) # Need this as a number for training

  # Remove entries where ELO is too low to be meaningful in eval
  welo   = int(welo)
  belo   = int(belo)
  if (welo < min_elo): continue
  if (belo < min_elo): continue

  # All checks passed
  num_valid += 1
  if   (win== 1): num_wwin += 1
  elif (win==-1): num_bwin += 1
  elif (win== 0): num_tie  += 1

  # Clean-up the PGN
  data_moves = data_moves.strip()
  data_moves = data_moves.split(" ")
  pgn_moves = []
  for i in range(0, len(data_moves), 2):
    num_move = i // 2 + 1
    wmove = data_moves[i].split('.')[1]
    bmove = data_moves[i+1].split('.')[1] \
                 if i+1 < len(data_moves) and \
                 '.' in data_moves[i+1] else ''
    pgn_moves.append(f"{num_move}.{wmove} {bmove}")
  pgn = " ".join(pgn_moves)

  # Append to dataframe
  # df = df.append({"win": win, "pgn": pgn}, ignore_index=True)
  #df = pd.DataFrame()
  df_row = pd.DataFrame([[win, pgn]], columns=['win', 'pgn'])
  df = pd.concat([df, df_row], ignore_index=True)

  if (num_proc%1000==0):
    print ("num_proc, num_valid, num_wwin, num_bwin, num_tie = ",
            num_proc, num_valid, num_wwin, num_bwin, num_tie)

print ("num_proc, num_valid, num_wwin, num_bwin, num_tie = ",
        num_proc, num_valid, num_wwin, num_bwin, num_tie)

# Delete data from file from memory to keep RAM usage low
file.close()
del data

print("Before Resampling")
print(df['win'].value_counts())

# Equalize the bins for fair training
df_wwin   = resample(df[df["win"]== 1], replace=False,
                     n_samples=min_samp, random_state=10)
df_bwin   = resample(df[df["win"]==-1], replace=False,
                     n_samples=min_samp, random_state=10)
if (remove_ties==0):
  df_tie  = resample(df[df["win"]== 0], replace=False,
                     n_samples=min_samp, random_state=10)
  df = pd.concat([df_wwin, df_bwin, df_tie])
  del df_wwin, df_bwin, df_tie
else:
  df = pd.concat([df_wwin, df_bwin])
  del df_wwin, df_bwin

print("After Resampling")
print(df['win'].value_counts())
df.to_csv(output+'training_data.csv')
#print(df.head())

Length of file 3361470
num_proc, num_valid, num_wwin, num_bwin, num_tie =  16000 2164 1093 1071 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  20000 2696 1344 1352 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  22000 2954 1475 1479 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  27000 3558 1761 1797 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  40000 5131 2507 2624 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  49000 6282 3061 3221 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  52000 6691 3249 3442 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  59000 7527 3643 3884 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  65000 8183 3951 4232 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  68000 8554 4105 4449 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  72000 9030 4329 4701 0
num_proc, num_valid, num_wwin, num_bwin, num_tie =  3361470 9973 4765 5208 0
Before Resampling
-1    5208
 1    4765
Name: win, dtype: int64
After Resam

In [7]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

train_data = []
true_data  = []
num_proc = 0
done = 0

for idx, row in df.iterrows():
    num_proc += 1

    # Create states ouf of the pgn sequence
    # Each pgn move contains up to two moves
    # The last move can be only one if checkmate (#)
    states = pgn_to_states(row['pgn'])
    winner = row['win']

    # Create an array with "move_hist" states
    # that serves as the basis for future state extensions
    state_hist = [np.zeros((6, 8, 8), dtype=np.int8) for i in range(move_hist)]

    # Parse through all the PGN based board states
    # Create arrays of PGN states
    for state in states:
      true_data.append(winner)
      state_hist.pop(0)
      state_hist.append(state)
      state_save = np.array(state_hist).reshape((6*move_hist, 8, 8))
      if (group=='group_by_piece'):
        state_shuffle = np.zeros((6*move_hist, 8, 8), dtype=np.int8)
        for i in range(6):
          for j in range(move_hist):
            state_shuffle[j+i*move_hist, :, :] = state_save[i+j*6, :, :]
        train_data.append(state_shuffle)
      else:
        train_data.append(state_hist)

    if (num_proc%500==0):
      print ("num_proc = ", num_proc);

del df

train_data = torch.tensor(train_data, dtype=torch.int8)
true_data  = torch.tensor(true_data , dtype=torch.int8)
print(train_data.shape)
print(true_data.shape)

# Split data into training and validation sets
train_input, val_input, train_true, val_true = train_test_split( \
    train_data, true_data, test_size=0.1, random_state=10, shuffle=True)
del train_data, true_data

train_input   = train_input.float()
train_true    = train_true.float().unsqueeze(1)
train_dataset = TensorDataset(train_input, train_true)
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
del train_input, train_true, train_dataset

val_input     = val_input.float()
val_true      = val_true.float().unsqueeze(1)
val_dataset   = TensorDataset(val_input, val_true)
val_loader    = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
del val_input, val_true, val_dataset


  train_data = torch.tensor(train_data, dtype=torch.int8)


torch.Size([22349, 48, 8, 8])
torch.Size([22349])


# Model Definition

In [8]:
import torch.nn as nn
import torch.nn.functional as F

class ChessNN(nn.Module):
  def __init__(self, move_hist, group): # -> None:
      super().__init__()
      self.move_hist = move_hist
      self.group = group
      if (group=='no_group'):
        # Do not group
        self.convs = nn.Sequential(
          nn.Conv2d(6*move_hist, 6, kernel_size=3, padding=1),
          nn.BatchNorm2d(6),
          nn.Flatten(),
          nn.Linear(6*8*8, 64),
          nn.LeakyReLU(),
          nn.Linear(64, 1),
          nn.Tanh())
      elif (group=='no_group_b2b_conv2d'):
        # Back-to-back Conv2D
        self.convs = nn.Sequential(
          nn.Conv2d(6*move_hist, 6*move_hist, kernel_size=3, padding=1),
          nn.BatchNorm2d(6*move_hist),
          nn.Conv2d(6*move_hist, 6, kernel_size=3, padding=1),
          nn.BatchNorm2d(6),
          nn.Flatten(),
          nn.Linear(6*8*8, 64),
          nn.LeakyReLU(),
          nn.Linear(64, 1),
          nn.Tanh())
      elif (group=='group_by_move_hist'):
        # Group by move history
        # Each group contains all pieces of a move
        # move_hist number of groups
        self.convs = nn.Sequential(
          nn.Conv2d(6*move_hist, move_hist, kernel_size=3, groups=move_hist, padding=1),
          nn.BatchNorm2d(move_hist),
          nn.Flatten(),
          nn.Linear(move_hist*8*8, 64),
          nn.LeakyReLU(),
          nn.Linear(64, 1),
          nn.Tanh())
      elif (group=='group_by_piece'):
        # Group by piece
        # Each group contains move history of a piece
        # 6 groups for 6 piece types
        self.convs = nn.Sequential(
          nn.Conv2d(6*move_hist, 6, kernel_size=3, groups=6, padding=1),
          nn.BatchNorm2d(6),
          nn.Flatten(),
          nn.Linear(6*8*8, 64),
          nn.LeakyReLU(),
          nn.Linear(64, 1),
          nn.Tanh())

  def forward(self, board_state_hist):
      # Remap all piece positions next to each other
      # for use in Conv2d
      board_state_hist_remap = []
      if (self.group=='group_by_piece'):
        for mpiece in range(6):
            board_state_hist_remap.append(\
            board_state_hist\
            [:, mpiece*self.move_hist:(mpiece+1)*self.move_hist, :, :])
        board_state_hist = torch.cat(board_state_hist_remap, dim=1)
      else:
        for mstate in range(self.move_hist):
            board_state_hist_remap.append(\
            board_state_hist\
            [:, mstate*6:(mstate+1)*6, :, :])
        board_state_hist = torch.cat(board_state_hist_remap, dim=1)
      return self.convs(board_state_hist)

  def predict(self, board_state_hist):
      board_state_hist = board_state_hist.reshape(1, 6*self.move_hist, 8, 8)
      return self.forward(board_state_hist).item()

In [9]:
from torchsummary import summary
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

input_shape = (move_hist*6, 8, 8)
summary(ChessNN(move_hist, group).to(device), input_shape)

Using device: cpu
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1              [-1, 6, 8, 8]             438
       BatchNorm2d-2              [-1, 6, 8, 8]              12
           Flatten-3                  [-1, 384]               0
            Linear-4                   [-1, 64]          24,640
         LeakyReLU-5                   [-1, 64]               0
            Linear-6                    [-1, 1]              65
              Tanh-7                    [-1, 1]               0
Total params: 25,155
Trainable params: 25,155
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.01
Params size (MB): 0.10
Estimated Total Size (MB): 0.12
----------------------------------------------------------------


#Training

In [10]:
import matplotlib.pyplot as plt

train_losses   = []
val_losses     = []
val_accuracies = []
model          = ChessNN(move_hist, group).to(device)
loss_fn        = nn.MSELoss().to(device)
optimizer      = torch.optim.AdamW(model.parameters(),
                                   lr=learning_rate,
                                    weight_decay=weight_decay)

for epoch in range(epochs):

  # Training
  train_loss_batch = []
  for batch_idx, (train_input, train_true) in enumerate(train_loader):
    model.train()
    train_input = train_input.to(device, dtype=torch.float32)
    train_true  = train_true.to(device, dtype=torch.float32)
    optimizer.zero_grad()
    train_pred = model(train_input)
    train_loss = loss_fn(train_pred, train_true)
    train_loss.backward()
    optimizer.step()
    train_loss_batch.append(train_loss.item())

  train_losses.extend(train_loss_batch)

  # Validation
  val_accuracy = 0
  val_total = 0

  with torch.no_grad():
    model.eval()
    val_loss_batch = []
    for val_input, val_true in val_loader:
      val_input = val_input.to(device, dtype=torch.float32)
      val_true  = val_true.to(device, dtype=torch.float32)
      val_pred = model(val_input)
      val_loss = loss_fn(val_pred, val_true)
      val_loss_batch.append(val_loss.item())

      val_total += torch.flatten(val_true).shape[0]
      val_accuracy += (torch.flatten(val_pred).round() ==
                       torch.flatten(val_true)).sum().item()

  val_losses.extend(val_loss_batch)
  val_accuracies.append(val_accuracy/val_total)

  print(f'Epoch: {epoch+1}/{epochs}, \
          Train Loss: {np.mean(train_loss_batch):.4f}, \
          Val Loss: {np.mean(val_loss_batch):.4f}, \
          Val Accuracy: {val_accuracy/val_total:.4f}')

  logging.info("Epoch: {0}".format(epoch))
  logging.info("Train Loss: {0}".format(np.mean(train_loss_batch)))
  logging.info("Val Loss: {0}".format(np.mean(val_loss_batch)))
  logging.info("Val Accuracy: {0}".format(val_accuracy/val_total))

  # Loss and Accuracy plot
  plt.figure(figsize=(12, 6))
  plt.plot(train_losses, label='Training Loss')
  val_plot = np.linspace(0, 1, len(val_losses)) * len(train_losses)
  plt.plot(val_plot, val_losses, label='Validation Loss')
  val_plot = np.linspace(0, 1, len(val_accuracies)) * len(train_losses)
  plt.plot(val_plot, val_accuracies, label='Validation Accuracy')
  plt.title('Loss and Accuracy Over Epochs')
  plt.xlabel('Epochs')
  plt.ylabel('Loss and Accuracy')
  plt.legend()
  if (epoch+1==epochs):
    plt.savefig(output+"Results-Epoch" + str(epoch) + ".png")
  plt.show()


Output hidden; open in https://colab.research.google.com to view.

#Model Save

In [11]:
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save('/content/drive/MyDrive/Colab Notebooks/chess_web/output/model_scripted.pt') # Save
! cp -a /home/logs.log '/content/drive/MyDrive/Colab Notebooks/chess_web/output/'