In [1]:
import os
from os.path import join
import chess.pgn
import numpy as np
import json
from joblib import Parallel, delayed
from tqdm import tqdm
from maia_chess import load_maia_network
from irl_chess.chess_utils.sunfish_utils import board2sunfish, sunfish_move_to_str, str_to_sunfish_move, get_new_pst
from irl_chess.chess_utils.sunfish import piece, pst
from irl_chess.models.sunfish_GRW import sunfish_move
from irl_chess.misc_utils.load_save_utils import get_board_after_n, is_valid_game

In [9]:
config_lower = {'min_elo':1400, 'max_elo':1600, 'n_endgame':20}
chess_boards, moves = [], []
i = 0
n_boards = 100
n_midgame = 8
while len(chess_boards) < n_boards:
    pgn_path = 'data/960/lichess_db_chess960_rated_2024-04.pgn'
    progress = 0
    with open(pgn_path) as pgn:
        size = os.path.getsize(pgn_path)
        with tqdm(total=size, desc=f'Looking through file {i}') as pbar:
            while len(chess_boards) < n_boards:
                game = chess.pgn.read_game(pgn)
                if is_valid_game(game, config_data=config_lower):
                    board_midgame, move_midgame = get_board_after_n(game,n_midgame)
                    chess_boards.append(board_midgame)
                    moves.append(sunfish_move_to_str(move_midgame))

                pbar.update(pgn.tell() - progress)
                progress = pgn.tell()
                if size <= progress:
                    break
        i += 1

Looking through file 0:   1%|          | 3875086/746967475 [00:07<22:56, 539650.16it/s]


In [10]:
maia_high = load_maia_network(1500, parent='maia_chess/')
R_high = np.array([100,280,320,420,928,65000])
pst_high = get_new_pst(R_high)

maia_acc = 0
sf_acc = 0
for board, move in tqdm(list(zip(chess_boards, moves))):
    maia_move = maia_high.getTopMovesCP(board, 1)[0][0]
    sf_move = sunfish_move(board2sunfish(board, 0), pst_high,time_limit=0.5, move_only=True)
    sf_move = sunfish_move_to_str(sf_move)
    maia_acc += maia_move == move
    sf_acc += sf_move == move
maia_acc, sf_acc = maia_acc/n_boards, sf_acc/n_boards
print(f'Sunfish accuracy: {sf_acc}, Maia accuracy: {maia_acc}')

Model loaded successfully!


100%|██████████| 100/100 [01:34<00:00,  1.06it/s]

Sunfish accuracy: 0.24, Maia accuracy: 0.38



