In [None]:
from edist.multiprocess import pairwise_distances_symmetric
from edist.uted import uted
from importlib import reload
from sklearn.cluster import OPTICS, cluster_optics_dbscan, DBSCAN, HDBSCAN, AgglomerativeClustering
from sklearn.manifold import TSNE, MDS
from tqdm import tqdm
import chess
import json
import main
import matplotlib.pyplot as plt
import numpy as np
import itertools
import os
import pandas as pd
import pickle
import random
import zss
reload(main)
from main import ChessTreeNode, expand_tree

In [None]:
# Download and unzip stockfish
if not os.path.exists("stockfish"):
    !wget https://github.com/official-stockfish/Stockfish/releases/latest/download/stockfish-ubuntu-x86-64-avx2.tar
    !tar -xf stockfish-ubuntu-x86-64-avx2.tar

In [None]:
def convert_to_adj_node_lists(t):
    # First we label each node with an integer
    counter = itertools.count()
    idxs = {}
    stack = [t]
    all_nodes = []
    while stack:
        n = stack.pop()
        all_nodes.append(n)
        idxs[id(n)] = next(counter)
        for c in reversed(n.children):
            stack.append(c)
            
    nodes = [None for _ in idxs]
    adj = [None for _ in idxs]

    for n in all_nodes:
        nodes[idxs[id(n)]] = n.move_attributes
        adj[idxs[id(n)]] = [idxs[id(c)] for c in n.children]

    return nodes, adj

In [None]:
# Read `all_trees`, parse to ChessTreeNode, and load to memory.
trees = {}
with open("../data/trees/condor/all", "r") as f:
    for idx, line in enumerate(tqdm(f.readlines())):
        try:
            puzzle_id, tree = json.loads(line)
        except Exception: # lol
            continue
        if len(tree) == 0:
            continue

        t = ChessTreeNode(*tree[0])
        trees[puzzle_id] = (t, convert_to_adj_node_lists(t))
    with open("tree_dict.pkl", "wb") as f:
        pickle.dump(trees, f)

In [None]:
def compare(t1, t2):
    #return zss.simple_distance(t1, t2, ChessTreeNode.get_children, ChessTreeNode.get_label, ChessTreeNode.compare)
    _, (x_nodes, x_adj) = t1
    _, (y_nodes, y_adj) = t2
    return uted(x_nodes, x_adj, y_nodes, y_adj, ChessTreeNode.compare)

def treeify_fen(fen):
    flipped = False
    board = chess.Board(fen)
    if board.turn == chess.BLACK:
        flipped = True
        board.apply_mirror()

    stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2", timeout=None)
    try:
        return ChessTreeNode(*expand_tree(board.fen(), chess.Move.null(), stockfish)[0]), flipped
    finally:
        stockfish.quit()
        
df = pd.read_csv("../data/lichess/lichess_db_puzzle.csv")
def lookup_id(puzzle_id):
    matched = df.loc[df["PuzzleId"] == puzzle_id]
    return f"{puzzle_id} {matched['GameUrl'].item()} {matched['Themes'].item()}"
    

In [None]:
with open("tree_dict.pkl", "rb") as f:
    trees = pickle.load(f)

In [None]:
custom_fens = {
    "Rook Backrank M1": "6k1/5ppp/8/8/8/8/r4PPP/1R4K1 w - - 0 1",
    "N forks RK": "8/1N6/1K6/4k1p1/2P1Pp1p/4n2P/3R2P1/8 b - - 0 49",
    "Greek gift": "r1bq1rk1/pp2nppp/1bn1p3/1N1pP3/1P6/P2B1N2/2P2PPP/R1BQK2R w KQ - 3 11",
    "Rook sacrifice, Queen+Bishop M3": "4r1k1/1b3pp1/4p3/p2r4/7R/2B1Q1PP/P1P1RP1K/1q6 w - - 0 1",
}
custom_flipped = {}
for my_id, my_fen in custom_fens.items():
    t, flip = treeify_fen(my_fen)
    trees[my_id] = (t, convert_to_adj_node_lists(t))
    custom_flipped[my_id] = flip
    print("="*50)
    print(my_id)
    print(my_fen)
    print(t.flip_san(flip))
    print("="*50)

In [None]:
# Find the closest puzzles for these fens
for my_id, my_fen in custom_fens.items():
    print("="*50)
    print(my_id)
    print(my_fen)
    print("="*50)
    my_tree = trees[my_id]
    distances = []
    for pid, t in tqdm(list(trees.items())[:1000], leave=False):
        if pid in custom_fens:
            continue
        distances.append((compare(t, my_tree), pid, t))
    distances.sort()

    for dist, pid, t in distances[:100]:
        print(f"(dist: {dist}) {lookup_id(pid)}")
        #print(t)
        print("-"*50)


In [None]:
np.random.seed(0)
ids = np.random.choice([k for k, _ in trees.items() if k not in custom_fens], 20000)
ids = np.concatenate((ids, list(custom_fens.keys())))
trees_short = [trees[k] for k in ids]

In [None]:
D = pairwise_distances_symmetric(trees_short, compare)


In [None]:
#optics = DBSCAN(metric="precomputed", eps=250, min_samples=20)
optics = HDBSCAN(metric="precomputed", min_cluster_size=7, cluster_selection_epsilon=100)
clust = optics.fit(D)
print(len(set(clust.labels_)))
print(set(clust.labels_))

In [None]:
def analyse_clustering_info(cl):
    # Total num of clusters
    print(sum(1 for l in set(cl.labels_) if l != -1), end=" ")
    # Clusters preset puzzles
    for k in custom_fens.keys():
        print(cl.labels_[np.where(ids==k)[0]].item() != -1, end=" ")
    # % of -1
    print(round((cl.labels_ == - 1).sum() / len(ids), 4), end=" ")

    ln = [(cl.labels_ == p).sum() for p in set(cl.labels_) if p != -1]
    # Average cluster size
    print(round(np.array(ln).mean(), 2), end=" ")
    # Get percentiles of cluster sizes
    for perc in [0, 25, 50, 75, 100]:
        print(int(np.percentile(ln, perc)), end=" ")
    print()

for linkage in ["average", "complete", "single"]:
    for dist in [100, 250, 500]:
        print(f"AgglomerativeClustering {linkage} {dist}", end=" ")
        analyse_clustering_info(AgglomerativeClustering(metric="precomputed", n_clusters=None, linkage=linkage, distance_threshold=dist).fit(D))

for eps in [50, 100, 250, 500]:
    for min_samples in [3, 5, 7]:
        print(f"DBSCAN {eps} {min_samples}", end=" ")
        analyse_clustering_info(DBSCAN(metric="precomputed", eps=eps, min_samples=min_samples).fit(D))
 
for eps in [0, 25, 50, 100]:
    for size in [3, 5, 7]:
        print(f"HDBSCAN {eps} {size}", end=" ")
        analyse_clustering_info(HDBSCAN(metric="precomputed", cluster_selection_epsilon=eps, min_cluster_size=size).fit(D))

In [None]:
# OPTICS reachability
space = np.arange(len(D))
reachability = clust.reachability_[clust.ordering_]
labels = clust.labels_[clust.ordering_]

_, ax = plt.subplots(figsize=(18, 6))
for i in range(len(set(clust.labels_)) - 1):
    ax.plot(space[i == labels], reachability[i == labels], ".", alpha=0.3)
    
ax.plot(space[-1 == labels], reachability[-1 == labels], "k.", alpha=0.3)

In [None]:
print("Cluster of puzzles from OPTICS")
print([(p, clust.labels_[np.where(ids==p)[0]]) for p, *_ in custom_fens.items()])

for p, f in custom_fens.items():
    print(p, f)
    idx = np.where(ids==p)[0]
    this_cluster = np.nonzero(clust.labels_ == clust.labels_[idx])
    for close_id in this_cluster[0]:
        print(f"{ids[close_id]}, {D[idx, close_id][0]}")


In [None]:
for c in set(clust.labels_):
    print(f"label {c}")
    for pid in ids[np.where(clust.labels_==c)[0]][:7]:
        if pid in custom_fens:
            continue
        print(lookup_id(pid))
    print()

In [None]:
y5 = TSNE(metric="precomputed", init="random", random_state=0, perplexity=5).fit_transform(D)
y10 = TSNE(metric="precomputed", init="random", random_state=0, perplexity=10).fit_transform(D)
y25 = TSNE(metric="precomputed", init="random", random_state=0, perplexity=25).fit_transform(D)
y50 = TSNE(metric="precomputed", init="random", random_state=0, perplexity=50).fit_transform(D)

In [None]:
#tsne = TSNE(metric="precomputed", init="random", random_state=0, perplexity=50)
#y = tsne.fit_transform(D)

clust = HDBSCAN(metric="precomputed", cluster_selection_epsilon=0, min_cluster_size=3).fit(D)
y = y5
name = "HDBSCAN, cluster selection epsilon 0, minimum samples 3 (TSNE perplexity 5)"
filename = "hdbscan_1_5.png"

colours = ["r.", "g.", "b.", "c."]
labels = [(p, clust.labels_[np.where(ids==p)[0]]) for p, *_ in custom_fens.items()]
print(labels)

_, ax = plt.subplots(figsize=(16, 12))

# Plot -1 labels with faint crosses
yc = y[clust.labels_ == -1]
ax.plot(yc[:, 0], yc[:, 1], "k.", alpha=0.1)

# Plot special colours, except -1
for (p, c), col in zip(labels, colours):
    if c == -1:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], col, label=f"{c.item()}: Cluster of {p}" )

# Plot the rest
for c in set(clust.labels_):
    if c in list(zip(*labels))[1] or c == -1:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], "x", alpha=0.3)

ax.legend()
ax.set_title(name)
plt.savefig("tsne/"+filename, dpi=250)
plt.show()
