In [1]:
import json
import chess
import zss
from tqdm import tqdm
import pickle
from main import ChessTreeNode, expand_tree

In [2]:
trees = {}
with open("all_trees", "r") as f:
    for idx, line in enumerate(tqdm(f.readlines())):
        puzzle_id, tree = json.loads(line)
        if len(tree) == 0:
            continue
        trees[puzzle_id] = ChessTreeNode(*tree[0])

    with open("tree_dict.pkl", "wb") as f:
        pickle.dump(trees, f)

100%|██████████| 3774123/3774123 [1:41:23<00:00, 620.35it/s]   


In [2]:
def compare(t1, t2):
    return zss.simple_distance(t1, t2, ChessTreeNode.get_children, ChessTreeNode.get_label, ChessTreeNode.compare)

def treeify_fen(fen):
    board = chess.Board(fen)
    if board.turn == chess.BLACK:
        board.apply_mirror()

    stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish", timeout=None)
    try:
        return ChessTreeNode(*expand_tree(board.fen(), chess.Move.null(), stockfish, True)[0])
    finally:
        stockfish.quit()

In [3]:
print(treeify_fen("3r4/1p2ppkp/p5p1/4P3/4NP2/1P2n2P/P1P2RP1/7K b - - 1 27").flip_san())

Rd1+ 1/0/0/1
  Kh7 0/2/0/0
    Nf8+ 1/0/1/1
      Rxf8 1/1/1/0
        Rxf8 1/0/0/0
      Kg8 1/2/0/0
        Ng6+ 1/0/1/0
          Kh7 1/2/0/0
        Nd7+ 2/0/2/1
          Kh7 0/2/0/0
        Ne6+ 2/1/0/0
          Kh7 0/2/0/0
      Kh8 0/1/0/0
        Nd7+ 2/0/2/1
          Kh7 0/2/0/0
        Ng6+ 2/0/1/0
          Kh7 1/2/0/0

