In [1]:
import nanopq
import numpy as np
import networkx as nx
from collections import defaultdict
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
from sklearn.datasets import make_blobs
import sys
import faiss
from numba import njit

In [2]:
N, D = 1000000, 128

def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy()


def fvecs_read(fname):
    return ivecs_read(fname).view('float32')

def load_sift():
    print("Loading sift...", end='', file=sys.stderr)
    xb = fvecs_read("sift/sift_base.fvecs")
    xq = fvecs_read("sift/sift_query.fvecs")
    gt = ivecs_read("sift/sift_groundtruth.ivecs")
    print("done", file=sys.stderr)

    return xb, xq, gt

def load_gist():
    print("Loading gist...", end='', file=sys.stderr)
    xb = fvecs_read("gist/gist_base.fvecs")
    xq = fvecs_read("gist/gist_query.fvecs")
    gt = ivecs_read("gist/gist_groundtruth.ivecs")
    print("done", file=sys.stderr)

    return xb, xq, gt


vectors_base, queries, gt = load_sift()

Loading sift...done


In [7]:
   
def generate_graph(vectors, k_nearest):
    index = faiss.IndexFlatL2(vectors.shape[1])  # длина вектора
    index.add(vectors.astype('float32'))
    _, indices = index.search(vectors.astype('float32'), k_nearest)
    G = nx.Graph()
    for i in tqdm(range(len(vectors)), total=len(vectors)):
        for index in indices[i]:
            if index != i:
                G.add_edge(i, index)
    return G


k = 35    # количество ближайших соседей для связывания вершин


# генерируем граф
G = generate_graph(vectors_base, k)


100%|██████████| 1000000/1000000 [01:00<00:00, 16397.20it/s]


In [26]:
G_nodes = G.nodes()
G_edges = {}
for node in G_nodes:
    G_edges[node] = list(G.edges(node))


In [27]:
pq = nanopq.PQ(M=2, Ks=128, verbose=False)

pq.fit(vectors_base)
X_code = pq.encode(vectors_base)

In [28]:
indexes_map = defaultdict(list)

for i in range(len(X_code)):
    indexes_map[(X_code[i][0], X_code[i][1])].append(i)

In [29]:
def bfs(query, start, G_edges, vectors):
    queue = []
    queue.append([start, 0])
    best_node = start
    best_dist = np.linalg.norm(vectors[best_node] - query)
    was = set()
    was.add(best_node)
    while len(queue) > 0:
        node = queue.pop(0)
        if node[1] > 0:
            continue
        for edge in G_edges[node[0]]:
            dst = edge[1]
            if dst in was:
                continue
            was.add(dst)
            queue.append([dst, node[1] + 1])
            dist = np.linalg.norm(vectors[dst] - query)
            if dist < best_dist:
                best_node = dst
                best_dist = dist
    return best_node, best_dist


def find_nearest(query, query_mi, G_edges=G_edges, G_nodes=G_nodes, indexes_map=indexes_map, vectors=vectors_base):
    query_mi_neighbours = indexes_map[(query_mi[0][0], query_mi[0][1])]
    if len(query_mi_neighbours) > 0:
        best_node = query_mi_neighbours[0]
    else:
        best_node = np.random.choice(G_nodes)
    best_dist = np.linalg.norm(vectors[best_node] - query)
    queue = []
    queue.append(best_node)
    was = set()
    was.add(best_node)
    while len(queue) > 0:
        node = queue.pop(0)
        for edge in G_edges[node]:
            dst = edge[1]
            if dst in was:
                continue
            was.add(dst)
            dist = np.linalg.norm(vectors[dst] - query)
            if dist < best_dist:
                queue.append(dst)
                best_node = dst
                best_dist = dist
    return best_node, best_dist
    return bfs(query, best_node, G_edges, vectors)


In [20]:
index = faiss.IndexHNSWFlat(D, 32)
index.hnsw.efConstruction = 40

index.train(vectors_base)
print(index.ntotal)   # 0
index.add(vectors_base)
print(index.ntotal)   # 1000000

0
1000000


In [31]:
import time

dist_better_my = 0
dist_better_hnsw = 0
dist_equal = 0

gt_good_my = 0
gt_good_hnsw = 0

result_my = []
result_hnsw = []

start = time.time()

for test_number in tqdm(range(len(queries))):
    query = np.array([queries[test_number]])
    query_mi = pq.encode(query).astype(np.int32)
    d1, i1 = find_nearest(query, query_mi, G_edges, G_nodes, indexes_map, vectors_base)
    result_my.append([d1, i1])

end = time.time()

print("my time:", end-start)
start = time.time()
for test_number in range(len(queries)):
    d2, i2 = index.search(np.array([queries[test_number]]), 1)
    d2 = np.linalg.norm(vectors_base[i2[0][0]] - queries[test_number])
    result_hnsw.append([d2, i2[0][0]])
end = time.time()
print("hnsw time:", end-start)

for test_number in range(len(queries)):
    d1, i1 = result_my[test_number][0], result_my[test_number][1]
    gt_good_my += gt[test_number][0] == i1
    
    d2, i2 = result_hnsw[test_number][0], result_hnsw[test_number][1]
    gt_good_hnsw += gt[test_number][0] == i2
    
    dist_better_my += d1 < d2
    dist_equal += d1 == d2
    dist_better_hnsw += d1 > d2
print(dist_better_my, dist_equal, dist_better_hnsw)
print(gt_good_my, gt_good_hnsw)

100%|██████████| 10000/10000 [00:15<00:00, 658.05it/s]


my time: 15.19741702079773
hnsw time: 2.7055742740631104
10 0 9990
0 9080


fi

my time: 5.14903450012207

hnsw time: 2.6712119579315186

157 4900 4943

4839 9092

base

my time: 18.093042612075806

hnsw time: 2.7155497074127197

353 6644 3003

6700 9092


v2 

my time: 14.344702243804932

hnsw time: 2.6752655506134033

273 6232 3495

6215 9092

In [164]:
T = 21
print(find_nearest(G, np.array([queries[T]])))
print(index.search(np.array([queries[T]]), 1)[1][0][0])
print(gt[T])

(173.01733, 337194)
4490
[  4490   4457 440896 105025  30046 214449 214475 234950 337194 566609
 554601 559990 497354 337027 306259  59773 365266 776961 261517  30205
 498443 104882 293784 787655 187642 161744  94526 290612 190535 497211
 104971 554692 321611 554617 523702  99066 151645 501915 710368 525244
 261426 570662  94528 114300 278717 190529 366228 337168 365353 312454
 776967 825736 563905 547063 357803  51303 122700 805687 776965 152427
 498491 869260 124074 106313 534926  42882 337012 122254 494271 825636
  94532 151174 428246 790898  51399 365529 674780 365285 560050  52746
 526223 295102 122033 292386 151148 337470 689876 776959 385048 302199
 525423 160586 427618 852710 152142 905096  90283 571084 305080 114777]


In [165]:
pq.encode(np.array([queries[T]]))


array([[85,  9]], dtype=uint8)

In [166]:
pq.encode(np.array([vectors_base[337194]]))


array([[85,  9]], dtype=uint8)

In [167]:
pq.encode(np.array([vectors_base[4490]]))


array([[24,  9]], dtype=uint8)