In [1]:
# install library as needed
#pip install chess

In [2]:
import os
import datetime
import chess
import chess.engine
import random
import numpy as np
import math
from tqdm import tqdm
import io
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, SVG
from sklearn.preprocessing import MinMaxScaler
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive

In [3]:
# helper function: download file from Google Drive

def download_from_drive(file_name='lichess_db_eval.jsonl.zst', asset_dir='Mengqi_Input'):
    gauth = GoogleAuth()
    gauth.DEFAULT_SETTINGS['client_config_file'] = 'client_secret_1057507276332-5mk9ac9q22rsmtm1idlqvpraq08ar8p5.apps.googleusercontent.com.json'
    gauth.LoadCredentialsFile("mycreds.txt")
    if gauth.credentials is None:
        gauth.LocalWebserverAuth()
    elif gauth.access_token_expired:
        gauth.Refresh()
    else:
        gauth.Authorize()
    gauth.SaveCredentialsFile("mycreds.txt")
    drive = GoogleDrive(gauth)

    def find_folder_id(folder_name):
        file_list = drive.ListFile({'q': f"title='{folder_name}' and mimeType='application/vnd.google-apps.folder' and trashed=false"}).GetList()
        for file in file_list:
            if file['title'] == folder_name:
                return file['id']
        return None

    def download_zst_file_from_drive(file_title, parent_id):
        query = f"'{parent_id}' in parents and trashed=false and title='{file_title}'"
        file_list = drive.ListFile({'q': query}).GetList()
        if not file_list:
            print(f"No file found with title: {file_title}")
            return None
        file = file_list[0]
        print("loading file...") #3min
        file.GetContentFile(file_title)
        return file_title

    asset_folder_id = find_folder_id(asset_dir)
    if asset_folder_id is None:
        print("Asset folder not found.")
        return None

    file_path = download_zst_file_from_drive(file_name, asset_folder_id)
    if file_path is None:
        return None

    print("Downloaded file: {}".format(file_path))

In [4]:
# Set the random seed
np.random.seed(42)

In [6]:
# load conditioned training and validation data
download_from_drive(file_name='training_set_conditioned.npz', asset_dir='Mengqi_Input')
download_from_drive(file_name='validation_set_conditioned.npz', asset_dir='Mengqi_Input')

data = np.load('training_set_conditioned.npz')
inputdata = {'X': data['X'], 'y': data['y']}
data = np.load('validation_set_conditioned.npz')
validation_set = {'X': data['X'], 'y': data['y']}

loading file...
Downloaded file: training_set_conditioned.npz
loading file...
Downloaded file: validation_set_conditioned.npz


In [7]:
# model components

class ConvBlock(nn.Module):
    def __init__(self, channel_in=13, channel_out=256, kernel_size=3, stride=1):
        super(ConvBlock, self).__init__()
        self.channel_in = channel_in
        self.conv1 = nn.Conv2d(channel_in, channel_out, kernel_size=kernel_size,
                               stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(channel_out)

    def forward(self, s):
        if s.dtype != torch.float32:
            s = s.float()
        s = s.view(-1, self.channel_in, 8, 8)  # batch_size x channels x board_x x board_y
        s = self.conv1(s)
        s = self.bn1(s)
        s = F.relu(s)
        return s

class ResBlock(nn.Module):
    def __init__(self, channel_in=256, channel_out=256, kernel_size=3, stride=1,
                 downsample=None):
        super(ResBlock, self).__init__()
        padding = (kernel_size - 1) // 2
        self.conv1 = nn.Conv2d(channel_in, channel_out, kernel_size=kernel_size,
                               stride=stride,
                     padding=padding, bias=False)
        self.bn1 = nn.BatchNorm2d(channel_out)
        self.conv2 = nn.Conv2d(channel_out, channel_out, kernel_size=kernel_size,
                               stride=stride,
                     padding=padding, bias=False)
        self.bn2 = nn.BatchNorm2d(channel_out)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = F.relu(out)
        return out

class OutBlock(nn.Module):
    def __init__(self):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv2d(256, 1, kernel_size=1, stride=1)
        self.bn = nn.BatchNorm2d(1)
        self.fc1 = nn.Linear(8*8, 256)
        self.fc2 = nn.Linear(256, 1)

    def forward(self,s):
        v = self.conv(s)
        v = self.bn(v)
        v = F.relu(v)
        v = v.view(-1, 8*8)  # batch_size x channels x board_x x board_y
        v = self.fc1(v)
        v = F.relu(v)
        v = self.fc2(v)
        v = torch.tanh(v)

        return v

In [8]:
# 13 layers of residual blocks

class ResNet13(nn.Module):
    def __init__(self):
        super(ResNet13, self).__init__()
        self.conv = ConvBlock()
        self.resblocks = nn.ModuleList([ResBlock() for _ in range(13)])
        self.outblock = OutBlock()

    def forward(self,s):
        s = self.conv(s)
        for resblock in self.resblocks:
            s = resblock(s)
        s = self.outblock(s)
        return s

In [9]:
# initialize a normalizer for the evaluation score

normalizer = MinMaxScaler(feature_range=(-1, 1)).fit(inputdata['y'].reshape(-1,1))

In [10]:
class data_prep():
    def __init__(self, dataset): # dataset = np.array of (s, v)
        # self.X = dataset[:,0]
        # self.y = dataset[:,1]
        self.X = dataset['X']
        self.y = normalizer.transform(dataset['y'].reshape(-1,1))
        self.y = self.y.reshape(-1,)

    def __len__(self):
        return len(self.X)

    def __getitem__(self,idx):
        return self.X[idx].transpose(2,0,1), self.y[idx]

In [11]:
# validation function

def validate(net, val_loader, criterion):
    net.eval()  # Switch model to evaluation mode
    total_loss = 0.0
    num_batches = 0
    for i, data in enumerate(val_loader):
        state, value = data
        if cuda:
            state, value = state.cuda().float(), value.cuda().float()
        value = value.float()  # Convert to torch.float32
        with torch.no_grad():  # Disable gradient calculation
            value_pred = net(state)  # value_pred = torch.Size([batch, 1])
            loss = criterion(value_pred[:, 0], value)
        total_loss += loss.item()
        num_batches += 1
    average_loss = total_loss / num_batches
    return average_loss

In [12]:
def train(net, train_data, val_data, batch_size=100, epoch_start=0, epoch_stop=20, checkpoint_interval=10, early_stop_patience=5):
    cuda = torch.cuda.is_available()
    net.train()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.01)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.2)
    train_set = data_prep(train_data)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                              num_workers=2, pin_memory=False)
    val_set = data_prep(val_data)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False,
                            num_workers=2, pin_memory=False)
    losses_per_epoch = []
    val_losses_per_epoch = []
    lr = []
    patience_count = 0
    for epoch in range(epoch_start, epoch_stop):
        total_loss = 0.0
        losses_per_batch = []
        for i,data in enumerate(train_loader):
            state, value = data
            if cuda:
                state, value = state.cuda().float(), value.cuda().float()
            value = value.float()  # Convert to torch.float32
            optimizer.zero_grad()
            value_pred = net(state) # value_pred = torch.Size([batch, 1])
            loss = criterion(value_pred[:,0], value)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if i % 10 == 9:    # print every 10 mini-batches of size = batch_size
                print('Process ID: %d [Epoch: %d, %5d/ %d points] total loss per batch: %.5f' %
                      (os.getpid(), epoch + 1, (i + 1)*batch_size, len(train_set), total_loss/10))
                print("Value:",value[0].item(),value_pred[0,0].item())
                losses_per_batch.append(total_loss/10)
                total_loss = 0.0
        losses_per_epoch.append(sum(losses_per_batch)/len(losses_per_batch))
        # Validation
        val_loss = validate(net, val_loader, criterion)
        print(f'Validation MSE Loss: {val_loss:.5f}')
        val_losses_per_epoch.append(val_loss)
        scheduler.step(val_loss)
        # Print learning rate
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Current learning rate: {current_lr}")
        lr.append(current_lr)
        # Early stopping: validation loss has to reduce at least 5% per epoch
        if len(val_losses_per_epoch) >= 6:
            loss_avg = np.average(val_losses_per_epoch[-6:-1])
            if val_loss <= 0.95 * loss_avg:
                patience_count = 0  # Reset patience count
            else:
                patience_count += 1
        if patience_count >= early_stop_patience:
            np.savez("model/ResNet13_Mar28_losses_per_epoch_epoch{}.npz".format(epoch+1),
                train_loss=losses_per_epoch,
                val_loss=val_losses_per_epoch,
                lr=lr
                )
            torch.save({'state_dict': net.state_dict()},
                  "model/ResNet13_Mar28_training_checkpoint_epoch{}.ckpt".format(epoch+1))
            print(f'Early stopping! No improvement in validation loss for {early_stop_patience} epochs.')
            break
        # Save checkpoints
        if epoch % checkpoint_interval == 9:
            np.savez("model/ResNet13_Mar28_losses_per_epoch_epoch{}.npz".format(epoch+1),
                   train_loss=losses_per_epoch,
                   val_loss=val_losses_per_epoch,
                   lr=lr
                   )
            losses_per_epoch = []
            val_losses_per_epoch = []
            lr = []
            torch.save({'state_dict': net.state_dict()},
                     "model/ResNet13_Mar28_training_checkpoint_epoch{}.ckpt".format(epoch+1))

In [None]:
# train net
net = ResNet13()
cuda = torch.cuda.is_available()
if cuda:
    net.cuda()
train(net,inputdata, validation_set, epoch_stop=50)
# save results
torch.save({'state_dict': net.state_dict()}, "model/ResNet13_March28_trained.pth.tar")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Value: -0.1434878557920456 -0.13221146166324615
Process ID: 1076 [Epoch: 17, 49000/ 848835 points] total loss per batch: 0.01085
Value: 0.1743929386138916 0.17017802596092224
Process ID: 1076 [Epoch: 17, 50000/ 848835 points] total loss per batch: 0.01083
Value: 0.3421633541584015 0.10878203064203262
Process ID: 1076 [Epoch: 17, 51000/ 848835 points] total loss per batch: 0.00983
Value: -0.22516556084156036 -0.21436172723770142
Process ID: 1076 [Epoch: 17, 52000/ 848835 points] total loss per batch: 0.00934
Value: -0.14790287613868713 -0.14734572172164917
Process ID: 1076 [Epoch: 17, 53000/ 848835 points] total loss per batch: 0.01054
Value: 0.8035320043563843 0.8312468528747559
Process ID: 1076 [Epoch: 17, 54000/ 848835 points] total loss per batch: 0.01010
Value: -0.07726269215345383 0.03640192374587059
Process ID: 1076 [Epoch: 17, 55000/ 848835 points] total loss per batch: 0.01001
Value: -0.057395145297050476 -0.03977