In [22]:
import numpy as np
from tqdm import tqdm

In [2]:
# now define a function to read the fvecs file format of Sift1M dataset
def read_fvecs(fp):
    a = np.fromfile(fp, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy().view('float32')

In [40]:
# read in data
# data we will search through

xb = read_fvecs('./sift/sift_base.fvecs')  # 1M samples
# also get some query vectors to search with
xq = read_fvecs('./sift/sift_query.fvecs')
# take just one query (there are many in sift_learn.fvecs)
xq = xq[0].reshape(1, xq.shape[1])

groundtruth = read_fvecs('./sift/sift_groundtruth.ivecs')

In [41]:
groundtruth.shape

(10000, 100)

In [4]:
class Node:
    """
    Node for a navigable small world graph.

    Parameters
    ----------
    idx : int
        For uniquely identifying a node.

    value : 1d np.ndarray
        To access the embedding associated with this node.

    neighborhood : set
        For storing adjacent nodes.

    References
    ----------
    https://book.pythontips.com/en/latest/__slots__magic.html
    https://hynek.me/articles/hashes-and-equality/
    """
    __slots__ = ['idx', 'value', 'neighborhood']

    def __init__(self, idx, value):
        self.idx = idx
        self.value = value
        self.neighborhood = set()

    def __hash__(self):
        return hash(self.idx)

    def __eq__(self, other):
        return (
            self.__class__ == other.__class__ and
            self.idx == other.idx
        )

In [44]:
from scipy.spatial import distance


def build_nsw_graph(index_factors, k):
    n_nodes = index_factors.shape[0]

    graph = []
    for i, value in enumerate(index_factors):
        node = Node(i, value)
        graph.append(node)

    for node in tqdm(graph):
        query_factor = node.value.reshape(1, -1)

        # note that the following implementation is not the actual procedure that's
        # used to find the k closest neighbors, we're just implementing a quick version,
        # will come back to this later

        # https://codereview.stackexchange.com/questions/55717/efficient-numpy-cosine-distance-calculation
        # the smaller the cosine distance the more similar, thus the most
        # similar item will be the first element after performing argsort
        # since argsort by default sorts in ascending order
        dist = distance.cdist(index_factors, query_factor, metric='cosine').ravel()
        neighbors_indices = np.argsort(dist)[:k].tolist()
        
        # insert bi-directional connection
        node.neighborhood.update(neighbors_indices)
        for i in neighbors_indices:
            graph[i].neighborhood.add(node.idx)

    return graph

In [45]:
import heapq
import random
from typing import List, Tuple


def nsw_knn_search(
    graph: List[Node],
    query: np.ndarray,
    k: int=5,
    m: int=50) -> Tuple[List[Tuple[float, int]], float]:
    """
    Performs knn search using the navigable small world graph.

    Parameters
    ----------
    graph :
        Navigable small world graph from build_nsw_graph.

    query : 1d np.ndarray
        Query embedding that we wish to find the nearest neighbors.

    k : int
        Number of nearest neighbors returned.

    m : int
        The recall set will be chosen from m different entry points.

    Returns
    -------
    The list of nearest neighbors (distance, index) tuple.
    and the average number of hops that was made during the search.
    """
    result_queue = []
    visited_set = set()
    
    hops = 0
    for _ in range(m):
        # random entry point from all possible candidates
        entry_node = random.randint(0, len(graph) - 1)
        entry_dist = distance.cosine(query, graph[entry_node].value)
        candidate_queue = []
        heapq.heappush(candidate_queue, (entry_dist, entry_node))

        temp_result_queue = []
        while candidate_queue:
            candidate_dist, candidate_idx = heapq.heappop(candidate_queue)

            if len(result_queue) >= k:
                # if candidate is further than the k-th element from the result,
                # then we would break the repeat loop
                current_k_dist, current_k_idx = heapq.nsmallest(k, result_queue)[-1]
                if candidate_dist > current_k_dist:
                    break

            for friend_node in graph[candidate_idx].neighborhood:
                if friend_node not in visited_set:
                    visited_set.add(friend_node)

                    friend_dist = distance.cosine(query, graph[friend_node].value)
                    heapq.heappush(candidate_queue, (friend_dist, friend_node))
                    heapq.heappush(temp_result_queue, (friend_dist, friend_node))
                    hops += 1

        result_queue = list(heapq.merge(result_queue, temp_result_queue))

    return heapq.nsmallest(k, result_queue), hops / m

In [46]:
k = 10
selected_rows = np.random.choice(xb.shape[0], round(0.001*xb.shape[0]), replace=False)
index_factors = xb[selected_rows]
graph = build_nsw_graph(index_factors, k)
graph[0].neighborhood


100%|██████████| 1000/1000 [00:00<00:00, 3597.02it/s]


{0,
 34,
 51,
 81,
 120,
 144,
 187,
 201,
 276,
 290,
 296,
 530,
 562,
 636,
 763,
 879,
 966}

In [None]:
results = nsw_knn_search(graph, xq, k=5)
results

In [42]:
def build_nsw_graph(index_factors: np.ndarray, k: int) -> List[Node]:
    n_nodes = index_factors.shape[0]

    graph = []
    for i, value in enumerate(index_factors):
        node = Node(i, value)
        if i > k:
            neighbors, hops = nsw_knn_search(graph, node.value, k)
            neighbors_indices = [node_idx for _, node_idx in neighbors]
        else:
            neighbors_indices = list(range(i))

        # insert bi-directional connection
        node.neighborhood.update(neighbors_indices)
        for i in neighbors_indices:
            graph[i].neighborhood.add(node.idx)
        
        graph.append(node)

    return graph

In [43]:
k = 10
selected_rows = np.random.choice(xb.shape[0], round(0.001*xb.shape[0]), replace=False)
index_factors = xb[selected_rows]
graph = build_nsw_graph(index_factors, k)
graph[0].neighborhood

{1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 15,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 28,
 30,
 31,
 32,
 38,
 42,
 44,
 45,
 51,
 59,
 61,
 63,
 65,
 71,
 73,
 74,
 75,
 77,
 85,
 95,
 102,
 105,
 109,
 110,
 158,
 173,
 177,
 178,
 183,
 235,
 244,
 247,
 305,
 342,
 384,
 386,
 419,
 443,
 464,
 513,
 732,
 779,
 809,
 889,
 939,
 941,
 978,
 980}

In [12]:
def test():
    def time(graph):
        start = datetime.now()
        for _ in range(100):
            query_point = np.random.rand(10)
            nearest_neighbor = graph.greedy_search(query_point)
        end = datetime.now()
        print(len(graph.nodes))
        print(end - start)
        print()

    def add(graph, node_count):
        for _ in range(node_count):
            graph.add_node(np.random.rand(10))

    nsw1 = NSWGraph()
    nsw2 = NSWGraph()
    nsw3 = NSWGraph()

    add(nsw1, 1000)
    add(nsw2, 2000)
    add(nsw3, 4000)

    time(nsw1)
    time(nsw2)
    time(nsw3)

test()

1000
0:00:00.108000

2000
0:00:00.214994

4000
0:00:00.432016

