In [1]:
# install the necessary library
# uncomment as needed

#pip install chess
#pip install zstandard
#pip install pydrive

In [2]:
import os
import datetime
import chess
import chess.engine
import random
import numpy as np
from tqdm import tqdm
import io
import json
import zstandard
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive

In [31]:
# helper function: download file from Google Drive

def download_from_drive(file_name='lichess_db_eval.jsonl.zst', asset_dir='Mengqi_Input'):
    gauth = GoogleAuth()
    gauth.DEFAULT_SETTINGS['client_config_file'] = 'client_secret_1057507276332-5mk9ac9q22rsmtm1idlqvpraq08ar8p5.apps.googleusercontent.com.json'
    gauth.LoadCredentialsFile("mycreds.txt")
    if gauth.credentials is None:
        gauth.LocalWebserverAuth()
    elif gauth.access_token_expired:
        gauth.Refresh()
    else:
        gauth.Authorize()
    gauth.SaveCredentialsFile("mycreds.txt")
    drive = GoogleDrive(gauth)

    def find_folder_id(folder_name):
        file_list = drive.ListFile({'q': f"title='{folder_name}' and mimeType='application/vnd.google-apps.folder' and trashed=false"}).GetList()
        for file in file_list:
            if file['title'] == folder_name:
                return file['id']
        return None

    def download_zst_file_from_drive(file_title, parent_id):
        query = f"'{parent_id}' in parents and trashed=false and title='{file_title}'"
        file_list = drive.ListFile({'q': query}).GetList()
        if not file_list:
            print(f"No file found with title: {file_title}")
            return None
        file = file_list[0]
        print("loading file...") #3min
        file.GetContentFile(file_title)
        return file_title

    asset_folder_id = find_folder_id(asset_dir)
    if asset_folder_id is None:
        print("Asset folder not found.")
        return None

    file_path = download_zst_file_from_drive(file_name, asset_folder_id)
    if file_path is None:
        return None

    print("Downloaded file: {}".format(file_path))

In [32]:
download_from_drive(file_name='lichess_db_eval.jsonl.zst', asset_dir='Mengqi_Input')

loading file...
Downloaded file: lichess_db_eval.jsonl.zst


In [None]:
# uncomment the next block of code to load the input data from the original
# ZST file and dump as JSON file.
# saves a JSON file per 1M entries
# take a while...

#     fen_data = {}
#     counter = 0
#     with open(file_path, 'rb') as f:
#         decompressor = zstandard.ZstdDecompressor()
#         with decompressor.stream_reader(f) as reader:
#             text_stream = io.TextIOWrapper(reader, encoding='utf-8')
#             for line in text_stream:
#                 counter += 1
#                 data = json.loads(line)
#                 fen_data[counter] = data
#                 if counter % 1000000 == 0:
#                     output_path = f"lichess_db_eval.{counter // 1000000}.json"
#                     with open(output_path, 'w') as output:
#                         json.dump(fen_data, output)
#                     break

#     print(f"Loaded {counter} FEN positions")

In [33]:
# other than loading the original ZST file and dump the JSON,
# here loads a JSON file pre-processed from the code block above

download_from_drive(file_name='lichess_db_eval.1.json', asset_dir='Mengqi_Input')

with open("lichess_db_eval.1.json", 'r') as json_file:
    lichessdata = json.load(json_file)

loading file...
Downloaded file: lichess_db_eval.1.json


In [34]:
len(lichessdata)

1000000

In [36]:
# download the stockfish engine for evaluation generation

download_from_drive(file_name='stockfish-windows-x86-64-avx2.exe', asset_dir='Mengqi_Input')

loading file...
Downloaded file: stockfish-windows-x86-64-avx2.exe


In [37]:
# help functions

def stockfish(board, depth):
    with chess.engine.SimpleEngine.popen_uci("stockfish-windows-x86-64-avx2") as sf:
        result = sf.analyse(board, chess.engine.Limit(depth=depth))
        score = result['score'].white().score()
    return score

def board_encoder(board):
    encoded_board = np.zeros([8,8,15]).astype(np.int8)
    fen = board.fen()
    fen_field = fen.split(' ')
    PiecePlacement = fen_field[0].split('/')
    piece_dict = {"R":0, "N":1, "B":2, "Q":3, "K":4, "P":5,
                "r":6, "n":7, "b":8, "q":9, "k":10, "p":11
                }
    for rank in range(8):
        pieces = ''
        for c in PiecePlacement[rank]:
            if c.isnumeric():
                pieces += '-'*int(c)
            else:
                pieces += c
        for file in range(8):
            if pieces[file] != '-':
                encoded_board[rank, file, piece_dict[pieces[file]]] = 1
    # plane 12 encodes all the legal moves of white
    aux = board.turn
    board.turn = chess.WHITE
    for move in board.legal_moves:
        encoded_board[7-np.unravel_index(move.to_square, (8,8))[0], np.unravel_index(move.to_square, (8,8))[1], 12] = 1
    # plane 13 encodes all the legal moves of black
    board.turn = chess.BLACK
    for move in board.legal_moves:
        encoded_board[7-np.unravel_index(move.to_square, (8,8))[0], np.unravel_index(move.to_square, (8,8))[1], 13] = 1
    board.turn = aux
    # plane 14 encodes the current player to move: white is 1, black is 0
    if fen_field[1] == 'w':
        encoded_board[:,:,14] = 1
    else:
        encoded_board[:,:,14] = 0
    return encoded_board

In [None]:
# generate the input data:
# X: encoded chess board
# y: stockfish evaluation
# save a .npz file every 100k entries

dataset_board = []
dataset_v = []
counter = 0
for game in tqdm(lichessdata):
    counter += 1
    try:
        board = chess.Board(fen=lichessdata[game]['fen'])
        v = stockfish(board, 5)
    except:
        continue
    if v is not None:
        dataset_board.append(board_encoder(board))
        dataset_v.append(v)
    if counter % 100000 == 0:
        dataset_board = np.array(dataset_board, dtype=np.int8)
        dataset_v = np.array(dataset_v, dtype=np.int16)
        np.savez("dataset_lichess"+str(counter // 100000)+".Mar21.npz", X=dataset_board, y=dataset_v)
        dataset_board = []
        dataset_v = []