In [41]:
import pandas as pd
import pickle
import io
import chess.pgn
import matplotlib.pyplot as plt
plt.style.use("seaborn-v0_8-deep")

In [19]:
df = pd.read_csv("../data/lichess/lichess_db_puzzle.csv")
df["Themes"] = df["Themes"].str.split()
puzzles = df[["PuzzleId", "Themes", "Rating"]].set_index("PuzzleId").to_dict(orient="index")

In [43]:
my_pgn = io.StringIO("""1. e4 c5 2. c3 d6 3. d4 b6 4. Nf3 Bg4 5. Bb5+ Nd7 6. O-O a6 7. Bc6
Rc8 8. Bd5 b5 9. Bxf7+ Kxf7 10. Ng5+ Ke8 11. Qxg4 Ngf6 12. Qe6 Nb6 13. e5 Nfd5
14. exd6 Qxd6 15. Qxd6 exd6 16. Nd2 c4 17. Nde4 h6 18. Nf3 Kd7 19. b3 Re8 20.
Ng3 g5 21. Ba3 a5 22. bxc4 Nxc4 23. Bc1 Bg7 24. Rb1 Rb8 25. Ne2 b4 26. cxb4
axb4 27. Nd2 Nxd2 28. Bxd2 Rhe8 29. Rfe1 Re7 30. Ng3 Rxe1+ 31. Rxe1 Bxd4 32. h3
Rf8 33. Be3 Nxe3 34. fxe3 Bc5 35. Rb1 Re8 36. Ne4 Rxe4 37. a3 Bxe3+ 38. Kh2
bxa3 39. Ra1 Bc5 40. Ra2 Rb4 41. g4 Rb2+ 42. Rxb2 axb2 0-1""")
game = chess.pgn.read_game(my_pgn)

fens = []
board = game.board()
fens.append(board.fen())
for move in game.mainline_moves():
    board.push(move)
    fens.append(board.fen())

['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1',
 'rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1',
 'rnbqkbnr/pp1ppppp/8/2p5/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2',
 'rnbqkbnr/pp1ppppp/8/2p5/4P3/2P5/PP1P1PPP/RNBQKBNR b KQkq - 0 2',
 'rnbqkbnr/pp2pppp/3p4/2p5/4P3/2P5/PP1P1PPP/RNBQKBNR w KQkq - 0 3',
 'rnbqkbnr/pp2pppp/3p4/2p5/3PP3/2P5/PP3PPP/RNBQKBNR b KQkq - 0 3',
 'rnbqkbnr/p3pppp/1p1p4/2p5/3PP3/2P5/PP3PPP/RNBQKBNR w KQkq - 0 4',
 'rnbqkbnr/p3pppp/1p1p4/2p5/3PP3/2P2N2/PP3PPP/RNBQKB1R b KQkq - 1 4',
 'rn1qkbnr/p3pppp/1p1p4/2p5/3PP1b1/2P2N2/PP3PPP/RNBQKB1R w KQkq - 2 5',
 'rn1qkbnr/p3pppp/1p1p4/1Bp5/3PP1b1/2P2N2/PP3PPP/RNBQK2R b KQkq - 3 5',
 'r2qkbnr/p2npppp/1p1p4/1Bp5/3PP1b1/2P2N2/PP3PPP/RNBQK2R w KQkq - 4 6',
 'r2qkbnr/p2npppp/1p1p4/1Bp5/3PP1b1/2P2N2/PP3PPP/RNBQ1RK1 b kq - 5 6',
 'r2qkbnr/3npppp/pp1p4/1Bp5/3PP1b1/2P2N2/PP3PPP/RNBQ1RK1 w kq - 0 7',
 'r2qkbnr/3npppp/ppBp4/2p5/3PP1b1/2P2N2/PP3PPP/RNBQ1RK1 b kq - 1 7',
 '2rqkbnr/3npppp/ppBp4/2p5/3PP1b1/2P2N2/PP3PP

In [32]:
def label_error(pred, exp):
    pred_set = set(pred) - {"middlegame", "endgame", "short", "crushing", "advantage", "long"}
    exp_set = set(exp) - {"middlegame", "endgame", "short", "crushing", "advantage", "long"}
    return len(pred_set - exp_set) + len(exp_set - pred_set)

def mse_loss(pred, exp):
    return (pred - exp) ** 2

In [7]:
with open("tree_results.pkl", "rb") as f:
    tree = pickle.load(f)

In [33]:
total_theme_err = 0
total_diff_loss = 0
for pid, (themes, diff) in tree.items():
    total_theme_err += label_error(themes, puzzles[pid]["Themes"])
    total_diff_loss += mse_loss(diff, puzzles[pid]["Rating"])

print(total_theme_err / len(tree))
print((total_diff_loss / len(tree))**0.5)

1.039
404.47368080450127
