In [1]:
cd ../..

c:\Users\Josu\Documents\Workspace\Human-Benchmark\src\server\model


In [5]:
import asyncio
import glob
import pathlib
import threading
from typing import Any, Awaitable

import aiofiles
import numpy as np
from chess import engine, pgn
from scripts.utils import bitboard, tanh

In [3]:
async def worker(path: pathlib.PurePath, semaphore: asyncio.Semaphore, *, checkpoint: int = 10000) -> None:
    async with semaphore:
        kwds = {"X": [], "y": []}
        _, simple_engine = await engine.popen_uci(
            "lib/stockfish-11-win/stockfish-11-win/Windows/stockfish_20011801_x64"
        )
        async with aiofiles.open(path) as f:
            i = 0
            save_checkpoint = lambda: np.savez(
                f"data/npz/{path.stem}", **kwds
            )
            while True:
                game = pgn.read_game(f._file)
                if not game:
                    break
                mainline = game.mainline()
                if mainline:
                    board = np.random.choice(list(mainline)).board()
                    kwds["X"].append(
                        bitboard(board)
                    ),
                    kwds["y"].append(
                        tanh((await simple_engine.analyse(
                            board,
                            engine.Limit(time=0.1)
                        ))["score"].relative.score(mate_score=32768))
                    )
                    i += 1
                    if i % checkpoint == 0:
                        save_checkpoint()
        await simple_engine.quit()
        save_checkpoint()

In [4]:
async def main() -> None:
    semaphore = asyncio.Semaphore(3)
    await asyncio.gather(*(
        asyncio.ensure_future(
            worker(pathlib.PurePath(file), semaphore)
        ) for file in glob.glob("data/*.pgn")
    ))

In [5]:
def run(main: Awaitable[Any], *, debug: bool = False) -> None:
    asyncio.set_event_loop_policy(
        engine.EventLoopPolicy()
    )
    asyncio.run(
        main,
        debug=debug
    )

In [6]:
threading.Thread(
    target=run,
    args=(main(),)
).start()