In [None]:
import json
import chess
import zss
from tqdm import tqdm
import pickle
import random
from sklearn.cluster import OPTICS, cluster_optics_dbscan
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
from main import ChessTreeNode, expand_tree

In [None]:
# Read `all_trees`, parse to ChessTreeNode, and load to memory.
trees = {}
with open("all_trees", "r") as f:
    for idx, line in enumerate(tqdm(f.readlines())):
        puzzle_id, tree = json.loads(line)
        if len(tree) == 0:
            continue
        trees[puzzle_id] = ChessTreeNode(*tree[0])

    with open("tree_dict.pkl", "wb") as f:
        pickle.dump(trees, f)

In [None]:
def compare(t1, t2):
    return zss.simple_distance(t1, t2, ChessTreeNode.get_children, ChessTreeNode.get_label, ChessTreeNode.compare)

def treeify_fen(fen):
    flipped = False
    board = chess.Board(fen)
    if board.turn == chess.BLACK:
        flipped = True
        board.apply_mirror()

    stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish", timeout=None)
    try:
        return ChessTreeNode(*expand_tree(board.fen(), chess.Move.null(), stockfish)[0]), flipped
    finally:
        stockfish.quit()

In [None]:
with open("tree_dict_short.pkl", "rb") as f:
    trees = pickle.load(f)

In [None]:
custom_fens = {
    "smothered M2": "r3r2k/ppp2Qbp/2bp2p1/5PN1/4P3/1B2q2P/P5PK/8 w - - 0 25",
    "backrank M2": "6k1/1rp2pp1/5n1p/2B5/2RP4/4P3/P4PPP/6K1 b - - 0 24",
    "N forks QK": "2kr3r/2p5/ppq2n1p/3Np1p1/P3P3/1P6/Q2b1PPP/3RR1K1 w - - 0 22",
    "pinned pawn defender": "4r1k1/pb3p1p/1bp1q1p1/3pB3/3P4/1NP2PQP/PP4P1/R5K1 b - - 1 20",

}
custom_flipped = {}
for my_id, my_fen in custom_fens.items():
    t, flip = treeify_fen(my_fen)
    trees[my_id], custom_flipped[my_id] = t, flip
    print("="*50)
    print(my_id)
    print(my_fen)
    print(t.flip_san(flip))
    print("="*50)

In [None]:
# Find a the top 5 best puzzles for these fens
for my_id, my_fen in custom_fens.items():
    print("="*50)
    print(my_id)
    print(my_fen)
    print("="*50)
    my_tree = trees[my_id]
    distances = []
    for pid, t in tqdm(list(trees.items())[:10000], leave=False):
        distances.append((compare(t, my_tree), pid, t))
    distances.sort()

    for dist, pid, t in distances[:5]:
        print(f"{pid} (DISTANCE {dist})")
        #print(t)
        print("-"*50)

In [None]:
np.random.seed(0)
ids = np.random.choice([k for k, _ in trees.items() if k not in custom_fens], 400)
ids = np.concatenate((ids, list(custom_fens.keys())))

In [None]:
N = len(ids)
# Generate distance matrix. It's symmetric, so use that to halve running time.
d = np.zeros((N, N))
with tqdm(total=N*(N+1)//2) as pbar:
    for i in tqdm(range(N)):
        for j in range(i):
            d[i,j] = compare(trees[ids[i]], trees[ids[j]])
            pbar.update(1)
d = d + d.T


In [None]:
optics = OPTICS(metric="precomputed", min_cluster_size=5)
clust = optics.fit(d)

In [None]:
print("Cluster of puzzles from OPTICS")
for p, f in custom_fens.items():
    print(p, f)
    idx = np.where(ids==p)[0]
    this_cluster = np.nonzero(clust.labels_ == clust.labels_[idx])
    for close_id in this_cluster[0]:
        print(f"{ids[close_id]}, {d[idx, close_id][0]}")


In [None]:
tsne = TSNE(metric="precomputed", init="random", random_state=0, perplexity=4)
y = tsne.fit_transform(d)

colours = ["r.", "g.", "b."]
labels = [(p, clust.labels_[np.where(ids==p)[0]]) for p, *_ in custom_fens.items()]
print(labels)

_, ax = plt.subplots()
# Plot our fens with special colours
for (p, c), col in zip(labels, colours):
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], col, label=f"Similar to {p}" )

# Plot the rest
for c in set(clust.labels_):
    if c in list(zip(*labels))[1]:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], "k.")

ax.legend()
