In [None]:
from edist.multiprocess import pairwise_distances_symmetric
from edist.uted import uted
from importlib import reload
from sklearn.cluster import OPTICS, cluster_optics_dbscan, DBSCAN, HDBSCAN
from sklearn.manifold import TSNE
from tqdm import tqdm
import chess
import json
import main
import matplotlib.pyplot as plt
import numpy as np
import itertools
import os
import pandas as pd
import pickle
import random
import zss
reload(main)
from main import ChessTreeNode, expand_tree

In [None]:
# Download and unzip stockfish
if not os.path.exists("stockfish"):
    !wget https://github.com/official-stockfish/Stockfish/releases/latest/download/stockfish-ubuntu-x86-64-avx2.tar
    !tar -xf stockfish-ubuntu-x86-64-avx2.tar

In [None]:
def convert_to_adj_node_lists(t):
    # First we label each node with an integer
    counter = itertools.count()
    idxs = {}
    stack = [t]
    all_nodes = []
    while stack:
        n = stack.pop()
        all_nodes.append(n)
        idxs[id(n)] = next(counter)
        for c in reversed(n.children):
            stack.append(c)
            
    nodes = [None for _ in idxs]
    adj = [None for _ in idxs]

    for n in all_nodes:
        nodes[idxs[id(n)]] = n.move_attributes
        adj[idxs[id(n)]] = [idxs[id(c)] for c in n.children]

    return nodes, adj

In [None]:
# Read `all_trees`, parse to ChessTreeNode, and load to memory.
trees = {}
with open("../data/trees/condor/all", "r") as f:
    for idx, line in enumerate(tqdm(f.readlines())):
        try:
            puzzle_id, tree = json.loads(line)
        except Exception: # lol
            continue
        if len(tree) == 0:
            continue

        t = ChessTreeNode(*tree[0])
        trees[puzzle_id] = (t, convert_to_adj_node_lists(t))
    with open("tree_dict.pkl", "wb") as f:
        pickle.dump(trees, f)

In [None]:
def compare(t1, t2):
    #return zss.simple_distance(t1, t2, ChessTreeNode.get_children, ChessTreeNode.get_label, ChessTreeNode.compare)
    _, (x_nodes, x_adj) = t1
    _, (y_nodes, y_adj) = t2
    return uted(x_nodes, x_adj, y_nodes, y_adj, ChessTreeNode.compare)

def treeify_fen(fen):
    flipped = False
    board = chess.Board(fen)
    if board.turn == chess.BLACK:
        flipped = True
        board.apply_mirror()

    stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2", timeout=None)
    try:
        return ChessTreeNode(*expand_tree(board.fen(), chess.Move.null(), stockfish)[0]), flipped
    finally:
        stockfish.quit()
        
df = pd.read_csv("../data/lichess/lichess_db_puzzle.csv")
def lookup_id(puzzle_id):
    matched = df.loc[df["PuzzleId"] == puzzle_id]
    return f"{puzzle_id} {matched['GameUrl'].item()} {matched['Themes'].item()}"
    

In [None]:
with open("tree_dict.pkl", "rb") as f:
    trees = pickle.load(f)

In [26]:
custom_fens = {
    "backrank M1 N5uON https://lichess.org/UVbVId2k#52": "2krr3/1p3ppp/RP3n2/2NP1b2/1RP2B2/3p4/5PPP/6K1 b - - 0 27",
    "N fork RK 4BY3Z https://lichess.org/ed0T8UXV#96": "8/1N6/1K6/4k1p1/2P1Pp1p/4n2P/3R2P1/8 b - - 0 49",
    "Greek gift yv6lb https://lichess.org/utt3XQm1/black#19": "r1bq1rk1/pp2nppp/1bn1p3/1N1pP3/1P6/P2B1N2/2P2PPP/R1BQK2R w KQ - 3 11",
    "example1": "6k1/5ppp/8/8/8/8/r4PPP/1R4K1 w - - 0 1",
    "example2": "6k1/5ppp/1p1Q4/p3p1B1/Pn4P1/1q6/1Pr4P/K6R w - - 1 2",
    "example3": "4r1k1/1b3pp1/4p3/p2r4/7R/2B1Q1PP/P1P1RP1K/1q6 w - - 0 1",
    "example4": "r5k1/5pp1/8/3p3R/2q4P/PbB2P2/1P1Q2P1/K7 w q - 0 1",
    "example5": "2R5/4bppk/1p1p4/5R1P/4PQ2/5P2/r4q1P/7K w - - 5 50",
    "example6": "5R2/bp4pk/2n3p1/P7/P1q3bP/6P1/3Q3K/1R6 w - - 1 32",
}
custom_flipped = {}
for my_id, my_fen in custom_fens.items():
    t, flip = treeify_fen(my_fen)
    trees[my_id] = (t, convert_to_adj_node_lists(t))
    custom_flipped[my_id] = flip
    print("="*50)
    print(my_id)
    print(my_fen)
    print(t.flip_san(flip))
    print("="*50)

backrank M1 N5uON https://lichess.org/UVbVId2k#52
2krr3/1p3ppp/RP3n2/2NP1b2/1RP2B2/3p4/5PPP/6K1 b - - 0 27
Re1# 140262941889680 1/0/0/0

N fork RK 4BY3Z https://lichess.org/ed0T8UXV#96
8/1N6/1K6/4k1p1/2P1Pp1p/4n2P/3R2P1/8 b - - 0 49
Nxc4+ 140263026994464 2/1/0/0
  Kc3 140264212933728 0/1/0/0
    Nxd7 140263112060624 1/0/0/0
  Kc2 140264212938576 0/1/0/0
    Nxd7 140263112059712 1/0/0/0
  Kb4 140264212938192 1/0/0/0
    Nxd7 140263112062064 1/0/0/0
  Kc4 140264212929216 1/0/0/1
    Nxd7 140263112059184 1/0/0/0

Greek gift yv6lb https://lichess.org/utt3XQm1/black#19
r1bq1rk1/pp2nppp/1bn1p3/1N1pP3/1P6/P2B1N2/2P2PPP/R1BQK2R w KQ - 3 11
Bxh7+ 140262941895536 1/1/1/0
  Kxh7 140261897071088 0/1/0/0
    Ng5+ 140261897067728 3/0/0/1
  Kh8 140261897067776 1/1/0/1
    Ng5 140261675560704 2/1/0/1

example1
6k1/5ppp/8/8/8/8/r4PPP/1R4K1 w - - 0 1
Rb8# 140260199521728 1/0/0/0

example2
6k1/5ppp/1p1Q4/p3p1B1/Pn4P1/1q6/1Pr4P/K6R w - - 1 2
Qd8# 140263070113648 2/1/0/1

example3
4r1k1/1b3pp1/4p3/p2r4/7R/

In [30]:
h = np.zeros((6, 6))
for i in range(1, 7):
    for j in range(1, 7):
        h[i-1][j-1] = compare(trees[f"example{i}"], trees[f"example{j}"])
        
print(h)

[[   0.    232.    978.5   982.5  1422.   1714.  ]
 [ 232.      0.   1086.5  1082.5  1278.   1586.  ]
 [ 978.5  1086.5     0.     46.5   789.   1033.15]
 [ 982.5  1082.5    46.5     0.    795.   1035.15]
 [1422.   1278.    789.    795.      0.    417.  ]
 [1714.   1586.   1033.15 1035.15  417.      0.  ]]


In [None]:
# Find a the top 5 best puzzles for these fens
for my_id, my_fen in custom_fens.items():
    print("="*50)
    print(my_id)
    print(my_fen)
    print("="*50)
    my_tree = trees[my_id]
    distances = []
    for pid, t in tqdm(list(trees.items())[:10000], leave=False):
        if pid in custom_fens:
            continue
        distances.append((compare(t, my_tree), pid, t))
    distances.sort()

    for dist, pid, t in distances[:5]:
        print(f"(dist: {dist}) {lookup_id(pid)}")
        #print(t)
        print("-"*50)

In [None]:
np.random.seed(0)
ids = np.random.choice([k for k, _ in trees.items() if k not in custom_fens], 10000)
ids = np.concatenate((ids, list(custom_fens.keys())))
trees_short = [trees[k] for k in ids]

In [None]:
D = pairwise_distances_symmetric(trees_short, compare)

In [None]:
#optics = DBSCAN(metric="precomputed", eps=250, min_samples=20)
optics = HDBSCAN(metric="precomputed")
clust = optics.fit(D)
print(len(set(clust.labels_)))
print(set(clust.labels_))

In [None]:
# OPTICS reachability
space = np.arange(len(D))
reachability = clust.reachability_[clust.ordering_]
labels = clust.labels_[clust.ordering_]

_, ax = plt.subplots(figsize=(18, 6))
for i in range(len(set(clust.labels_)) - 1):
    ax.plot(space[i == labels], reachability[i == labels], ".", alpha=0.3)
    
ax.plot(space[-1 == labels], reachability[-1 == labels], "k.", alpha=0.3)

    

In [None]:
print("Cluster of puzzles from OPTICS")
print([(p, clust.labels_[np.where(ids==p)[0]]) for p, *_ in custom_fens.items()])

for p, f in custom_fens.items():
    print(p, f)
    idx = np.where(ids==p)[0]
    this_cluster = np.nonzero(clust.labels_ == clust.labels_[idx])
    for close_id in this_cluster[0]:
        print(f"{ids[close_id]}, {D[idx, close_id][0]}")


In [None]:
for c in set(clust.labels_):
    print(f"label {c}")
    for pid in ids[np.where(clust.labels_==c)[0]][:7]:
        if pid in custom_fens:
            continue
        print(lookup_id(pid))
    print()

In [None]:
tsne = TSNE(metric="precomputed", init="random", random_state=0, perplexity=20)
y = tsne.fit_transform(D)

colours = ["r.", "g.", "b.", "c."]
labels = [(p, clust.labels_[np.where(ids==p)[0]]) for p, *_ in custom_fens.items()]
print(labels)

_, ax = plt.subplots(figsize=(16, 12))

# Plot -1 labels with faint crosses
yc = y[clust.labels_ == -1]
ax.plot(yc[:, 0], yc[:, 1], "k.", alpha=0.1)

# Plot special colours, except -1
for (p, c), col in zip(labels, colours):
    if c == -1:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], col, label=f"{c}: Similar to {p}" )

# Plot the rest
for c in set(clust.labels_):
    if c in list(zip(*labels))[1] or c == -1:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], "x", alpha=0.3)

ax.legend()
