In [15]:
import os
import sys
import json
import pandas as pd
import numpy as np
import torch
from typing import Literal, TypedDict, Callable, TypeVar, Generic, Iterable

In [2]:
sys.path.append("..")
os.environ["USER_PATH"] = "../userdata/"

In [3]:
from system.namespace.store import get_namespace

In [4]:
namespace = get_namespace("train_db")

In [5]:
from system.embedding.store import get_embed_store
from misc.lru import LRU

In [6]:
embeds = get_embed_store(namespace)

In [7]:
role_c = "child"
role_p = "parent"

In [8]:
total_c = embeds.get_embedding_count(role_c)
total_c

689418

In [9]:
x = np.vstack([
    embed[1].ravel().detach().numpy()
    for embed in embeds.get_all_embeddings(role_c, progress_bar=True)
])
x = x[:100000, :]

  0%|          | 0/689418 [00:00<?, ?it/s]

In [10]:
x.shape

(100000, 768)

In [12]:
x[5, :].shape

(768,)

In [248]:
T = TypeVar('T')


class Node(Generic[T]):
    def __init__(self, dbs: 'DBScANN[T]', embed_ix: int) -> None:
        self._dbs = dbs
        self._embed_ix = embed_ix
        self._dist_range: float = 0.0
        self._count: int = 1
        self._children: list[Node] = []
        
    def get_ix(self) -> int:
        return self._embed_ix
        
    def get_embed(self) -> T:
        return self._dbs.get_embed(self._embed_ix)
    
    def count_descendants(self) -> int:
        return self._count
    
    def count_children(self) -> int:
        return len(self._children)
    
    def get_dist(self, embed: T, cache: dict[int, float]) -> float:
        return self._dbs.dist_embed(embed, self._embed_ix, cache)
    
    def get_dist_max(self, embed: T, cache: dict[int, float]) -> float:
        return self.get_dist(embed, cache) + self._dist_range
    
    def get_dist_min(self, embed: T, cache: dict[int, float]) -> float:
        return max(0, self.get_dist(embed, cache) - self._dist_range)
    
    def get_true_dist_max(self, embed: T, cache: dict[int, float]) -> float:
        init = self.get_dist(embed, cache)
        cur = init
        for child in self._children:
            cur = max(cur, child.get_true_dist_max(embed, cache))
        if cur - init > self._dist_range:
            self._dist_range = cur - init
        return cur
        
    def add_child(self, child: 'Node[T]') -> None:
        embed = self.get_embed()
        cache = {}
        c_dist_max = child.get_true_dist_max(embed, cache)
        # c_dist_max = child.get_dist_max(embed, cache)
        # if len(cache) > 1:
        #     print(f"comparisons while adding: {len(cache)}")
        if c_dist_max > self._dist_range:
            self._dist_range = c_dist_max
        self._children.append(child)
        self._count += child.count_descendants()
    
    def get_all_descendants(self) -> list['Node[T]']:
        res = [self]
        for child in self._children:
            res += child.get_all_descendants()
        return res
    
    def get_closest(
            self,
            embed: T,
            count: int,
            cache: dict[int, float]) -> list[tuple['Node[T]', float]]:
        if self.count_descendants() <= count:
            return [
                (
                    desc,
                    desc.get_dist(embed, cache),
                )
                for desc in self.get_all_descendants()
            ]
        own_dist = self.get_dist(embed, cache)
        children = sorted([
            (
                child,
                child.get_dist_min(embed, cache),
            )
            for child in self._children
        ], key=lambda row: row[1])
        
        def compact() -> None:
            nonlocal res, cur_max
            
            res = sorted(res, key=lambda row: row[1])[:count]
            cur_max = max((dist for _, dist in res))
            
        res = [(self, own_dist)]
        cur_max = own_dist
        for row in children:
            compact()
            child, cmin = row
            if cmin > cur_max:
                continue
            res.extend(child.get_closest(embed, count, cache))
        compact()
        return res
    
    def debug(self, pad: int) -> list[str]:
        outs = [
            child.debug(pad)
            for child in self._children
        ]
        num = f"{self._embed_ix}".rjust(pad)
        own = f"({num})"
        if not outs:
            return [own]
        bar = " " * len(own)
        res = []
        for cix, lines in enumerate(outs):
            for lix, line in enumerate(lines):
                if lix == 0 and cix == 0:
                    start = own
                else:
                    start = bar
                if lix == 0:
                    if cix == 0:
                        if len(outs) > 1:
                            mid = "┳"
                        else:
                            mid = "━"
                    else:
                        if cix >= len(outs) - 1:
                            mid = "┗"
                        else:
                            mid = "┣"
                else:
                    if cix >= len(outs) - 1:
                        mid = " "
                    else:
                        mid = "┃"
                res.append(f"{start}{mid}{line}")
        return res


class DBScANN(Generic[T]):
    def __init__(
            self,
            get_all_ix: Callable[[], Iterable[int]],
            get_embed: Callable[[int], T],
            get_dist: Callable[[T, T], float]) -> None:
        self._get_all_ix = get_all_ix
        self._get_embed = get_embed
        self._get_dist = get_dist
        self._lru: LRU[int, T] = LRU(1000)
        self._root: Node | None = None
        self._high_ix: int | None = None
        
    def get_embed(self, ix: int) -> T:
        res = self._lru.get(ix)
        if res is None:
            res = self._get_embed(ix)
            self._lru.set(ix, res)
        return res
    
    def dist_embed(self, embed: T, ix: int, cache: dict[int, float]) -> float:
        res = cache.get(ix)
        if res is not None:
            return res
        res = self._get_dist(embed, self.get_embed(ix))
        assert res >= 0.0
        cache[ix] = res
        return res
    
    def _centroid(self, all_ixs: list[int]) -> int:
        best_dist = 0.0
        best_ix = None
        for ix in all_ixs:
            cur_dist = 0.0
            embed = self.get_embed(ix)
            for oix in all_ixs:
                if ix == oix:
                    continue
                cache = {}  # no cache
                cur_dist += self.dist_embed(embed, oix, cache)
                if best_ix is not None and cur_dist > best_dist:
                    break
            if best_ix is None or cur_dist < best_dist:
                best_ix = ix
                best_dist = cur_dist
        assert best_ix is not None
        return best_ix
    
    def _kmedoid(self, all_ixs: list[int], k_num: int) -> list[tuple[int, list[int]]]:
        assert len(all_ixs) > k_num
        centroids = all_ixs[:k_num]
        rounds = 1000
        while rounds > 0:
            assignments = [[cix] for cix in centroids]
            for ix in all_ixs:
                if ix in centroids:
                    continue
                best_dist = 0.0
                best_cluster_ix = None
                embed = self.get_embed(ix)
                for cluster_ix, cix in enumerate(centroids):
                    cache = {}  # no cache
                    cur_dist = self.dist_embed(embed, cix, cache)
                    if best_cluster_ix is None or cur_dist < best_dist:
                        best_cluster_ix = cluster_ix
                        best_dist = cur_dist
                assert best_cluster_ix is not None
                assignments[best_cluster_ix].append(ix)
            done = True
            for cluster_ix in range(len(centroids)):
                new_c = self._centroid(assignments[cluster_ix])
                if new_c != centroids[cluster_ix]:
                    centroids[cluster_ix] = new_c
                    done = False
            if done:
                break
            rounds -= 1
        if rounds <= 0:
            print("exhausted iteration steps")
        return list(zip(centroids, assignments))
    
    def _remove(self, all_ixs: list[int], remove_ix: int) -> list[int]:
        return [ix for ix in all_ixs if ix != remove_ix]
    
    def build(self, max_node_size: int) -> None:
        all_ixs = [ix for ix in self._get_all_ix()]
        self._high_ix = max(all_ixs)
        root_ix = self._centroid(all_ixs)
        all_ixs = self._remove(all_ixs, root_ix)
        
        def build_level(cur_root_ix: int, cur_all_ixs: list[int]) -> Node[T]:
            node = Node(self, cur_root_ix)
            if len(cur_all_ixs) <= max_node_size:
                for child_ix in cur_all_ixs:
                    node.add_child(Node(self, child_ix))
                return node
            children = self._kmedoid(cur_all_ixs, max_node_size)
            for row in children:
                centroid_ix, assignments = row
                assignments = self._remove(assignments, centroid_ix)
                cnode = build_level(centroid_ix, assignments)
                node.add_child(cnode)
            return node
        
        self._root = build_level(root_ix, all_ixs)
    
    def get_closest(self, embed: T, count: int) -> list[tuple[int, float]]:
        assert self._root is not None
        cache = {}
        closest = self._root.get_closest(embed, count, cache)
        print(f"get_closest main comparisons: {len(cache)}")
        res = sorted([
            (row[0].get_ix(), row[0].get_dist(embed, cache))
            for row in closest
        ], key=lambda cur: cur[1])
        print(f"get_closest final comparisons: {len(cache)}")
        return res
    
    def debug(self) -> str:
        if self._high_ix is None:
            pad = 0
        else:
            pad = len(f"{self._high_ix}")
        return "\n".join(self._root.debug(pad))

In [249]:
test_size = 10000

In [250]:
# dbs = DBScANN(lambda: range(0, x.shape[0]), lambda ix: x[ix, :], lambda a, b: np.log1p(np.exp(-np.dot(a, b))))
dbs = DBScANN(lambda: range(0, test_size), lambda ix: x[ix, :], lambda a, b: np.log1p(np.exp(-np.dot(a, b))))
dbs

<__main__.DBScANN at 0x48511fd30>

In [251]:
%%time

dbs.build(max_node_size=10)

CPU times: user 40.6 s, sys: 99.7 ms, total: 40.7 s
Wall time: 40.7 s


In [252]:
print(dbs.debug())

(6829)┳( 795)┳(6328)┳( 772)
      ┃      ┃      ┣(1029)
      ┃      ┃      ┣(2696)
      ┃      ┃      ┣(3238)
      ┃      ┃      ┣(3299)
      ┃      ┃      ┣(3360)
      ┃      ┃      ┣(4074)
      ┃      ┃      ┣(5417)
      ┃      ┃      ┣(7265)━(9128)
      ┃      ┃      ┗(7659)┳(8078)
      ┃      ┃             ┣(8282)
      ┃      ┃             ┗(9572)
      ┃      ┣(8672)┳( 211)┳(1436)━(9708)
      ┃      ┃      ┃      ┣(2322)
      ┃      ┃      ┃      ┣(2725)━(8878)
      ┃      ┃      ┃      ┣(3500)
      ┃      ┃      ┃      ┣(3583)
      ┃      ┃      ┃      ┣(3749)
      ┃      ┃      ┃      ┣(4349)
      ┃      ┃      ┃      ┣(4416)
      ┃      ┃      ┃      ┣(5293)
      ┃      ┃      ┃      ┗(8378)━(9442)
      ┃      ┃      ┣(1555)┳( 946)
      ┃      ┃      ┃      ┣(2129)
      ┃      ┃      ┃      ┣(2666)
      ┃      ┃      ┃      ┣(3913)
      ┃      ┃      ┃      ┣(5308)
      ┃      ┃      ┃      ┣(7661)
      ┃      ┃      ┃      ┣(8257)
      ┃      ┃      

In [253]:
test_ix = test_size
test_embed = x[test_ix, :]
test_embed.shape

(768,)

In [254]:
test_dists = np.log1p(np.exp(-np.dot(x[:test_size, :], test_embed)))
test_dists.shape

(10000,)

In [255]:
np.argsort(test_dists)[:10]

array([4711,  218, 4413, 6829, 9231, 5834, 6155,  795, 6574, 9277])

In [256]:
test_dists[np.argsort(test_dists)[:10]]

array([7.93406427e-120, 8.74874356e-120, 8.94165723e-120, 9.21019760e-120,
       9.27423789e-120, 1.00236720e-119, 1.03866893e-119, 1.04421357e-119,
       1.15937662e-119, 1.17943871e-119])

In [257]:
dbs.get_closest(test_embed, 10)

get_closest main comparisons: 7454
get_closest final comparisons: 7454


[(4711, 7.934064274485374e-120),
 (218, 8.748743564573412e-120),
 (4413, 8.941657231831245e-120),
 (6829, 9.21019760211911e-120),
 (9231, 9.274237887915886e-120),
 (5834, 1.0023671996319938e-119),
 (6155, 1.0386689324780679e-119),
 (795, 1.0442135697163178e-119),
 (6574, 1.1593766167871222e-119),
 (9277, 1.179438714662682e-119)]