Access to google drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


- Define dataset class
- define model architecture (cnn)
- run training loop & save

In [2]:
# import libs
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import time

In [3]:
# config
DRIVE_BASE_PATH = "/content/drive/MyDrive/chess_games/"

TRAIN_FILES = [
    os.path.join(DRIVE_BASE_PATH, "stockfished_dataset_carlsen.pt"),
    os.path.join(DRIVE_BASE_PATH, "stockfished_dataset_caruana.pt"),
    os.path.join(DRIVE_BASE_PATH, "stockfished_dataset_firouzja.pt"),
    os.path.join(DRIVE_BASE_PATH, "stockfished_dataset_karpov.pt")
]

VALID_FILE = os.path.join(DRIVE_BASE_PATH, "stockfished_dataset_kasparov.pt")

MODEL_SAVE_PATH = os.path.join(DRIVE_BASE_PATH, "sf_chess_model.pth")

#hyperparameters
EPOCHS = 50
BATCH_SIZE = 128
LEARNING_RATE = .0001

# device and check gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

using cuda


In [4]:
# dataset class
class ChessDataset(Dataset):
  def __init__(self, file_paths):
    print("loading dataset... takes time")

    self.X = [] # board tensor list
    self.y = [] # label tensor list
    total_positions = 0
    for file_path in file_paths:
      try:
        print(f"loading file {file_path}")
        data = torch.load(file_path)
        x_data = data[0]
        y_data = data[1]

        self.X.extend(x_data)
        self.y.extend(y_data)

        num_loaded = len(y_data)
        total_positions += num_loaded
        print(f"dataset loaded, {num_loaded} positions.")

        del data
        del x_data
        del y_data
      except FileNotFoundError:
        print(f"ERROR! no file {file_path}")
      except Exception as e:
        print(f"ERROR! loading dataset: {e}")
        print("check for out of memory")

  def __len__(self):
    return len(self.y)

  def __getitem__(self, idx): # stack tensors into a single tensor
    return self.X[idx], self.y[idx]


In [5]:
# model architecture (cnn)
class ChessCNN(nn.Module):
  def __init__(self):
    super(ChessCNN, self).__init__() # input (batch_size, 12, 8, 8)

    # convolutional block 1
    self.conv1 = nn.Conv2d(12, 32, kernel_size=3, padding=1) # (batch, 32, 8, 8)
    self.bn1 = nn.BatchNorm2d(32)

    # convolutional block 2
    self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # (batch, 64, 8, 8)
    self.bn2 = nn.BatchNorm2d(64)

    # convolutional block 3
    self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # (batch, 128, 8, 8)
    self.bn3 = nn.BatchNorm2d(128)

    # connected layers
    self.flatten = nn.Flatten()
    # 128 chn * 8 * 8 = 8192
    self.fc1 = nn.Linear(128 * 8 * 8, 512)
    self.dropout = nn.Dropout(.5)
    self.fc2 = nn.Linear(512, 1) # output single score

    # activation functions
    self.relu = nn.ReLU()
    self.tanh = nn.Tanh()

  def forward(self, x):
    # apply conv blocks
    x = self.relu(self.bn1(self.conv1(x)))
    x = self.relu(self.bn2(self.conv2(x)))
    x = self.relu(self.bn3(self.conv3(x)))

    x = self.flatten(x)

    # apply fc layers
    x = self.relu(self.fc1(x))
    x = self.dropout(x)
    x = self.fc2(x)

    # output activation
    x = self.tanh(x)
    return x

Loading Dataset & Training

In [6]:
try:
  print("loading train data...")
  train_dataset = ChessDataset(TRAIN_FILES)

  print("\nloading valid data...")
  valid_dataset = ChessDataset([VALID_FILE])

  if len(train_dataset) > 0 and len(valid_dataset) > 0:
    train_loader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=0, # 2
                              pin_memory=True) # true for faster gpu transfer
    valid_loader = DataLoader(valid_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=0,
                              pin_memory=True)

    print("\ndataset ready to train")

    # init model, loss and optimizer
    model = ChessCNN().to(device)
    criterion = nn.MSELoss() # mean squared error (sweet for regression)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min', # reduce lr when metric stop decreasinh
                                                     factor=.1, # reduce lr by factor 10
                                                     patience=3) # wait 3 epoch with no imrpovement before reducing

    print("\n===== TRAINING STARTS")
    start_time = time.time()

    best_valid_loss = float("inf")
    epochs_no_improve = 0
    early_stopping_patience = 10

    for epoch in range(EPOCHS):
      epoch_start_time = time.time()
      model.train()

      running_train_loss = 0.0
      num_train_batches = 0


      for i, (boards, labels) in enumerate(train_loader):
        # move data to gpu
        boards = boards.to(device)
        labels = labels.to(device)

        # forward pass
        optimizer.zero_grad()
        outputs = model(boards)

        # calculate loss
        loss = criterion(outputs, labels)

        # backward pass
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()
        num_train_batches += 1

        if (i + 1) % 1000 == 0:
          # 2.16m positions, 2.16m/128 ~ 16875 batches per epoch
          avg_train_loss_so_far = running_train_loss / (i + 1)
          print(f"[epoch {epoch + 1}, batch {i + 1:5d}] loss {avg_train_loss_so_far:.4f}")

      avg_train_loss = running_train_loss / num_train_batches

      model.eval()
      running_valid_loss = 0.0
      num_valid_batches = 0

      with torch.no_grad():
        for boards, labels in valid_loader:
          boards = boards.to(device)
          labels = labels.to(device)

          outputs = model(boards)
          loss = criterion(outputs, labels)

          running_valid_loss += loss.item()
          num_valid_batches += 1

      avg_valid_loss = running_valid_loss / num_valid_batches

      epoch_time = time.time() - epoch_start_time
      print(f">>> epoch {epoch + 1} finished in {epoch_time:.2f} seconds...")
      print(f"\navg training loss: {avg_train_loss:.4f}")
      print(f"\navg valid loss: {avg_valid_loss:.4f}")

      scheduler.step(avg_valid_loss) # update learning rate
      print(f"learning rate: {scheduler.get_last_lr()}")

      if avg_valid_loss < best_valid_loss:
        best_valid_loss = avg_valid_loss
        epochs_no_improve = 0
        print(f"valid loss improved, saving model {MODEL_SAVE_PATH}")
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
      else:
        epochs_no_improve += 1
        print(f"valid loss did not improved for {epochs_no_improve} epochs...")
        if epochs_no_improve >= early_stopping_patience:
          print(f"early stop triggered, epoch: {epoch + 1}")
          break

    total_time = time.time() - start_time
    print(f"\n===== TRAINING COMPLETE")
    print(f"total train time: {total_time:.2f} seconds")
    print(f"best val loss: {best_valid_loss:.4f}")

except Exception as e:
  print(f"\n ERROR during training {e}")
  import traceback
  traceback.print_exc()



loading train data...
loading dataset... takes time
loading file /content/drive/MyDrive/chess_games/stockfished_dataset_carlsen.pt
dataset loaded, 653423 positions.
loading file /content/drive/MyDrive/chess_games/stockfished_dataset_caruana.pt
dataset loaded, 550019 positions.
loading file /content/drive/MyDrive/chess_games/stockfished_dataset_firouzja.pt
dataset loaded, 436867 positions.
loading file /content/drive/MyDrive/chess_games/stockfished_dataset_karpov.pt
dataset loaded, 294121 positions.

loading valid data...
loading dataset... takes time
loading file /content/drive/MyDrive/chess_games/stockfished_dataset_kasparov.pt
dataset loaded, 161951 positions.

dataset ready to train

===== TRAINING STARTS
[epoch 1, batch  1000] loss 0.1225
[epoch 1, batch  2000] loss 0.1218
[epoch 1, batch  3000] loss 0.1212
[epoch 1, batch  4000] loss 0.1208
[epoch 1, batch  5000] loss 0.1208
[epoch 1, batch  6000] loss 0.1209
[epoch 1, batch  7000] loss 0.1208
[epoch 1, batch  8000] loss 0.1208
[e