In [1]:
import heapq
import numpy as np
import random
import tqdm
import matplotlib.pyplot as plt

### Settings

In [2]:
K = 20
N = 1000
D = 50
sim = lambda x, y: np.sqrt(np.sum((x - y)**2))

### Dataset

In [3]:
random.seed(1234)
np.random.seed(1234)
V = np.array([np.random.randn(D) for _ in range(N)])
print(V.shape)

(1000, 50)


### Exact

In [4]:
answer = [[] for _ in range(N)]
for i in tqdm.tqdm(range(N)):
    hq = []
    for j in range(N):
        dist = sim(V[i], V[j])
        heapq.heappush(hq, (dist, j))
    for k in range(K+1):
        dist, j = heapq.heappop(hq)
        answer[i].append(j)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 98.25it/s]


### Utils

In [5]:
def recall(exact, result):
    total = 0
    right = 0
    for r1, r2 in zip(exact, result):
        right += len(set(r1[1:K+1]).intersection(set(r2)))
        total += K
    return right/total

In [6]:
def update_pq(pqs):
    tmpqs = []
    for idx, pq in enumerate(pqs):
        res = []
        bef = -1
        while len(pq) > 0 and len(res) < K:
            dist, n = heapq.heappop(pq)
            if bef == n:
                continue
            heapq.heappush(res, (dist, n))
            bef = n
        tmpqs.append(res)
    pqs = tmpqs

In [7]:
def knn(nodes, pqs):
    for a in nodes:
        for b in nodes:
            if a==b: continue
            heapq.heappush(pqs[a], (sim(V[a], V[b]), b))

In [8]:
def divide(nodes, pqs, depth):
    if len(nodes) <= threshold:
        knn(nodes, pqs)
        return
    
    centroids = random.sample(nodes, div)
    divs = [[] for _ in range(div)]
    for PLOTi, idx in enumerate(nodes):
        dist, i = min((sim(V[idx], V[cen]), i) for i, cen in enumerate(centroids))
        divs[i].append(idx)

    for nxt in divs:
        divide(nxt, pqs, depth+1)

### Algorithm

In [9]:
EPOCH = 20
div = 2
threshold = 150

In [10]:
random.seed(1234)
pqs = [[] for _ in range(N)]
for _ in tqdm.tqdm(range(EPOCH)):
    divide(list(range(N)), pqs, 0)

result = []
for idx, pq in enumerate(pqs):
    res = []
    bef = -1
    while len(pq) > 0 and len(res) < K:
        dist, n = heapq.heappop(pq)
        if bef == n:
            continue
        res.append(n)
        bef = n
    result.append(res)
    if len(res) < K:
        print("부족!", idx)
    if len(res) != len(set(res)):
        print("중복!", idx)
        
print(recall(answer, result))

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:22<00:00,  1.11s/it]

0.98035



