In [12]:
from dataset_management import DatasetManager, SimpleEncoder, EncodingType
from typing import List

In [13]:
import chess
import chess.pgn
import sys

In [14]:
import tqdm
import random
import itertools
import os

In [15]:
def read_games(path : str) -> List[chess.pgn.Game]:
    with open(path, 'r') as file:
        print(f"Processing file {path}")
        game = chess.pgn.read_game(file)
        games = []
        nr = 0
        while game != None:
            if nr % 100 == 0:
                print(f"{nr} games processed so far...", end='\r')
            games.append(game)
            game = chess.pgn.read_game(file)
            nr += 1
        print(100*' ', end='\r')
        print("Processing finished")
        print(f"{nr} games have been loaded from the file {path}")
    return games

In [16]:
def game_equality(g1 : chess.pgn.Game, g2 : chess.pgn.Game, s : str):
    for i, (m1, m2) in enumerate(zip(g1.mainline_moves(), g2.mainline_moves())):
        assert m1.uci() == m2.uci(), \
        s + f'move number: {i}, first game: {m1.uci()}, second game: {m2.uci()}'

def test_encoder(games : List[chess.pgn.Game], encoder : EncodingType):
    for j, game in tqdm.tqdm(enumerate(games)):
        encoding = encoder.encode_pgn(game)
        decoded = encoder.decode_to_pgn(encoding)
        game_equality(game, decoded, f'game number: {j}')

In [17]:
def test_game_generator_correctness(games : List[chess.pgn.Game], dataset : DatasetManager, n : int):
    print(f'Running {n} tests')
    total_games_tested = 0
    for t in tqdm.tqdm(range(n)):
        l, r = random.randint(0, len(games)-1), random.randint(0, len(games)-1)
        if r < l:
            l, r = r, l
        for i, game in enumerate(dataset.games(l, r)):
            original_game = games[l+i]
            game_equality(original_game, game, f'interval: [{l}, {r}], game number: {l+i}')
        total_games_tested += r-l+1
        tqdm.tqdm.write(f'{total_games_tested} games checked so far')
    print(f'Finished, {n} tests passed, {total_games_tested} games checked in total')
            

In [18]:
def test_position_selector_correctness(games : List[chess.pgn.Game], dataset : DatasetManager, n : int):
    print(f'Running {n} tests')
    total_moves_on_prefix = list(itertools.accumulate(map(lambda x: len(list(x.mainline_moves())), games), initial=0))
    total_moves_tested = 0
    for t in tqdm.tqdm(range(n)):
        game_id = random.randint(0, len(games)-1)
        for i, m in enumerate(games[game_id].mainline_moves()):
            position_id = total_moves_on_prefix[game_id] + i
            move_id, g = dataset.game_by_position_id(position_id)
            assert i == move_id, "wrong move index"
            game_equality(games[game_id], g, f'position id: {position_id}')
            total_moves_tested += 1
        if t % 10 == 0:
            tqdm.tqdm.write(f'{total_moves_tested} positions checked so far')
    print(f'Finished, {n} tests passed, {total_moves_tested} positions checked in total')


In [19]:
dataset = DatasetManager('my_dataset', create_empty=True)

In [20]:
encoder = SimpleEncoder()

In [21]:
dataset.add_pgn_file(os.path.join('Lichess Elite Database', 'lichess_elite_2016-01.pgn'))

max_number_of_games: 1000


In [None]:
games = read_games(os.path.join('Lichess Elite Database', 'lichess_elite_2016-01.pgn'))

In [None]:
test_encoder(games, encoder)

In [None]:
test_position_selector_correctness(games, dataset, 1000)

In [None]:
test_game_generator_correctness(games, dataset, 10)

In [None]:
#Testing loading existings dataset
dataset = DatasetManager('my_dataset', create_empty=False)

In [None]:
test_game_generator_correctness(games, dataset, 10)

In [None]:
dataset.add_pgn_file(os.path.join('Lichess Elite Database', 'lichess_elite_2015-11.pgn'))

In [None]:
games.extend(read_games(os.path.join('Lichess Elite Database', 'lichess_elite_2015-11.pgn')))

In [None]:
test_game_generator_correctness(games, dataset, 10)