In [None]:
import numpy as np
import pandas as pd
import plotly.express as px

from os.path import getsize

In [None]:
def get_next_game():
    # Until we find a game
    while True:
        # Record game info
        buffer = []
        while True:
            line = f.readline()

            if line == "\n":
                continue

            buffer.append(line)

            # Moves always start with 1.
            if line.startswith("1."):
                return buffer

In [None]:
def params_to_dict(str_list):

    return {
        a: b.strip('"') 
        for a, b in [
            i.strip("\n").strip("[]").split(" ", 1) 
            for i in str_list
        ]
    }

In [None]:
def get_next_line_with_rating():
    while True:
        game = get_next_game()
        moves = [
            move
            for move in game[-1].split()
            if all(char not in '{}[]%.:-' for char in move)
        ]
        game = params_to_dict(get_next_game()[:-1])
        
        try:
            is_good_game = all([
                game["TimeControl"].split("+")[0] in ["600", "900"],
                abs(float(game["WhiteRatingDiff"])) <= 20,
                abs(float(game["BlackRatingDiff"])) <= 20,
                abs(int(game["WhiteElo"]) - int(game["BlackElo"])) <= 100,
                game["Termination"] in ["Normal", "Time forfeit"]
            ])
        except:
            is_good_game = False
        if not is_good_game:
            continue
        
        mean_rating = (int(game["WhiteElo"]) + int(game["BlackElo"])) // 2
        return moves, mean_rating

In [None]:
PGN_FILE = "pgn/lichess_db_standard_rated_2024-01.pgn"
print(f"PGN file size (bytes): {getsize(PGN_FILE):,}")

f = open(PGN_FILE, mode="r")

# OFFSET = 0
# f.seek(OFFSET)
# while True:
#     line = f.readline()
#     if line.startswith("1."):
#         break

In [None]:
line_count = {}
line_rating = {}

for i in range(10_000_000):
    if i % 1000 == 0:
        print(i, end='\r')

    moves, elo = get_next_line_with_rating()
    
    if not (elo < 1000):
        continue
    
    move_to = min(len(moves), 2*15)
    for i in range(1, move_to+1):
        key = " ".join(moves[:i])
        line_count[key] = line_count.get(key, 0) + 1
        line_rating[key] = line_rating.get(key, 0) + elo

In [None]:
result = {
    k: v / line_count[k]
    for k, v in line_rating.items()
    if line_count[k] >= 100
}

In [None]:
series = pd.Series(result).sort_values()

In [None]:
tree = {}
for key, value in series.sort_index().to_dict().items():
    current_node = tree
    for move in key.split(" "):
        if not (move in current_node):
            current_node[move] = {}
        current_node = current_node[move]
    current_node["mean"] = value

In [None]:
tree