In [None]:
import json
import chess
import zss
from tqdm import tqdm
import pickle
import random
from sklearn.cluster import OPTICS, cluster_optics_dbscan
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
from main import ChessTreeNode, expand_tree

In [None]:
trees = {}
with open("all_trees", "r") as f:
    for idx, line in enumerate(tqdm(f.readlines())):
        puzzle_id, tree = json.loads(line)
        if len(tree) == 0:
            continue
        trees[puzzle_id] = ChessTreeNode(*tree[0])

    with open("tree_dict.pkl", "wb") as f:
        pickle.dump(trees, f)

In [None]:
def compare(t1, t2):
    return zss.simple_distance(t1, t2, ChessTreeNode.get_children, ChessTreeNode.get_label, ChessTreeNode.compare)

def treeify_fen(fen):
    board = chess.Board(fen)
    if board.turn == chess.BLACK:
        board.apply_mirror()

    stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish", timeout=None)
    try:
        return ChessTreeNode(*expand_tree(board.fen(), chess.Move.null(), stockfish)[0])
    finally:
        stockfish.quit()

In [None]:
with open("tree_dict_short.pkl", "rb") as f:
    trees = pickle.load(f)

In [None]:
np.random.seed(0)
ids = np.random.choice(list(trees.keys()), 400)

custom_fens = {
    "backrank": "2r4k/p2Q2pp/1p6/2qp1NP1/4p2P/1P6/PP3P2/1K5R b - - 1 33"
}

for my_id, my_fen in custom_fens.items():
    ids = np.append(ids, my_id)

    flipped = False
    b = chess.Board(my_fen)
    if b.turn == chess.BLACK:
        b.apply_mirror()
        flipped = True

    stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish", timeout=None)
    trees[my_id] = ChessTreeNode(*expand_tree(b.fen(), chess.Move.null(), stockfish)[0])
    stockfish.quit()
    print(my_id)
    print(my_fen)
    print(trees[my_id].flip_san(flipped))

In [None]:
N = len(ids)
# Generate distance matrix. It's symmetric, so use that to halve running time.
d = np.zeros((N, N))
with tqdm(total=N*(N+1)//2) as pbar:
    for i in tqdm(range(N)):
        for j in range(i):
            d[i,j] = compare(trees[ids[i]], trees[ids[j]])
            pbar.update(1)
d = d + d.T


In [None]:
optics = OPTICS(metric="precomputed", min_cluster_size=5)
clust = optics.fit(d)

In [None]:
# The most similar puzzles by similarity
print("Most similar puzzles by overall distance")
for p, f in custom_fens.items():
    print(p, f)
    idx = np.where(ids==p)[0]
    closest = np.argsort(d[idx])
    for close_id in closest.squeeze()[:7]:
        print(f"{ids[close_id]}, {d[idx, close_id][0]}")
    
print()
print("Cluster of puzzles from OPTICS")
for p, f in custom_fens.items():
    print(p, f)
    idx = np.where(ids==p)[0]
    this_cluster = np.nonzero(clust.labels_ == clust.labels_[idx])
    for close_id in this_cluster[0]:
        print(f"{ids[close_id]}, {d[idx, close_id][0]}")


In [None]:
tsne = TSNE(metric="precomputed", init="random", random_state=0, perplexity=4)
y = tsne.fit_transform(d)

colours = ["r.", "g.", "b."]
labels = [(p, clust.labels_[np.where(ids==p)[0]]) for p, *_ in custom_fens.items()]
print(labels)

_, ax = plt.subplots()
# Plot our fens with special colours
for (p, c), col in zip(labels, colours):
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], col, label=f"Similar to {p}" )

# Plot the rest
for c in set(clust.labels_):
    if c in list(zip(*labels))[1]:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], "k.")

ax.legend()
