<a href="https://colab.research.google.com/github/nankivel/capstone/blob/main/training_data_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install chess

In [None]:
pip install zstandard

In [None]:
import os
import datetime
import chess
import chess.engine
import random
import numpy as np
import pydot
from tqdm import tqdm
import io
import json
import graphviz
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, SVG

import zstandard as zstd
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
!apt-get install -y stockfish

In [None]:
# help functions

def stockfish(board, depth):
  with chess.engine.SimpleEngine.popen_uci("/usr/games/stockfish") as sf:
    result = sf.analyse(board, chess.engine.Limit(depth=depth))
    score = result['score'].white().score()
  return score

def board_encoder(board):
  encoded_board = np.zeros([8,8,15]).astype(np.int8)
  fen = board.fen()
  fen_field = fen.split(' ')
  PiecePlacement = fen_field[0].split('/')
  piece_dict = {"R":0, "N":1, "B":2, "Q":3, "K":4, "P":5,
                "r":6, "n":7, "b":8, "q":9, "k":10, "p":11
                }
  for rank in range(8):
    pieces = ''
    for c in PiecePlacement[rank]:
      if c.isnumeric():
        pieces += '-'*int(c)
      else:
        pieces += c
    for file in range(8):
      if pieces[file] != '-':
        encoded_board[rank, file, piece_dict[pieces[file]]] = 1
  # plane 12 encodes all the legal moves of white
  aux = board.turn
  board.turn = chess.WHITE
  for move in board.legal_moves:
    encoded_board[7-np.unravel_index(move.to_square, (8,8))[0], np.unravel_index(move.to_square, (8,8))[1], 12] = 1
  # plane 13 encodes all the legal moves of black
  board.turn = chess.BLACK
  for move in board.legal_moves:
    encoded_board[7-np.unravel_index(move.to_square, (8,8))[0], np.unravel_index(move.to_square, (8,8))[1], 13] = 1
  board.turn = aux
  # plane 14 encodes the current player to move: white is 1, black is 0
  if fen_field[1] == 'w':
    encoded_board[:,:,14] = 1
  else:
    encoded_board[:,:,14] = 0
  return encoded_board

In [None]:
def pos_from_drive(file_name='lichess_db_eval.jsonl.zst', asset_dir='asset'):
    gauth = GoogleAuth()
    gauth.DEFAULT_SETTINGS['client_config_file'] = 'client_secret_1057507276332-5mk9ac9q22rsmtm1idlqvpraq08ar8p5.apps.googleusercontent.com.json'
    gauth.LoadCredentialsFile("mycreds.txt")
    if gauth.credentials is None:
        gauth.LocalWebserverAuth()
    elif gauth.access_token_expired:
        gauth.Refresh()
    else:
        gauth.Authorize()
    gauth.SaveCredentialsFile("mycreds.txt")
    drive = GoogleDrive(gauth)

    def find_folder_id(folder_name):
        file_list = drive.ListFile({'q': f"title='{folder_name}' and mimeType='application/vnd.google-apps.folder' and trashed=false"}).GetList()
        for file in file_list:
            if file['title'] == folder_name:
                return file['id']
        return None

    def download_zst_file_from_drive(file_title, parent_id):
        query = f"'{parent_id}' in parents and trashed=false and title='{file_title}'"
        file_list = drive.ListFile({'q': query}).GetList()
        if not file_list:
            print(f"No file found with title: {file_title}")
            return None
        file = file_list[0]
        print("loading zst file...") #3min
        file.GetContentFile(file_title)
        return file_title

    asset_folder_id = find_folder_id(asset_dir)
    if asset_folder_id is None:
        print("Asset folder not found.")
        return None

    file_path = download_zst_file_from_drive(file_name, asset_folder_id)
    if file_path is None:
        return None

    print("Processing zst file...")
    
    fen_data = {}
    counter = 0
    with open(file_path, 'rb') as f:
        decompressor = zstd.ZstdDecompressor()
        with decompressor.stream_reader(f) as reader:
            text_stream = io.TextIOWrapper(reader, encoding='utf-8')
            for line in text_stream:
                counter += 1
                data = json.loads(line)
                fen_data[counter] = data
                if counter % 1000000 == 0:
                    output_path = f"lichess_db_eval.{counter // 1000000}.json"
                    with open(output_path, 'w') as output:
                        json.dump(fen_data, output)
                    break

    print(f"Loaded {counter} FEN positions")

    with open("lichess_db_eval.1.json", 'r') as json_file:
        lichessdata = json.load(json_file)

    return lichessdata

In [None]:
lichessdata = pos_from_drive()
len(lichessdata)

In [None]:
fens = ['rn1q1rk1/pbp2ppp/1p1bp3/8/3Pp3/1P2PN2/PBPN1PPP/R2Q1RK1 w - -',
        '3R4/p4pkp/3p2p1/2pP4/3brP2/P5PP/P2B4/7K b - -',
        'rn3rk1/1b3ppp/p3pn2/1p6/1P6/1BB1PN2/1P3PPP/3RK2R w K -',
        'r1bqk2r/2ppbpp1/p1n3np/1p6/3PP3/1B3N2/PP3PPP/RNBQ1RK1 w kq -',
        'r4rk1/pp1b1ppp/2n1p3/2qp4/8/2PBP3/PP1N1PPP/R2QK2R w Q f'
        ]

boards = []
for fen in fens:
  try:
    board = chess.Board(fen=fen)
  except:
    continue
  boards.append(board_encoder(board))

In [None]:
len(boards)

In [None]:
dataset_board = []
dataset_v = []
counter = 0
for game in tqdm(lichessdata):
  counter += 1
  try:
    board = chess.Board(fen=lichessdata[game]['fen'])
    v = stockfish(board, 5)
  except:
    continue
  if v is not None:
    dataset_board.append(board_encoder(board))
    dataset_v.append(v)
  if counter % 100000 == 0:
    dataset_board = np.array(dataset_board, dtype=np.int8)
    dataset_v = np.array(dataset_v, dtype=np.int16)
    np.savez("/content/drive/MyDrive/dataset_lichess"+str(counter // 100000)+".Mar21.npz", X=dataset_board, y=dataset_v)
    dataset_board = []
    dataset_v = []