In [2]:
import numpy as np
import heapq

In [139]:
d = 5
data = np.random.normal(0, 20, (100, d))
np.random.shuffle(data)
q = np.random.normal(0, 20, d)
k = 3

In [102]:
def distance(d, q):
    return np.add.reduce((d-q)**2, axis=1)

### baseline KNN

In [115]:
data1 = distance(data,q)
heap = []
for i, d in enumerate(data1.tolist()):
    heapq.heappush(heap, (-d, i))
    if len(heap) > k:
        heapq.heappop(heap)
ind = np.array(heap)[:,1].astype(int) # 가장 가까운 원소 k개의 index를 반환

In [128]:
print(f'데이터 q:\t {q}')
for i, d in zip(ind, data[ind]):
    print(f'데이터 {i}:\t {d}, 거리: {np.sqrt(np.sum((d-q)**2))}')

데이터 q:	 [ 32.48752858   1.09566161   0.06163878 -12.25777447  -8.52109092]
데이터 7:	 [ 13.65714253   5.80998071  -1.10041929 -12.2630388  -26.14216378], 거리: 26.242348578844304
데이터 75:	 [  9.97696162  -6.00910521   0.75921347 -10.34061601  -6.03531261], 거리: 23.823193257580552
데이터 92:	 [ 20.39456905   5.03509057 -16.70913761 -14.53920643  -1.50416001], 거리: 22.303810498177864


### KNN with KD-tree

In [134]:
def dist(p: np.ndarray, b: np.ndarray) -> np.ndarray:
    '''
    p: (k)-dim np.ndarray
    b: (k,1) or (k,2)-dim np.ndarray
    out: (1)-dim np.ndarray
    '''
    if b.ndim > 1: # len(b.shape)
        temp = ((b-p)>0)
        mask1 = np.all(temp, axis=1) # b>p
        mask2 = np.all(~temp, axis=1) # b<p
        blp = b[mask1][:,0]-p[mask1]
        bgp = b[mask2][:,1]-p[mask2]
        return np.sum(blp**2)+np.sum(bgp**2)
    else:
        dist_sq = np.sum((p - b.ravel())**2)
        return np.sqrt(dist_sq)

In [188]:
ls = 2 # leaf size
def c_kdt(d: np.ndarray, depth: int):
    size = d.shape[0] # array length
    if size <= ls:
        return d
    dim = d.shape[1] # point dimension
    r = depth % dim # selected axis
    mi = np.argpartition(d[:,r], kth=size//2)[size//2] # TODO: 이거 좀 더 이해하기. 대체 뭐임??
    p = d[mi]
    rd = np.delete(d, mi, axis=0)
    ml = rd[:,r] <= p[r]
    return (p, c_kdt(rd[ml], depth+1), c_kdt(rd[~ml], depth+1))
kdtree = c_kdt(data, 0)

In [196]:
def is_leaf(tree):
    return not isinstance(tree, tuple)

#### One Queue

In [201]:
def kdt_search(kdtree: tuple, query, k):
    pq = [(0, 'node', kdtree)]
    heapq.heapify(pq)
    result = [] # Contain (dist, point) in ascending order
    while pq and len(result) < k:
        key, kind, tr = heapq.heappop(pq)
        if kind == 'point':
            result.append((key, tr))
            continue

        if is_leaf(tr):
            for point in tr:
                d = dist(query, point)
                heapq.heappush(pq, (d, 'point', point))
        else:
            lbl = dist(tr[1][0], query)
            lbr = dist(tr[2][0], query)
            heapq.heappush(pq, (lbl, 'node', tr[1]))
            heapq.heappush(pq, (lbr, 'node', tr[2]))
    return result

In [202]:
kdt_search(kdtree, q, k)

[(np.float64(38.832389101279894),
  array([ -3.17110616, -23.88504447,  11.84150113,  36.6289676 ,
         -31.41702688])),
 (np.float64(40.79060663402517),
  array([-12.9246372 , -11.12465565,   6.0563385 ,  18.68290921,
          -7.84197982])),
 (np.float64(41.2315872117607),
  array([-12.69452964, -30.86323103,  12.48987864,  -6.60664743,
         -23.95672778]))]

#### Two Queue

In [220]:
def kdt_search2(kdtree: tuple, query, k):
    pq_nodes = [(0, kdtree)]
    pq_nbrs = []
    tau = np.inf
    while pq_nodes and pq_nodes[0][0] < tau:
        _, node = heapq.heappop(pq_nodes)
        if is_leaf(node):
            for point in node:
                d = dist(query, point)
                heapq.heappush(pq_nbrs, (-d, point))
                if len(pq_nbrs) > k:
                    heapq.heappop(pq_nbrs)
                tau = -pq_nbrs[0][0]
        else:
            lbl = dist(node[1][0], query)
            lbr = dist(node[2][0], query)
            # 분할점 누락
            if lbl < tau:
                heapq.heappush(pq_nodes, (lbl, node[1])) # box를 가지고 거리를 계산해야 되는데, 내부점을 가지고 거리를 계산하고 앉았음...
            if lbr < tau:
                heapq.heappush(pq_nodes, (lbr, node[2]))
    return sorted(pq_nbrs, key=lambda t: -t[0])

In [219]:
kdt_search2(kdtree, q, k) # 지금은 수정하기 어려운 버그가 있음...

[(np.float64(-38.832389101279894),
  array([ -3.17110616, -23.88504447,  11.84150113,  36.6289676 ,
         -31.41702688])),
 (np.float64(-40.79060663402517),
  array([-12.9246372 , -11.12465565,   6.0563385 ,  18.68290921,
          -7.84197982]))]

### ANN (DiskANN)