In [55]:
import pandas as pd
import pickle
import io
import chess.pgn
from sklearn.metrics import classification_report
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("seaborn-v0_8-deep")

In [19]:
df = pd.read_csv("../data/lichess/lichess_db_puzzle.csv")
df["Themes"] = df["Themes"].str.split()
puzzles = df[["PuzzleId", "Themes", "Rating"]].set_index("PuzzleId").to_dict(orient="index")

In [43]:
my_pgn = io.StringIO("""1. e4 c5 2. c3 d6 3. d4 b6 4. Nf3 Bg4 5. Bb5+ Nd7 6. O-O a6 7. Bc6
Rc8 8. Bd5 b5 9. Bxf7+ Kxf7 10. Ng5+ Ke8 11. Qxg4 Ngf6 12. Qe6 Nb6 13. e5 Nfd5
14. exd6 Qxd6 15. Qxd6 exd6 16. Nd2 c4 17. Nde4 h6 18. Nf3 Kd7 19. b3 Re8 20.
Ng3 g5 21. Ba3 a5 22. bxc4 Nxc4 23. Bc1 Bg7 24. Rb1 Rb8 25. Ne2 b4 26. cxb4
axb4 27. Nd2 Nxd2 28. Bxd2 Rhe8 29. Rfe1 Re7 30. Ng3 Rxe1+ 31. Rxe1 Bxd4 32. h3
Rf8 33. Be3 Nxe3 34. fxe3 Bc5 35. Rb1 Re8 36. Ne4 Rxe4 37. a3 Bxe3+ 38. Kh2
bxa3 39. Ra1 Bc5 40. Ra2 Rb4 41. g4 Rb2+ 42. Rxb2 axb2 0-1""")
game = chess.pgn.read_game(my_pgn)

fens = []
board = game.board()
fens.append(board.fen())
for move in game.mainline_moves():
    board.push(move)
    fens.append(board.fen())

['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1',
 'rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1',
 'rnbqkbnr/pp1ppppp/8/2p5/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2',
 'rnbqkbnr/pp1ppppp/8/2p5/4P3/2P5/PP1P1PPP/RNBQKBNR b KQkq - 0 2',
 'rnbqkbnr/pp2pppp/3p4/2p5/4P3/2P5/PP1P1PPP/RNBQKBNR w KQkq - 0 3',
 'rnbqkbnr/pp2pppp/3p4/2p5/3PP3/2P5/PP3PPP/RNBQKBNR b KQkq - 0 3',
 'rnbqkbnr/p3pppp/1p1p4/2p5/3PP3/2P5/PP3PPP/RNBQKBNR w KQkq - 0 4',
 'rnbqkbnr/p3pppp/1p1p4/2p5/3PP3/2P2N2/PP3PPP/RNBQKB1R b KQkq - 1 4',
 'rn1qkbnr/p3pppp/1p1p4/2p5/3PP1b1/2P2N2/PP3PPP/RNBQKB1R w KQkq - 2 5',
 'rn1qkbnr/p3pppp/1p1p4/1Bp5/3PP1b1/2P2N2/PP3PPP/RNBQK2R b KQkq - 3 5',
 'r2qkbnr/p2npppp/1p1p4/1Bp5/3PP1b1/2P2N2/PP3PPP/RNBQK2R w KQkq - 4 6',
 'r2qkbnr/p2npppp/1p1p4/1Bp5/3PP1b1/2P2N2/PP3PPP/RNBQ1RK1 b kq - 5 6',
 'r2qkbnr/3npppp/pp1p4/1Bp5/3PP1b1/2P2N2/PP3PPP/RNBQ1RK1 w kq - 0 7',
 'r2qkbnr/3npppp/ppBp4/2p5/3PP1b1/2P2N2/PP3PPP/RNBQ1RK1 b kq - 1 7',
 '2rqkbnr/3npppp/ppBp4/2p5/3PP1b1/2P2N2/PP3PP

In [45]:
def label_error(pred, exp):
    pred_set = set(pred) - {"middlegame", "endgame", "short", "crushing", "advantage", "long"}
    exp_set = set(exp) - {"middlegame", "endgame", "short", "crushing", "advantage", "long"}
    return len(pred_set - exp_set) + len(exp_set - pred_set)

def mse_loss(pred, exp):
    return (pred - exp) ** 2

THEMES = ['advancedPawn', 'advantage', 'anastasiaMate', 'arabianMate', 'attackingF2F7', 'attraction', 'backRankMate', 'bishopEndgame', 'bodenMate', 'capturingDefender', 'castling', 'clearance', 'crushing', 'defensiveMove', 'deflection', 'discoveredAttack', 'doubleBishopMate', 'doubleCheck', 'dovetailMate', 'enPassant', 'endgame', 'equality', 'exposedKing', 'fork', 'hangingPiece', 'hookMate', 'interference', 'intermezzo', 'kingsideAttack', 'knightEndgame', 'long', 'master', 'masterVsMaster', 'mate', 'mateIn1', 'mateIn2', 'mateIn3', 'mateIn4', 'mateIn5', 'middlegame', 'oneMove', 'opening', 'pawnEndgame', 'pin', 'promotion', 'queenEndgame', 'queenRookEndgame', 'queensideAttack', 'quietMove', 'rookEndgame', 'sacrifice', 'short', 'skewer', 'smotheredMate', 'superGM', 'trappedPiece', 'underPromotion', 'veryLong', 'xRayAttack', 'zugzwang']

In [77]:
with open("transformer.pkl", "rb") as f:
    tree = pickle.load(f)
tree

{'0009B': ['advantage', 'opening', 'short'],
 '000qP': ['crushing', 'endgame'],
 '001Wz': ['backRankMate', 'endgame', 'mate', 'mateIn2', 'short'],
 '001h8': ['crushing', 'middlegame'],
 '001uD': ['advantage', 'middlegame', 'short'],
 '001xO': ['advantage', 'middlegame', 'short'],
 '001xl': ['crushing', 'endgame', 'short', 'skewer'],
 '002Ua': ['crushing', 'middlegame', 'short'],
 '003UW': ['crushing', 'endgame', 'queenEndgame', 'short'],
 '004sg': ['endgame', 'short'],
 '005qG': ['crushing', 'endgame'],
 '005ws': ['crushing', 'endgame', 'knightEndgame', 'quietMove'],
 '0061g': ['endgame', 'mate'],
 '0066C': ['advantage', 'middlegame'],
 '0068B': ['advantage', 'middlegame'],
 '006om': ['crushing', 'endgame', 'fork', 'long'],
 '007XE': ['backRankMate', 'mate', 'mateIn2', 'middlegame', 'short'],
 '007ku': ['mate', 'mateIn1', 'middlegame', 'oneMove'],
 '008Nz': ['backRankMate', 'mate', 'mateIn2', 'middlegame', 'short'],
 '008P4': ['crushing', 'endgame', 'fork', 'short'],
 '009FS': ['advant

In [78]:
tree_ids = list(tree.keys())
y_true = np.zeros((len(tree_ids), len(THEMES)))
y_pred = np.zeros((len(tree_ids), len(THEMES)))
for i, pid in enumerate(tree_ids):
    for j, t in enumerate(THEMES):
        if t in puzzles[pid]["Themes"]:
            y_true[i,j] = 1
        if t in tree[pid]:
            y_pred[i,j] = 1

print(classification_report(y_true, y_pred, target_names=THEMES, digits=4, zero_division=0))

                   precision    recall  f1-score   support

     advancedPawn     0.7044    0.4462    0.5463     43215
        advantage     0.5966    0.5898    0.5932    235529
    anastasiaMate     0.8554    0.5338    0.6574       798
      arabianMate     0.7858    0.5842    0.6702       760
    attackingF2F7     0.8693    0.7412    0.8002      4657
       attraction     0.6858    0.1767    0.2810     26566
     backRankMate     0.8236    0.8427    0.8330     24332
    bishopEndgame     0.9622    0.9774    0.9697      9476
        bodenMate     0.8045    0.5619    0.6617       315
capturingDefender     0.7120    0.0209    0.0406      6505
         castling     0.5000    0.0071    0.0141       421
        clearance     0.7768    0.0320    0.0616     10641
         crushing     0.7223    0.5972    0.6538    337162
    defensiveMove     0.5340    0.0462    0.0850     47212
       deflection     0.7243    0.0940    0.1664     32867
 discoveredAttack     0.6760    0.2641    0.3799     43