In [1]:
cd ../..

c:\Users\Josu\Documents\Workspace\Human-Benchmark\src\server\model


In [2]:
import asyncio
import glob
import pathlib
import random
import threading
from typing import Any, Awaitable, Optional, Union

import chess
import numpy as np
from chess import engine, pgn

from scripts import bitboard, moves, synchronize, tanh

In [None]:
async def fetch(file: Union[str, bytes, int], /, *, checkpoint: Optional[int] = None) -> None:
    kwds = {x: [] for x in ("X", "y_1", "y_2")}
    _, uci_protocol = await engine.popen_uci("lib/stockfish/stockfish")
    with open(file) as f:
        savez = lambda: np.savez(f"data/npz/{pathlib.PurePath(f.name).stem}", **kwds)
        while True:
            try:
                try:
                    play_result = await uci_protocol.play(
                        board := random.choice(tuple(pgn.read_game(f).mainline())).board(),
                        limit=engine.Limit(time=.1),
                        info=engine.INFO_SCORE
                    )
                except AttributeError:
                    break
                for kwd, x in zip(kwds.values(), (
                    bitboard(
                        board,
                        dtype=int
                    ),
                    moves.index(
                        (play_result.move if board.turn else chess.Move(
                            *(len(chess.SQUARES) - np.array((
                                play_result.move.from_square,
                                play_result.move.to_square
                            )) - 1),
                            promotion=play_result.move.promotion
                        )).uci()
                    ),
                    tanh(
                        play_result.info["score"].relative.score(
                            mate_score=7625
                        ),
                        k=.0025
                    )
                )):
                    kwd.append(x)
            except (AttributeError, IndexError, ValueError):
                continue
            if checkpoint and not len(kwds["X"]) % checkpoint:
                savez()
        savez()
        await uci_protocol.quit()

In [None]:
async def main() -> None:
    semaphore = asyncio.Semaphore(value=3)
    await asyncio.gather(*(
        synchronize(semaphore)(fetch)(
            file, checkpoint=10000
        ) for file in glob.glob("data/*.pgn")
    ))

In [None]:
def run(main: Awaitable[Any], /, *, debug: bool = False) -> None:
    asyncio.set_event_loop_policy(engine.EventLoopPolicy())
    asyncio.run(main, debug=debug)

In [None]:
threading.Thread(target=run, args=(main(),)).start()