In [None]:
from importlib import reload
from sklearn.cluster import OPTICS, cluster_optics_dbscan, DBSCAN
from sklearn.manifold import TSNE
from tqdm import tqdm
import chess
import json
import main
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import random
import zss
reload(main)
from main import ChessTreeNode, expand_tree

In [None]:
# Download and unzip stockfish
if not os.path.exists("stockfish"):
    !wget https://github.com/official-stockfish/Stockfish/releases/latest/download/stockfish-ubuntu-x86-64-avx2.tar
    !tar -xf stockfish-ubuntu-x86-64-avx2.tar

In [None]:
# Read `all_trees`, parse to ChessTreeNode, and load to memory.
trees = {}
with open("../data/trees/condor/all", "r") as f:
    for idx, line in enumerate(tqdm(f.readlines())):
        try:
            puzzle_id, tree = json.loads(line)
        except Exception: # lol
            continue

        if len(tree) == 0:
            continue
        trees[puzzle_id] = ChessTreeNode(*tree[0])

    with open("tree_dict.pkl", "wb") as f:
        pickle.dump(trees, f)

In [None]:
def compare(t1, t2):
    return zss.simple_distance(t1, t2, ChessTreeNode.get_children, ChessTreeNode.get_label, ChessTreeNode.compare)

def treeify_fen(fen):
    flipped = False
    board = chess.Board(fen)
    if board.turn == chess.BLACK:
        flipped = True
        board.apply_mirror()

    stockfish = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2", timeout=None)
    try:
        return ChessTreeNode(*expand_tree(board.fen(), chess.Move.null(), stockfish)[0]), flipped
    finally:
        stockfish.quit()
        
df = pd.read_csv("../data/lichess/lichess_db_puzzle.csv")
def lookup_id(puzzle_id):
    matched = df.loc[df["PuzzleId"] == puzzle_id]
    return f"{puzzle_id} {matched['GameUrl'].item()} {matched['Themes'].item()}"
    

In [None]:
with open("tree_dict.pkl", "rb") as f:
    trees = pickle.load(f)

In [None]:
custom_fens = {
    "smothered M2": "r3r2k/ppp2Qbp/2bp2p1/5PN1/4P3/1B2q2P/P5PK/8 w - - 0 25",
    "backrank M2": "6k1/1rp2pp1/5n1p/2B5/2RP4/4P3/P4PPP/6K1 b - - 0 24",
    "N forks QK": "2kr3r/2p5/ppq2n1p/3Np1p1/P3P3/1P6/Q2b1PPP/3RR1K1 w - - 0 22",
    "pinned pawn defender": "4r1k1/pb3p1p/1bp1q1p1/3pB3/3P4/1NP2PQP/PP4P1/R5K1 b - - 1 20",

}
custom_flipped = {}
for my_id, my_fen in custom_fens.items():
    t, flip = treeify_fen(my_fen)
    trees[my_id], custom_flipped[my_id] = t, flip
    print("="*50)
    print(my_id)
    print(my_fen)
    print(t.flip_san(flip))
    print("="*50)

In [None]:
# Find a the top 5 best puzzles for these fens
for my_id, my_fen in custom_fens.items():
    print("="*50)
    print(my_id)
    print(my_fen)
    print("="*50)
    my_tree = trees[my_id]
    distances = []
    for pid, t in tqdm(list(trees.items())[:10000], leave=False):
        distances.append((compare(t, my_tree), pid, t))
    distances.sort()

    for dist, pid, t in distances[:5]:
        print(f"(dist: {dist}) {lookup_id(pid)}")
        #print(t)
        print("-"*50)

In [None]:
np.random.seed(0)
ids = np.random.choice([k for k, _ in trees.items() if k not in custom_fens], 1000)
ids = np.concatenate((ids, list(custom_fens.keys())))

In [None]:
N = len(ids)
# Generate distance matrix. It's symmetric, so use that to halve running time.
d = np.zeros((N, N))
with tqdm(total=N*(N+1)//2) as pbar:
    for i in tqdm(range(N)):
        for j in range(i):
            d[i,j] = compare(trees[ids[i]], trees[ids[j]])
            pbar.update(1)
d = d + d.T


In [None]:
optics = DBSCAN(metric="precomputed", eps=100, min_samples=4)
clust = optics.fit(d)
print(set(clust.labels_))

In [None]:
# OPTICS reachability
space = np.arange(len(d))
reachability = clust.reachability_[clust.ordering_]
labels = clust.labels_[clust.ordering_]

_, ax = plt.subplots(figsize=(18, 6))
for i in range(len(set(clust.labels_)) - 1):
    ax.plot(space[i == labels], reachability[i == labels], ".", alpha=0.3)
    
ax.plot(space[-1 == labels], reachability[-1 == labels], "k.", alpha=0.3)

    

In [None]:
print("Cluster of puzzles from OPTICS")
print([(p, clust.labels_[np.where(ids==p)[0]]) for p, *_ in custom_fens.items()])

for p, f in custom_fens.items():
    print(p, f)
    idx = np.where(ids==p)[0]
    this_cluster = np.nonzero(clust.labels_ == clust.labels_[idx])
    for close_id in this_cluster[0]:
        print(f"{ids[close_id]}, {d[idx, close_id][0]}")


In [None]:
for c in set(clust.labels_):
    print(f"label {c}")
    for pid in ids[np.where(clust.labels_==c)[0]][:7]:
        print(lookup_id(pid))
        print(trees[pid])
    print()

In [None]:
compare(trees["4Qifg"], trees["3TTaA"])

In [None]:
tsne = TSNE(metric="precomputed", init="random", random_state=0, perplexity=20)
y = tsne.fit_transform(d)

colours = ["r.", "g.", "b.", "c."]
labels = [(p, clust.labels_[np.where(ids==p)[0]]) for p, *_ in custom_fens.items()]
print(labels)

_, ax = plt.subplots(figsize=(16, 12))

# Plot -1 labels with faint crosses
yc = y[clust.labels_ == -1]
ax.plot(yc[:, 0], yc[:, 1], "kx", alpha=0.3)

# Plot special colours, except -1
for (p, c), col in zip(labels, colours):
    if c == -1:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], col, label=f"{c}: Similar to {p}" )

# Plot the rest
for c in set(clust.labels_):
    if c in list(zip(*labels))[1] or c == -1:
        continue
    yc = y[clust.labels_ == c]
    ax.plot(yc[:, 0], yc[:, 1], "x", alpha=0.3, label=c)

ax.legend()
