In [15]:
import os
import sys
import json
import pandas as pd
import numpy as np
import torch
from typing import Literal, TypedDict, Callable, TypeVar, Generic, Iterable

In [2]:
sys.path.append("..")
os.environ["USER_PATH"] = "../userdata/"

In [3]:
from system.namespace.store import get_namespace

In [4]:
namespace = get_namespace("train_db")

In [5]:
from system.embedding.store import get_embed_store
from misc.lru import LRU

In [6]:
embeds = get_embed_store(namespace)

In [7]:
role_c = "child"
role_p = "parent"

In [8]:
total_c = embeds.get_embedding_count(role_c)
total_c

689418

In [9]:
x = np.vstack([
    embed[1].ravel().detach().numpy()
    for embed in embeds.get_all_embeddings(role_c, progress_bar=True)
])
x = x[:100000, :]

  0%|          | 0/689418 [00:00<?, ?it/s]

In [10]:
x.shape

(100000, 768)

In [12]:
x[5, :].shape

(768,)

In [599]:
T = TypeVar('T')


class Node(Generic[T]):
    def __init__(self, dbs: 'Fann[T]', embed_ix: int) -> None:
        self._dbs = dbs
        self._embed_ix = embed_ix
        self._radius: float = 0.0
        self._count: int = 1
        self._children: list[tuple[Node, float]] = []
        
    def get_ix(self) -> int:
        return self._embed_ix
        
    def get_embed(self) -> T:
        return self._dbs.get_embed(self._embed_ix)
    
    def count_descendants(self) -> int:
        return self._count
    
    def count_children(self) -> int:
        return len(self._children)
    
    def get_dist(self, embed: T, eix: int | None, cache: dict[int, float]) -> float:
        return self._dbs.dist_embed(embed, eix, self._embed_ix, cache)
    
    def get_radius(self) -> float:
        return self._radius
    
    def get_dist_max(self, embed: T, eix: int | None, cache: dict[int, float]) -> float:
        return self.get_dist(embed, eix, cache) + self._radius
    
    def get_dist_min(self, embed: T, eix: int | None, cache: dict[int, float]) -> float:
        return max(0, self.get_dist(embed, eix, cache) - self._radius)
    
    def get_true_dist_max(self, embed: T, eix: int | None, cache: dict[int, float]) -> float:
        init = self.get_dist(embed, eix, cache)
        cur = init
        for child, _ in self._children:
            cur = max(cur, child.get_true_dist_max(embed, eix, cache))
        if cur - init > self._radius:
            self._radius = cur - init
        return cur
        
    def add_child(self, child: 'Node[T]') -> None:
        embed = self.get_embed()
        eix = self.get_ix()
        cache = {}
        c_dist_max = child.get_true_dist_max(embed, eix, cache)
        # c_dist_max = child.get_dist_max(embed, eix, cache)
        # if len(cache) > 1:
        #     print(f"comparisons while adding: {len(cache)}")
        if c_dist_max > self._radius:
            self._radius = c_dist_max
        self._children.append((child, c_dist_max))
        self._children.sort(key=lambda c: c[1], reverse=True)
        self._count += child.count_descendants()
    
    def get_all_descendants(self) -> list['Node[T]']:
        res = [self]
        for child, _ in self._children:
            res += child.get_all_descendants()
        return res
    
    def get_closest(
            self,
            res: list[list[tuple['Node[T]', float]]],
            res_max: list[float],
            embed: T,
            count: int,
            cache: dict[int, float],
            stats: dict[int, str],
            eps: float) -> None:
        own_ix = self.get_ix()
        eix = None
        # if self.count_descendants() <= count:
        #     stats[own_ix] = "full"
        #     return [
        #         (
        #             desc,
        #             desc.get_dist(embed, eix, cache),
        #         )
        #         for desc in self.get_all_descendants()
        #     ]
        own_dist = self.get_dist(embed, eix, cache)
        radius = self.get_radius()
        
#         def get_max() -> float:
#             if len(res
#             res[0]
        
        def compact() -> None:
            res[0] = sorted(res[0], key=lambda row: row[1])[:count]
            res_max[0] = max((dist for _, dist in res[0]))
        
        if len(res[0]) < count or own_dist - eps <= res_max[0]:
            res[0].append((self, own_dist))
            compact()
        if radius < own_dist:
            stats[own_ix] = "outer"
            for child, c_dist_center in self._children:
                if res_max[0] < own_dist - c_dist_center - eps:
                    continue
                child.get_closest(
                    res,
                    res_max,
                    embed,
                    count,
                    cache,
                    stats,
                    eps)
        else:
            stats[own_ix] = "inner"
            children = sorted([
                (
                    child,
                    child.get_dist_min(embed, eix, cache),
                )
                for child, _ in self._children
            ], key=lambda row: row[1])
            for row in children:
                child, cmin = row
                if cmin - eps > res_max[0]:
                    continue
                child.get_closest(
                    res,
                    res_max,
                    embed,
                    count,
                    cache,
                    stats,
                    eps)
    
    def debug(
            self,
            pad: int,
            show_ixs: dict[int, bool],
            stats: dict[int, str],
            prune: bool) -> list[str]:
        highlight_chr = ":"
        highlight_schr = "*"
        lookup = {
            True: highlight_schr,
            False: highlight_chr,
            None: "",
        }
        highlight = lookup[show_ixs.get(self._embed_ix, None)]
        num = f"{self._embed_ix}".rjust(pad - len(highlight))
        own = f"({highlight}{num})"
        if not self._children:
            return [own]
        if self.count_descendants() == len(self._children) + 1:
            chs = ", ".join((
                f"{lookup[show_ixs.get(child.get_ix(), None)]}{child.get_ix()}"
                for child, _ in self._children
            ))
            if prune and highlight_chr not in chs and highlight_schr not in chs:
                chs = "..."
            return [f"{own}━({chs})"]
        outs = [
            (child.get_ix(), child.debug(pad, show_ixs, stats, prune))
            for child, _ in self._children
        ]
        bar = " " * len(own)
        res = []
        for cix, (child_ix, lines) in enumerate(outs):
            all_lines = "".join(lines)
            if prune and highlight_chr not in all_lines and highlight_schr not in all_lines:
                lines = ["(...)"]
            for lix, line in enumerate(lines):
                if lix == 0 and cix == 0:
                    start = own
                else:
                    start = bar
                if lix == 0:
                    if cix == 0:
                        if len(outs) > 1:
                            mid = "┳"
                        else:
                            mid = "━"
                    else:
                        if cix >= len(outs) - 1:
                            mid = "┗"
                        else:
                            mid = "┣"
                    mid = stats.get(child_ix, mid)
                    if len(mid) > 1 or mid[0].isalnum():
                        mid = mid.upper()[:1]
                else:
                    if cix >= len(outs) - 1:
                        mid = " "
                    else:
                        mid = "┃"
                res.append(f"{start}{mid}{line}")
        return res


class Fann(Generic[T]):
    def __init__(
            self,
            get_all_ix: Callable[[], Iterable[int]],
            get_embed: Callable[[int], T],
            get_dist: Callable[[T, T], float]) -> None:
        self._get_all_ix = get_all_ix
        self._get_embed = get_embed
        self._get_dist = get_dist
        self._lru: LRU[int, T] = LRU(1000)
        self._dlru: LRU[tuple[int, int], float] = LRU(100000)
        self._root: Node | None = None
        self._high_ix: int | None = None
        
    def get_embed(self, ix: int) -> T:
        res = self._lru.get(ix)
        if res is None:
            res = self._get_embed(ix)
            self._lru.set(ix, res)
        return res
    
    def dist_embed(self, embed: T, eix: int | None, ix: int, cache: dict[int, float]) -> float:
        key = None if eix is None else (min(eix, ix), max(eix, ix))
        if key is not None:
            res = self._dlru.get(key)
            if res is not None:
                cache[ix] = res
                return res
        res = cache.get(ix)
        if res is not None:
            if key is not None:
                self._dlru.set(key, res)
            return res
        res = self._get_dist(embed, self.get_embed(ix))
        assert res >= 0.0
        cache[ix] = res
        if key is not None:
            self._dlru.set(key, res)
        return res
    
    def _centroid(self, all_ixs: list[int]) -> int:
        best_dist = 0.0
        best_ix = None
        for ix in all_ixs:
            cur_dist = 0.0
            embed = self.get_embed(ix)
            for oix in all_ixs:
                if ix == oix:
                    continue
                cache = {}  # no cache
                cur_dist += self.dist_embed(embed, ix, oix, cache)
                if best_ix is not None and cur_dist > best_dist:
                    break
            if best_ix is None or cur_dist < best_dist:
                best_ix = ix
                best_dist = cur_dist
        assert best_ix is not None
        return best_ix
    
    def _kmedoid(self, all_ixs: list[int], k_num: int) -> list[tuple[int, list[int]]]:
        assert len(all_ixs) > k_num
        centroids = all_ixs[:k_num]
        rounds = 1000
        while rounds > 0:
            assignments = [[cix] for cix in centroids]
            for ix in all_ixs:
                if ix in centroids:
                    continue
                best_dist = 0.0
                best_cluster_ix = None
                embed = self.get_embed(ix)
                for cluster_ix, cix in enumerate(centroids):
                    cache = {}  # no cache
                    cur_dist = self.dist_embed(embed, ix, cix, cache)
                    if best_cluster_ix is None or cur_dist < best_dist:
                        best_cluster_ix = cluster_ix
                        best_dist = cur_dist
                assert best_cluster_ix is not None
                assignments[best_cluster_ix].append(ix)
            done = True
            for cluster_ix in range(len(centroids)):
                new_c = self._centroid(assignments[cluster_ix])
                if new_c != centroids[cluster_ix]:
                    centroids[cluster_ix] = new_c
                    done = False
            if done:
                break
            rounds -= 1
        if rounds <= 0:
            print("exhausted iteration steps")
        return list(zip(centroids, assignments))
    
    def _remove(self, all_ixs: list[int], remove_ix: int) -> list[int]:
        return [ix for ix in all_ixs if ix != remove_ix]
    
    def build(self, max_node_size: int) -> None:
        all_ixs = [ix for ix in self._get_all_ix()]
        self._high_ix = max(all_ixs)
        root_ix = self._centroid(all_ixs)
        all_ixs = self._remove(all_ixs, root_ix)
        
        def build_level(cur_root_ix: int, cur_all_ixs: list[int]) -> Node[T]:
            node = Node(self, cur_root_ix)
            if len(cur_all_ixs) <= max_node_size:
                for child_ix in cur_all_ixs:
                    node.add_child(Node(self, child_ix))
                return node
            num_k = max_node_size
            if max_node_size * max_node_size > len(cur_all_ixs):
                num_k = int(np.sqrt(len(cur_all_ixs)))
            children = self._kmedoid(cur_all_ixs, num_k)
            for row in children:
                centroid_ix, assignments = row
                assignments = self._remove(assignments, centroid_ix)
                cnode = build_level(centroid_ix, assignments)
                node.add_child(cnode)
            return node
        
        self._root = build_level(root_ix, all_ixs)
    
    def get_closest(
            self,
            embed: T,
            count: int,
            cache: dict[int, float] | None = None,
            stats: dict[int, str] | None = None,
            eps: float = 0.0) -> list[tuple[int, float]]:
        assert self._root is not None
        if cache is None:
            cache = {}
        if stats is None:
            stats = {}
        res = [[]]
        res_max = [np.inf]
        self._root.get_closest(
            res,
            res_max,
            embed,
            count,
            cache,
            stats,
            eps)
        return [
            (row[0].get_ix(), row[1])
            for row in res[0]
        ]
    
    def debug(
            self,
            *,
            show_ixs: set[int] | None = None,
            special_ixs: set[int] | None = None,
            stats: dict[int, str] | None = None,
            prune: bool = False) -> str:
        high_ix = self._high_ix
        if show_ixs is not None:
            high_ix = max(max(show_ixs) * 10, high_ix or 0)
        if high_ix is None:
            pad = 0
        else:
            pad = len(f"{high_ix}")
        highlight_ixs = {}
        if show_ixs is not None:
            for ix in show_ixs:
                highlight_ixs[ix] = False
        if special_ixs is not None:
            for ix in special_ixs:
                highlight_ixs[ix] = True
        if stats is None:
            stats = {}
        return "\n".join(self._root.debug(pad, highlight_ixs, stats, prune))

In [600]:
test_size = 10000

In [601]:
# dbs = Fann(lambda: range(0, x.shape[0]), lambda ix: x[ix, :], lambda a, b: np.log1p(np.exp(-np.dot(a, b))))
dbs = Fann(lambda: range(0, test_size), lambda ix: x[ix, :], lambda a, b: np.exp(-np.dot(a, b)))
dbs

<__main__.Fann at 0x17b96bbe0>

In [602]:
%%time

dbs.build(max_node_size=10)

CPU times: user 1min 11s, sys: 1.14 s, total: 1min 12s
Wall time: 1min 12s


In [609]:
c_cache = {}
c_stats = {}
is_correct = 2838 in (row[0] for row in dbs.get_closest(x[10001, :], 10, c_cache, c_stats))
print(dbs.debug(
    special_ixs={2838, 9282, 218, 1055, 6350, 1441, 5042, 1232, 2822, 7923},
    show_ixs=c_cache,
    stats=c_stats,
    prune=True))

(:6829)I(:8396)┳(:2370)┳(...)
       ┃       ┃       ┣(...)
       ┃       ┃       ┣(...)
       ┃       ┃       ┣(...)
       ┃       ┃       ┗(...)
       ┃       I(:5239)┳(: 322)━(...)
       ┃       ┃       ┣(: 134)━(...)
       ┃       ┃       ┣(:9045)━(...)
       ┃       ┃       I(:8504)I(:6281)━(:7844, :3904, :2720, :5190, :1345, :1712)
       ┃       ┃       ┃       ┣(:3641)┳(...)
       ┃       ┃       ┃       ┃       ┣(...)
       ┃       ┃       ┃       ┃       ┣(...)
       ┃       ┃       ┃       ┃       ┗(...)
       ┃       ┃       ┃       ┣(: 438)━(...)
       ┃       ┃       ┃       O(: 993)
       ┃       ┃       ┃       O(: 232)
       ┃       ┃       ┣(: 304)┳(...)
       ┃       ┃       ┃       ┣(...)
       ┃       ┃       ┃       ┣(...)
       ┃       ┃       ┃       ┗(...)
       ┃       ┃       I(:9122)┳(: 971)━(...)
       ┃       ┃       ┃       I(: 520)┳(:1682)━(...)
       ┃       ┃       ┃       ┃       O(:5813)━(:7169, :3673, :7193, :3007, :3040)
       

In [610]:
is_correct

False

In [604]:
test_ix = test_size
test_embed = x[test_ix, :]
test_embed.shape

(768,)

In [605]:
%%time

test_dists = np.exp(-np.dot(x[:test_size, :], test_embed))
test_dists.shape

CPU times: user 19.1 ms, sys: 442 µs, total: 19.6 ms
Wall time: 3.74 ms


(10000,)

In [606]:
np.argsort(test_dists)[:10]

array([4711,  218, 4413, 6829, 9231, 5834, 6155,  795, 6574, 9277])

In [607]:
test_dists[np.argsort(test_dists)[:10]]

array([7.93406427e-120, 8.74874356e-120, 8.94165723e-120, 9.21019760e-120,
       9.27423789e-120, 1.00236720e-119, 1.03866893e-119, 1.04421357e-119,
       1.15937662e-119, 1.17943871e-119])

In [608]:
def compute_stats(stats: dict[int, str]) -> dict[str, int]:
    res = {}
    for name in stats.values():
        if name not in res:
            res[name] = 1
        else:
            res[name] += 1
    return res

In [593]:
%%time

cache = {}
stats = {}
res = dbs.get_closest(test_embed, 10, cache, stats)
res, len(cache), compute_stats(stats)

CPU times: user 117 ms, sys: 39 ms, total: 156 ms
Wall time: 22.7 ms


([(4711, 7.934064274485374e-120),
  (218, 8.748743564573412e-120),
  (4413, 8.941657231831245e-120),
  (6829, 9.21019760211911e-120),
  (9231, 9.274237887915886e-120),
  (5834, 1.0023671996319938e-119),
  (6155, 1.0386689324780679e-119),
  (795, 1.0442135697163178e-119),
  (6574, 1.1593766167871222e-119),
  (9277, 1.179438714662682e-119)],
 4482,
 {'inner': 765, 'outer': 436})

In [582]:
%%time

orders_base = []
for ref_ix in range(test_ix, x.shape[0]):
    test_dists = np.exp(-np.dot(x[:test_size, :], x[ref_ix, :]))
    orders_base.append(np.argsort(test_dists)[:10].tolist())

CPU times: user 23min 16s, sys: 12.4 s, total: 23min 29s
Wall time: 3min 8s


In [560]:
%%time

orders_fast = []
cmp_nums = []
stats = {"inner": 0, "outer": 0}
for ref_ix in range(test_ix, x.shape[0]):
    cache = {}
    f_stats = {}
    orders_fast.append(dbs.get_closest(x[ref_ix, :], 10, cache, f_stats))
    for key, value in compute_stats(f_stats).items():
        stats[key] += value
    cmp_nums.append(len(cache))

CPU times: user 1h 18min 57s, sys: 33.6 s, total: 1h 19min 31s
Wall time: 7h 34min 25s


In [561]:
stats

{'inner': 33689538, 'outer': 866310462}

In [562]:
assert len(orders_base) == len(orders_fast)
correct = 0
incorrect = 0
show_only = 1000
for rix, (base, fast) in enumerate(zip(orders_base, orders_fast)):
    cur_fast = [row[0] for row in fast]
    if base == cur_fast:
        correct += 1
    else:
        incorrect += 1
        if show_only > 0:
            print(f"{test_ix + rix}: {base} != {cur_fast}")
            show_only -= 1
correct, incorrect, correct / (incorrect + correct)

(90000, 0, 1.0)

In [563]:
len(cmp_nums), np.std(cmp_nums), np.mean(cmp_nums), max(cmp_nums), min(cmp_nums)

(90000, 0.0, 10000.0, 10000, 10000)