In [None]:
import json
import chess
import zss
from tqdm import tqdm
import pickle
import random
from sklearn.cluster import OPTICS, cluster_optics_dbscan
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
from main import ChessTreeNode, expand_tree

In [None]:
trees = {}
with open("all_trees", "r") as f:
    for idx, line in enumerate(tqdm(f.readlines())):
        puzzle_id, tree = json.loads(line)
        if len(tree) == 0:
            continue
        trees[puzzle_id] = ChessTreeNode(*tree[0])

    with open("tree_dict.pkl", "wb") as f:
        pickle.dump(trees, f)

In [None]:
def compare(t1, t2):
    return zss.simple_distance(t1, t2, ChessTreeNode.get_children, ChessTreeNode.get_label, ChessTreeNode.compare)

def treeify_fen(fen):
    board = chess.Board(fen)
    if board.turn == chess.BLACK:
        board.apply_mirror()

    stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish", timeout=None)
    try:
        return ChessTreeNode(*expand_tree(board.fen(), chess.Move.null(), stockfish)[0])
    finally:
        stockfish.quit()

In [None]:
with open("tree_dict_short.pkl", "rb") as f:
    trees = pickle.load(f)

In [None]:
np.random.seed(0)
ids = np.random.choice(list(trees.keys()), 50)

ids = np.append(ids, "my_puzzle")
my_fen = "2r4k/p2Q2pp/1p6/2qp1NP1/4p2P/1P6/PP3P2/1K5R b - - 1 33"

flipped = False
b = chess.Board(my_fen)
if b.turn == chess.BLACK:
    b.apply_mirror()
    flipped = True

stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish", timeout=None)
trees["my_puzzle"] = ChessTreeNode(*expand_tree(b.fen(), chess.Move.null(), stockfish)[0])
stockfish.quit()
print(trees["my_puzzle"].flip_san(flipped))

In [None]:
optics = OPTICS(metric="precomputed")
tsne = TSNE(metric="precomputed", init="random", random_state=0)

N = len(ids)
# Generate distance matrix. It's symmetric, so use that to halve running time.
d = np.zeros((N, N))
for i in range(N):
    for j in range(i):
        d[i,j] = compare(trees[ids[i]], trees[ids[j]])
d = d + d.T


In [None]:
clust = optics.fit(d)

In [None]:
y = tsne.fit_transform(d)

In [None]:
colours = ["r.", "g.", "b."]
labels = [clust.labels_[np.where(ids==p)[0]] for p in ["my_puzzle"]]

_, ax = plt.subplots()
for c, col in zip(labels, colours):
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], col)

for c in set(clust.labels_):
    if c in labels:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], "k.")
