In [15]:
import os
import sys
import json
import pandas as pd
import numpy as np
import torch
from typing import Literal, TypedDict, Callable, TypeVar, Generic, Iterable

In [2]:
sys.path.append("..")
os.environ["USER_PATH"] = "../userdata/"

In [3]:
from system.namespace.store import get_namespace

In [4]:
namespace = get_namespace("train_db")

In [5]:
from system.embedding.store import get_embed_store
from misc.lru import LRU

In [6]:
embeds = get_embed_store(namespace)

In [7]:
role_c = "child"
role_p = "parent"

In [8]:
total_c = embeds.get_embedding_count(role_c)
total_c

689418

In [9]:
x = np.vstack([
    embed[1].ravel().detach().numpy()
    for embed in embeds.get_all_embeddings(role_c, progress_bar=True)
])
x = x[:100000, :]

  0%|          | 0/689418 [00:00<?, ?it/s]

In [10]:
x.shape

(100000, 768)

In [12]:
x[5, :].shape

(768,)

In [174]:
T = TypeVar('T')


class Node(Generic[T]):
    def __init__(self, dbs: 'DBScANN[T]', embed_ix: int) -> None:
        self._dbs = dbs
        self._embed_ix = embed_ix
        self._dist_range: float = 0.0
        self._count: int = 1
        self._children: list[Node] = []
        
    def get_ix(self) -> int:
        return self._embed_ix
        
    def get_embed(self) -> T:
        return self._dbs.get_embed(self._embed_ix)
    
    def count_descendants(self) -> int:
        return self._count
    
    def count_children(self) -> int:
        return len(self._children)
    
    def get_dist(self, embed: T, cache: dict[int, float]) -> float:
        return self._dbs.dist_embed(embed, self._embed_ix, cache)
    
    def get_dist_max(self, embed: T, cache: dict[int, float]) -> float:
        return self.get_dist(embed, cache) + self._dist_range
    
    def get_dist_min(self, embed: T, cache: dict[int, float]) -> float:
        return max(0, self.get_dist(embed, cache) - self._dist_range)
    
    def get_true_dist_max(self, embed: T, cache: dict[int, float]) -> float:
        init = self.get_dist(embed, cache)
        cur = init
        for child in self._children:
            cur = max(cur, child.get_true_dist_max(embed, cache))
        assert self._dist_range >= cur - init, f"{self._dist_range} >= {cur} - {init}"
        return cur
        
    def add_child(self, child: 'Node[T]') -> None:
        embed = self.get_embed()
        cache = {}
        # c_dist_max = child.get_true_dist_max(embed, cache)
        c_dist_max = child.get_dist_max(embed, cache)
        if len(cache) > 1:
            print(f"comparisons while adding: {len(cache)}")
        if c_dist_max > self._dist_range:
            self._dist_range = c_dist_max
        self._children.append(child)
        self._count += child.count_descendants()
    
    def get_all_descendants(self) -> list['Node[T]']:
        res = [self]
        for child in self._children:
            res += child.get_all_descendants()
        return res
    
    def get_closest(
            self,
            embed: T,
            count: int,
            cache: dict[int, float]) -> list[tuple['Node[T]', float, float]]:
        own_dist = self.get_dist(embed, cache)
        children = sorted([
            (
                child,
                child.count_descendants(),
                child.get_dist_min(embed, cache),
                child.get_dist_max(embed, cache),
            )
            for child in self._children
        ], key=lambda row: (row[2], row[3]))
        
        def compact() -> None:
            nonlocal res
            
            res = sorted(res, key=lambda row: row[1])[:count]
        
        res = [(self, own_dist, own_dist)]
        cur_min = own_dist
        cur_max = own_dist
        for row in children:
            compact()
            child, desc, cmin, cmax = row
            if desc <= count:
                res += [
                    (cd, cmin, cmax)
                    for cd in child.get_all_descendants()
                ]
                if cmin > cur_min:
                    cur_min = cmin
                if cmax < cur_max:
                    cur_max = cmax
            else:
                remain = count - len(res)
                res += child.get_closest(embed, remain, cache)
        compact()
        return res
    
    def debug(self) -> list[str]:
        outs = [
            child.debug()
            for child in self._children
        ]
        own = f"{'┣' if outs else '┗'}({self._embed_ix})"
        start_ix = 0
        out_lens = []
        for cur_out_ix, cur_outs in enumerate(outs):
            out_len = max((len(out) for out in cur_outs))
            is_last = cur_out_ix >= len(outs) - 1
            own = own.ljust(start_ix + 1, "┓" if is_last else "┳")
            pad_chr = " " if is_last else "━"
            own = own.ljust(start_ix + out_len, pad_chr)
            start_ix += out_len
            out_lens.append(start_ix)
        res = [f"{own} "]
        if not outs:
            return res
        for line_ix in range(max((len(cur_outs) for cur_outs in outs))):
            cur_line = ""
            for cur_out_ix, cur_outs in enumerate(outs):
                out = cur_outs[line_ix] if line_ix < len(cur_outs) else ""
                cur_line = f"{cur_line}{out}".ljust(out_lens[cur_out_ix])
            res.append(f"{cur_line}")
        max_len = max((len(line) for line in res))
        return [line.ljust(max_len) for line in res]


class DBScANN(Generic[T]):
    def __init__(
            self,
            get_all_ix: Callable[[], Iterable[int]],
            get_embed: Callable[[int], T],
            get_dist: Callable[[T, T], float]) -> None:
        self._get_all_ix = get_all_ix
        self._get_embed = get_embed
        self._get_dist = get_dist
        self._lru: LRU[int, T] = LRU(1000)
        self._root: Node | None = None
        
    def get_embed(self, ix: int) -> T:
        res = self._lru.get(ix)
        if res is None:
            res = self._get_embed(ix)
            self._lru.set(ix, res)
        return res
    
    def dist_embed(self, embed: T, ix: int, cache: dict[int, float]) -> float:
        res = cache.get(ix)
        if res is not None:
            return res
        res = self._get_dist(embed, self.get_embed(ix))
        assert res >= 0.0
        cache[ix] = res
        return res
    
    def _centroid(self, all_ixs: list[int]) -> int:
        best_dist = 0.0
        best_ix = None
        for ix in all_ixs:
            cur_dist = 0.0
            embed = self.get_embed(ix)
            for oix in all_ixs:
                if ix == oix:
                    continue
                cache = {}  # no cache
                cur_dist += self.dist_embed(embed, oix, cache)
                if best_ix is not None and cur_dist > best_dist:
                    break
            if best_ix is None or cur_dist < best_dist:
                best_ix = ix
                best_dist = cur_dist
        assert best_ix is not None
        return best_ix
    
    def _kmedoid(self, all_ixs: list[int], k_num: int) -> list[tuple[int, list[int]]]:
        assert len(all_ixs) > k_num
        centroids = all_ixs[:k_num]
        rounds = 1000
        while rounds > 0:
            assignments = [[cix] for cix in centroids]
            for ix in all_ixs:
                if ix in centroids:
                    continue
                best_dist = 0.0
                best_cluster_ix = None
                embed = self.get_embed(ix)
                for cluster_ix, cix in enumerate(centroids):
                    cache = {}  # no cache
                    cur_dist = self.dist_embed(embed, cix, cache)
                    if best_cluster_ix is None or cur_dist < best_dist:
                        best_cluster_ix = cluster_ix
                        best_dist = cur_dist
                assert best_cluster_ix is not None
                assignments[best_cluster_ix].append(ix)
            done = True
            for cluster_ix in range(len(centroids)):
                new_c = self._centroid(assignments[cluster_ix])
                if new_c != centroids[cluster_ix]:
                    centroids[cluster_ix] = new_c
                    done = False
            if done:
                break
            rounds -= 1
        if rounds <= 0:
            print("exhausted iteration steps")
        return list(zip(centroids, assignments))
    
    def _remove(self, all_ixs: list[int], remove_ix: int) -> list[int]:
        return [ix for ix in all_ixs if ix != remove_ix]
    
    def build(self, max_node_size: int) -> None:
        all_ixs = [ix for ix in self._get_all_ix()]
        root_ix = self._centroid(all_ixs)
        all_ixs = self._remove(all_ixs, root_ix)
        
        def build_level(cur_root_ix: int, cur_all_ixs: list[int]) -> Node[T]:
            node = Node(self, cur_root_ix)
            if len(cur_all_ixs) <= max_node_size:
                for child_ix in cur_all_ixs:
                    node.add_child(Node(self, child_ix))
                return node
            children = self._kmedoid(cur_all_ixs, max_node_size)
            for row in children:
                centroid_ix, assignments = row
                assignments = self._remove(assignments, centroid_ix)
                cnode = build_level(centroid_ix, assignments)
                node.add_child(cnode)
            return node
        
        self._root = build_level(root_ix, all_ixs)
    
    def get_closest(self, embed: T, count: int) -> list[tuple[int, float]]:
        assert self._root is not None
        cache = {}
        closest = self._root.get_closest(embed, count, cache)
        print(f"get_closest main comparisons: {len(cache)}")
        res = sorted([
            (row[0].get_ix(), row[0].get_dist(embed, cache))
            for row in closest
        ], key=lambda cur: cur[1])
        print(f"get_closest final comparisons: {len(cache)}")
        return res
    
    def debug(self, from_ix: int | None, to_ix: int | None) -> str:
        return "\n".join((line.rstrip()[from_ix:to_ix] for line in self._root.debug()))

In [175]:
test_size = 100

In [176]:
# dbs = DBScANN(lambda: range(0, x.shape[0]), lambda ix: x[ix, :], lambda a, b: np.log1p(np.exp(-np.dot(a, b))))
dbs = DBScANN(lambda: range(0, test_size), lambda ix: x[ix, :], lambda a, b: np.log1p(np.exp(-np.dot(a, b))))
dbs

<__main__.DBScANN at 0x49e204340>

In [177]:
%%time

dbs.build(max_node_size=10)

CPU times: user 43.3 ms, sys: 5.11 ms, total: 48.4 ms
Wall time: 44.5 ms


In [178]:
print(dbs.debug(0, 120))

┣(38)━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┣(40)━┳━━━━━┳━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓      ┣(82)┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━┳━━━━━━┳━━
┗(12) ┗(16) ┗(21) ┣(49)━┳━━━━━┓      ┗(54) ┗(65) ┗(66) ┗(75) ┣(77)  ┗(83)  ┗(0) ┣(1)━━┓      ┣(70)┓      ┗(8) ┣(37)  ┣(6
                  ┗(88) ┗(91) ┗(96)                          ┗(89)              ┗(46) ┗(78)  ┗(3) ┗(58)       ┗(32)  ┗(7


In [180]:
print(dbs.debug(110, 230))

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━
┳━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━┓                   ┣(2)━━┳━━━━━┓      ┣(22)
┣(37)  ┣(63)━┳━━━━━┓      ┣(17)━┳━━━━━┳━━━━━┓      ┣(24)━┳━━━━━┓      ┗(27) ┣(31)━┳━━━━━┓       ┗(30) ┗(35) ┗(93)  ┗(10)
┗(32)  ┗(71) ┗(81) ┗(87)  ┗(15) ┗(57) ┗(62) ┗(76)  ┗(42) ┗(44) ┗(98)        ┗(14) ┗(36) ┗(99)


In [148]:
test_ix = 10000
test_embed = x[test_ix, :]
test_embed.shape

(768,)

In [138]:
test_dists = np.log1p(np.exp(-np.dot(x[:test_size, :], test_embed)))
test_dists.shape

(100,)

In [139]:
np.argsort(test_dists)[:10]

array([82, 73, 12, 17, 22, 60, 38, 75, 70, 51])

In [140]:
test_dists[np.argsort(test_dists)[:10]]

array([2.71954057e-119, 2.80758998e-119, 3.46409662e-119, 4.15256151e-119,
       4.51581357e-119, 4.57619576e-119, 4.83220119e-119, 4.83843032e-119,
       5.04439785e-119, 5.16901212e-119])

In [141]:
dbs.get_closest(test_embed, 10)

get_closest main comparisons: 73
get_closest final comparisons: 77


[(73, 2.807589980862743e-119),
 (60, 4.576195761482374e-119),
 (38, 4.832201191724879e-119),
 (51, 5.169012115447932e-119),
 (47, 5.385610348584448e-119),
 (74, 6.195297811965087e-119),
 (13, 7.903206067254006e-119),
 (6, 8.033364595124181e-119),
 (67, 1.0662295533384053e-118),
 (18, 1.2968034288577648e-118)]