In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
# !pip uninstall torch
import json
import os
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from random import randrange
from collections import OrderedDict
import time
from torch.optim.lr_scheduler import StepLR

def fen_to_features(fen):
    piece_to_index = {'r': 0, 'n': 1, 'b': 2, 'q': 3, 'k': 4, 'p': 5,
                      'P': 6, 'K': 7, 'Q': 8, 'B': 9, 'N': 10, 'R': 11}

    one_hot_board = np.zeros((13, 8, 8), dtype=np.float32)
    additional_features = np.zeros(13, dtype=np.float32)
    # additional_features = np.zeros(1, dtype=np.float32)

    fen_rows = fen.split()[0].split('/')
    for row_idx, row in enumerate(fen_rows):
        col_idx = 0
        for char in row:
            if char.isdigit():
                col_idx += int(char)
            elif char in piece_to_index:
                piece_idx = piece_to_index[char]
                one_hot_board[piece_idx, row_idx, col_idx] = 1
                col_idx += 1
    turn_index = 12
    one_hot_board[turn_index, :, :].fill(1)
    if fen.split()[1] == 'b':
#         print(f'Reached 1 {fen}')
        one_hot_board[turn_index, :, :].fill(0)
#         one_hot_board = np.rot90(one_hot_board, k = 2)
#     additional_features[0] = 1 if fen[1] == 'w' else 0
    additional_features[0:4] = [int(right in fen.split()[2]) for right in ['K', 'Q', 'k', 'q']]
    if fen.split()[3] != '-':
        en_passant_row = ord(fen.split()[3][0]) - ord('a')
        additional_features[5 + en_passant_row] = 1
    return np.concatenate([one_hot_board.flatten(), additional_features])

class EvaluationDataset(IterableDataset):
  def __init__(self, count, mode, split_ratio = 0.8):
#     self.dataset = []
        self.count = count
        self.mode = mode
        self.split_ratio = split_ratio
        self.train_count = int(self.count * self.split_ratio)
        self.validation_count = self.count - self.train_count
  def __iter__(self):
        with open('lichess_db_eval.jsonl', 'r') as file:
            if self.mode == 'train':
                limit = self.train_count
            else:
                for _ in range(self.train_count):
                    next(file)
                limit = self.validation_count

            for _ in range(limit):
                line = file.readline()
                if not line:
                    break
                yield self.process_json_line(line)

  def process_json_line(self, line):
        json_object = json.loads(line)
        fen = json_object['fen']
        pv = json_object.get('evals', [{}])[0].get('pvs', [{}])[0]
        evaluation = pv.get('cp', 0) or 0

        if 'mate' in pv and pv['mate'] is not None:
            evaluation = 50000 if pv['mate'] > 0 else -50000

        evaluation = max(min(evaluation, 2000), -2000)
        evaluation = evaluation/100
        if fen.split()[1] == 'b':
            evaluation = evaluation*-1
#         if fen:
#         print('Reached 2')
        fen_encoded = fen_to_features(fen)
        return {'fen_encoded': fen_encoded, 'eval': evaluation}

#   def __next__(self):
#     idx = randrange(self.count)
#     return self[idx]
  def __len__(self):
    return self.train_count if self.mode == 'train' else self.validation_count

In [1]:
class EvaluationModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3, batch_size=1024, input_dim=845, layer_count=8):
        super().__init__()
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        layers = [(f"linear-0", nn.Linear(input_dim, input_dim)), (f"relu-0", nn.ReLU())]
        for i in range(1, layer_count - 1):
#             input_dim = input_dim//2
            layers.append((f"linear-{i}", nn.Linear(input_dim, input_dim)))
            layers.append((f"relu-{i}", nn.ReLU()))
        layers.append((f"linear-{layer_count - 1}", nn.Linear(input_dim, 1)))
        self.seq = nn.Sequential(OrderedDict(layers))

    def forward(self, x):
        return self.seq(x)

    def training_step(self, batch, batch_idx):
        x, y_eval = batch['fen_encoded'], batch['eval']
        x = x.float()
        y_eval = y_eval.float()
        output = self(x)
        y_hat_eval = output[:, 0]
        y_hat_eval = y_hat_eval.squeeze()
        y_hat_eval = y_hat_eval.float()
#         print(type(y_hat_eval))
        eval_loss = F.l1_loss(y_hat_eval, y_eval)
        loss = eval_loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_eval = batch['fen_encoded'], batch['eval']
        x = x.float()
        y_eval = y_eval.float()
        output = self(x)
        y_hat_eval = output[:, 0]
        y_hat_eval = y_hat_eval.squeeze()
        y_hat_eval = y_hat_eval.float()
#         print(type(y_hat_eval))
        eval_loss = F.l1_loss(y_hat_eval, y_eval)
        loss = eval_loss
        self.log("validation_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        # scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        # return {
        #     'optimizer': optimizer,
        #     'scheduler': scheduler,
        # }

    def train_dataloader(self):
      # dataset = EvaluationDataset(count=3072000)
      dataset = EvaluationDataset(count=5072000, mode='train', split_ratio=0.8)
      return DataLoader(dataset, batch_size=self.batch_size, pin_memory=True)

    def val_dataloader(self):
      # dataset = EvaluationDataset(count=3072000)
#       dataset = EvaluationDataset(count=5072000, mode='validation', split_ratio=0.8)
      dataset = EvaluationDataset(count=5072000, mode='validation', split_ratio=0.8)
      return DataLoader(dataset, batch_size=self.batch_size, pin_memory=True)

configs = [
          {"layer_count": 6, "batch_size": 512},
           ]
for config in configs:
#   print("Reached 1")
  version_name = f'{int(time.time())}-batch_size-{config["batch_size"]}-layer_count-{config["layer_count"]}'
#   print("Reached 2")
  logger = pl.loggers.TensorBoardLogger("lightning_logs", name="chessml", version=version_name)
#   print("Reached 3")
  early_stop_callback = EarlyStopping(monitor='validation_loss', patience=3, verbose=True, mode='min')
#   print("Reached 4")
  checkpoint_callback = ModelCheckpoint(monitor='validation_loss', save_top_k=1, mode='min')
  lr_monitor = LearningRateMonitor(logging_interval='step')
#   print("Reached 5")
  trainer = pl.Trainer(accelerator='gpu',precision=16,max_epochs=20,callbacks=[early_stop_callback, checkpoint_callback], profiler="simple", logger=logger, log_every_n_steps=10)
  model = EvaluationModel(layer_count=config["layer_count"],batch_size=config["batch_size"],learning_rate=1e-3)
#   print("Reached 6")
  trainer.fit(model)
  break

NameError: name 'pl' is not defined

In [9]:
%reload_ext tensorboard
%tensorboard --port 6007 --logdir lightning_logs/
# !kill 6004

In [16]:
from IPython.display import display, SVG
import random

SVG_BASE_URL = "https://us-central1-spearsx.cloudfunctions.net/chesspic-fen-image/"

def svg_url(fen):
  fen_board = fen.split()[0]
  return SVG_BASE_URL + fen_board

def show_index(idx):
  count = 0
  dataset =[]
  evals=[]
  with open('lichess_db_eval.jsonl', 'r') as file:
      for i in range(idx+1):
          line = file.readline()
          if not line:
              break
      json_object = json.loads(line)
      dataset.append(json_object)

  fen = dataset[0]['fen']
  pv = dataset[0]['evals'][0]['pvs'][0]
  if (('cp' in pv and pv['cp'] is None) or ('cp' not in pv)):
      eval = 5000*pv['mate']
  else:
      eval = dataset[0]['evals'][0]['pvs'][0]['cp']

  x = torch.tensor(fen_to_features(fen))
  eval = torch.tensor([eval], dtype=torch.float32)/100
  y_hat_eval= model(x).squeeze()
  loss_eval = F.l1_loss(y_hat_eval.unsqueeze(0), eval)
  loss_eval_per = (loss_eval/eval)*100
  print(f'Loss Eval {loss_eval:.2f}')
#   print(f'LossEval% {loss_eval_per.item():.2f}%')
  print(f'eval {eval.item():.2f}')
  print(f'y_eval_hat {y_hat_eval:.2f}')
  print(f'FEN {fen}')
  if not(y_hat_eval > -1 and y_hat_eval < 1 and eval > -1 and eval < 1) and y_hat_eval*eval < 0 :
    count = count+1
  return count
#   display(SVG(url=svg_url(fen)))

for i in range(100):
  idx = random.randint(1000000, 1500000)
  count_wrong_pred = show_index(idx)/100
  print(f'Accuracy% = {count_wrong_pred}')

Loss Eval 0.07
eval 0.05
y_eval_hat -0.02
FEN r1bqkbnr/1p3ppp/p1n1p3/2pp4/B3P3/2P2N2/PP1P1PPP/RNBQK2R w KQkq -
Accuracy% = 0.0
Loss Eval 0.84
eval 0.48
y_eval_hat -0.36
FEN r1b1qrk1/ppp2pbp/n2p2p1/3Pp1B1/2P1P1n1/2N2N2/PP2BPPP/R2Q1RK1 b - -
Accuracy% = 0.0
Loss Eval 2.41
eval 3.16
y_eval_hat 0.75
FEN rnbqkb1r/pp2p2p/2p2p2/3p2pn/3P3B/4PN2/PPP2PPP/RN1QKB1R w KQkq -
Accuracy% = 0.0
Loss Eval 0.41
eval -0.22
y_eval_hat 0.19
FEN r1bqkbnr/ppp2ppp/8/4n3/4N3/5N2/PPPP2PP/R1BQKB1R b KQkq -
Accuracy% = 0.0
Loss Eval 180.10
eval 200.00
y_eval_hat 19.90
FEN 8/5k2/6RR/8/3K4/8/8/8 w - -
Accuracy% = 0.0
Loss Eval 0.13
eval 0.00
y_eval_hat -0.13
FEN r1bqk2r/pppnppbp/3p1np1/8/3PP3/3B1N2/PPPN1PPP/R1BQK2R b KQkq -
Accuracy% = 0.0
Loss Eval 1169.76
eval 1150.00
y_eval_hat -19.76
FEN 8/2R5/3K4/8/8/3k4/8/8 b - -
Accuracy% = 0.01
Loss Eval 0.37
eval -0.21
y_eval_hat 0.16
FEN rn1qkb1r/pp2pppp/2p2n2/3p1b2/2PP1B2/6P1/PP2PPBP/RN1QK1NR b KQkq -
Accuracy% = 0.0
Loss Eval 0.52
eval -0.31
y_eval_hat 0.21
FEN r1bqkbnr/

Loss Eval 1.36
eval 1.48
y_eval_hat 0.12
FEN 1rbq1rk1/3nbppp/p1n1p3/1p1pP3/2pP1P2/P1N1BN2/1PPQB1PP/R4RK1 w - -
Accuracy% = 0.0
Loss Eval 2.12
eval -1.11
y_eval_hat 1.01
FEN rnbqkb1r/pp1ppppp/5n2/8/2B1P3/2N5/PB3PPP/R2QK1NR b KQkq -
Accuracy% = 0.01
Loss Eval 1.26
eval 0.56
y_eval_hat -0.70
FEN rnbqkbnr/pp4pp/4pp2/3p4/3P1B2/4PN2/PP3PPP/RN1QKB1R b KQkq -
Accuracy% = 0.0
Loss Eval 0.20
eval 0.33
y_eval_hat 0.13
FEN rnbqkb1r/ppp2ppp/3p1n2/8/3Q4/1P6/PBP1PPPP/RN2KBNR w KQkq -
Accuracy% = 0.0
Loss Eval 0.25
eval 0.30
y_eval_hat 0.05
FEN rnbqk2r/ppp1bppp/3p1n2/4p3/8/1PN1P3/PBPP1PPP/R2QKBNR w KQkq -
Accuracy% = 0.0
Loss Eval 0.84
eval 0.85
y_eval_hat 0.01
FEN rnbqk1nr/pppp1ppp/4p3/2b5/4P2P/8/PPPP1PP1/RNBQKBNR w KQkq -
Accuracy% = 0.0
Loss Eval 0.34
eval -0.21
y_eval_hat 0.13
FEN rn1qkb1r/pp3ppp/2p1pn2/3P1b2/3P4/2N1P2P/PP3PP1/R1BQKBNR b KQkq -
Accuracy% = 0.0
Loss Eval 0.09
eval 0.20
y_eval_hat 0.29
FEN r1bqk2r/pp1pbppp/2n1pn2/2p3B1/3P4/4PN2/PPPN1PPP/R2QKB1R w KQkq -
Accuracy% = 0.0
Loss Eval 630

In [6]:
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save('model_is_it_magnoos_level.pt')

In [17]:
model_path = 'scripted.pt'
model_scripted = torch.jit.load(model_path)

In [28]:
from IPython.display import display, SVG
import random

SVG_BASE_URL = "https://us-central1-spearsx.cloudfunctions.net/chesspic-fen-image/"

def svg_url(fen):
  fen_board = fen.split()[0]
  return SVG_BASE_URL + fen_board

def show_index(idx):
  count = 0
  dataset =[]
  evals=[]
  with open('lichess_db_eval.jsonl', 'r') as file:
      for i in range(idx+1):
          line = file.readline()
          if not line:
              break
      json_object = json.loads(line)
      dataset.append(json_object)

  fen = dataset[0]['fen']
  pv = dataset[0]['evals'][0]['pvs'][0]
  if (('cp' in pv and pv['cp'] is None) or ('cp' not in pv)):
      eval = 5000*pv['mate']
  else:
      eval = dataset[0]['evals'][0]['pvs'][0]['cp']

  x = torch.tensor(fen_to_features(fen))
  eval = torch.tensor([eval], dtype=torch.float32)/100
  y_hat_eval= model_scripted(x).squeeze()
  loss_eval = F.l1_loss(y_hat_eval.unsqueeze(0), eval)
  loss_eval_per = (loss_eval/eval)*100
#   print(f'Loss Eval {loss_eval:.2f}')
#   print(f'LossEval% {loss_eval_per.item():.2f}%')
#   print(f'eval {eval.item():.2f}')
#   print(f'y_eval_hat {y_hat_eval:.2f}')
#   print(f'FEN {fen}')
  if not(y_hat_eval > -1 and y_hat_eval < 1 and eval > -1 and eval < 1) and y_hat_eval*eval < 0 :
    count = count+1
  return count
#   display(SVG(url=svg_url(fen)))
count_wrong_pred = 0
max_count = 10
for i in range(max_count):
  idx = random.randint(5072000, 5472000)
  count_wrong_pred += show_index(idx)
#   print(count_wrong_pred)
print(f'Accuracy% = {((max_count - count_wrong_pred)/max_count)*100}')

Accuracy% = 100.0
