In [95]:
import numpy as np
from tqdm import tqdm
from scipy.spatial import distance

In [85]:
# now define a function to read the fvecs file format of Sift1M dataset
def read_fvecs(fp):
    a = np.fromfile(fp, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy().view('float32')

In [98]:
# read in data
# data we will search through

base = read_fvecs('./siftsmall/siftsmall_base.fvecs')  # 1M samples
# also get some query vectors to search with
query = read_fvecs('./siftsmall/siftsmall_query.fvecs')
# take just one query (there are many in sift_learn.fvecs)
# xq = xq[0].reshape(1, xq.shape[1])

groundtruth = read_fvecs('./siftsmall/siftsmall_groundtruth.ivecs')

In [122]:
groundtruth

array([[3.0492e-42, 5.2577e-42, 1.2359e-42, ..., 4.8765e-43, 4.2642e-42,
        5.1666e-42],
       [3.8970e-42, 1.3416e-41, 3.4920e-42, ..., 5.3936e-42, 4.0708e-42,
        5.7481e-42],
       [3.7933e-42, 1.3926e-41, 3.7807e-42, ..., 1.7530e-42, 1.2001e-41,
        1.1453e-41],
       ...,
       [1.2366e-41, 1.2725e-41, 8.6068e-42, ..., 1.1460e-41, 8.2494e-42,
        6.3969e-42],
       [7.6511e-42, 7.6217e-42, 8.1415e-42, ..., 7.2854e-42, 1.0486e-41,
        7.3316e-42],
       [1.1325e-41, 1.2306e-41, 6.6800e-42, ..., 1.5414e-44, 3.4780e-42,
        5.0881e-42]], dtype=float32)

In [99]:
class Node:
    """
    Node for a navigable small world graph.

    Parameters
    ----------
    idx : int
        For uniquely identifying a node.

    value : 1d np.ndarray
        To access the embedding associated with this node.

    neighborhood : set
        For storing adjacent nodes.

    References
    ----------
    https://book.pythontips.com/en/latest/__slots__magic.html
    https://hynek.me/articles/hashes-and-equality/
    """
    __slots__ = ['idx', 'value', 'neighborhood']

    def __init__(self, idx, value):
        self.idx = idx
        self.value = value
        self.neighborhood = set()

    def __hash__(self):
        return hash(self.idx)

    def __eq__(self, other):
        return (
            self.__class__ == other.__class__ and
            self.idx == other.idx
        )

In [100]:
import heapq
import random
from typing import List, Tuple


def greedy_search(
    graph: List[Node],
    query: np.ndarray,
    k: int=5,
    m: int=50) -> Tuple[List[Tuple[float, int]], float]:
    """
    Performs knn search using the navigable small world graph.

    Parameters
    ----------
    graph :
        Navigable small world graph from build_nsw_graph.

    query : 1d np.ndarray
        Query embedding that we wish to find the nearest neighbors.

    k : int
        Number of nearest neighbors returned.

    m : int
        The recall set will be chosen from m different entry points.

    Returns
    -------
    The list of nearest neighbors (distance, index) tuple.
    and the average number of hops that was made during the search.
    """
    result_queue = []
    visited_set = set()
    
    hops = 0
    for _ in range(m):
        # random entry point from all possible candidates
        entry_node = random.randint(0, len(graph) - 1)
        entry_dist = distance.cosine(query, graph[entry_node].value)
        candidate_queue = []
        heapq.heappush(candidate_queue, (entry_dist, entry_node))

        temp_result_queue = []
        while candidate_queue:
            candidate_dist, candidate_idx = heapq.heappop(candidate_queue)

            if len(result_queue) >= k:
                # if candidate is further than the k-th element from the result,
                # then we would break the repeat loop
                current_k_dist, current_k_idx = heapq.nsmallest(k, result_queue)[-1]
                if candidate_dist > current_k_dist:
                    break

            for friend_node in graph[candidate_idx].neighborhood:
                if friend_node not in visited_set:
                    visited_set.add(friend_node)

                    friend_dist = distance.cosine(query, graph[friend_node].value)
                    heapq.heappush(candidate_queue, (friend_dist, friend_node))
                    heapq.heappush(temp_result_queue, (friend_dist, friend_node))
                    hops += 1

        result_queue = list(heapq.merge(result_queue, temp_result_queue))

    return heapq.nsmallest(k, result_queue), hops / m

In [108]:
def beam_search(
    graph: List[Node],
    query: np.ndarray,
    k: int = 5,
    m: int = 50,
    beam_width: int = 10) -> Tuple[List[Tuple[float, int]], float]:
    """
    Performs knn search using beam search on the navigable small world graph.

    Parameters
    ----------
    graph :
        Navigable small world graph from build_nsw_graph.

    query : 1d np.ndarray
        Query embedding that we wish to find the nearest neighbors.

    k : int
        Number of nearest neighbors returned.

    m : int
        The recall set will be chosen from m different entry points.

    beam_width : int
        Number of nodes to consider at each level of the search.

    Returns
    -------
    The list of nearest neighbors (distance, index) tuple.
    and the average number of hops that was made during the search.
    """
    result_queue = []
    visited_set = set()

    hops = 0
    for _ in range(m):
        entry_node = random.randint(0, len(graph) - 1)
        entry_dist = distance.cosine(query, graph[entry_node].value)
        candidate_queue = []
        heapq.heappush(candidate_queue, (entry_dist, entry_node))

        while candidate_queue:
            temp_result_queue = []
            # Consider up to beam_width best candidates
            for _ in range(min(beam_width, len(candidate_queue))):
                candidate_dist, candidate_idx = heapq.heappop(candidate_queue)

                if len(result_queue) >= k:
                    current_k_dist, _ = heapq.nsmallest(k, result_queue)[-1]
                    if candidate_dist > current_k_dist:
                        break

                for friend_node in graph[candidate_idx].neighborhood:
                    if friend_node not in visited_set:
                        visited_set.add(friend_node)
                        friend_dist = distance.cosine(query, graph[friend_node].value)
                        heapq.heappush(candidate_queue, (friend_dist, friend_node))
                        heapq.heappush(temp_result_queue, (friend_dist, friend_node))
                        hops += 1

            result_queue = list(heapq.merge(result_queue, temp_result_queue))

    return heapq.nsmallest(k, result_queue), hops / m

In [101]:
def build_nsw_graph(index_factors: np.ndarray, k: int) -> List[Node]:
    n_nodes = index_factors.shape[0]
    tqdm_loader = tqdm(index_factors)
    tqdm_loader.set_description("Building Graph")
    graph = []
    for i, value in enumerate(tqdm_loader):
        node = Node(i, value)
        if i > k:
            neighbors, hops = greedy_search(graph, node.value, k)
            neighbors_indices = [node_idx for _, node_idx in neighbors]
        else:
            neighbors_indices = list(range(i))

        # insert bi-directional connection
        node.neighborhood.update(neighbors_indices)
        for i in neighbors_indices:
            graph[i].neighborhood.add(node.idx)
        
        graph.append(node)

    return graph

In [103]:
k = 10
# selected_rows = np.random.choice(xb.shape[0], round(0.001*xb.shape[0]), replace=False)
# index_factors = xb[selected_rows]
graph = build_nsw_graph(base, k)

Building Graph: 100%|██████████| 10000/10000 [40:28<00:00,  4.12it/s]


In [104]:
import pickle
with open("graph.pkl", "wb") as f:
    pickle.dump(graph, f)

In [123]:
with open('graph.pkl', 'rb') as f:
    objects = pickle.load(f)

In [None]:
results1 = greedy_search(objects, query[0], k=5)
results1

In [109]:
results = greedy_search(graph, query[0], k=5)
results

([(0.14819204807281494, 2176),
  (0.1492038369178772, 3752),
  (0.15418195724487305, 882),
  (0.15591514110565186, 4009),
  (0.15687423944473267, 2837)],
 200.0)

In [114]:
groundtruth.shape

(100, 100)

In [116]:
groundtruth[0]

array([3.0492e-42, 5.2577e-42, 1.2359e-42, 5.6178e-42, 3.9755e-42,
       2.6625e-43, 5.0657e-42, 1.1435e-42, 1.4644e-42, 2.6400e-42,
       3.1389e-43, 4.2221e-42, 4.0918e-43, 1.7825e-42, 7.4367e-42,
       6.9196e-42, 1.8147e-42, 6.8944e-43, 1.2907e-41, 5.0797e-42,
       1.7572e-42, 1.8105e-42, 2.2771e-42, 4.9788e-42, 1.6199e-42,
       2.0459e-43, 1.4994e-43, 7.3302e-42, 2.7956e-42, 1.3370e-41,
       4.9648e-42, 1.3674e-41, 1.3741e-41, 1.4910e-42, 1.3594e-41,
       5.6949e-42, 3.4416e-42, 3.8718e-42, 4.5360e-42, 1.8455e-42,
       4.9466e-42, 8.9823e-43, 2.3962e-42, 1.2453e-41, 5.9737e-42,
       2.4607e-42, 8.3798e-43, 5.1848e-43, 3.8900e-42, 1.6956e-43,
       5.6865e-42, 1.0152e-41, 2.6555e-42, 1.7376e-43, 1.2235e-41,
       9.7530e-43, 6.0536e-42, 6.3437e-42, 5.6753e-42, 3.7106e-42,
       2.3570e-42, 3.0184e-42, 2.3668e-42, 3.4136e-42, 2.8096e-42,
       4.4982e-42, 5.6080e-42, 3.8872e-42, 1.2948e-42, 9.2906e-42,
       4.8331e-42, 1.3752e-41, 4.9256e-42, 7.5320e-42, 4.0217e

In [110]:
results = beam_search(graph, query[0], k=5, beam_width = 10)
results

([(0.14819204807281494, 2176),
  (0.15418195724487305, 882),
  (0.15591514110565186, 4009),
  (0.15687423944473267, 2837),
  (0.16574138402938843, 190)],
 4.24)

In [12]:
def test():
    def time(graph):
        start = datetime.now()
        for _ in range(100):
            query_point = np.random.rand(10)
            nearest_neighbor = graph.greedy_search(query_point)
        end = datetime.now()
        print(len(graph.nodes))
        print(end - start)
        print()

    def add(graph, node_count):
        for _ in range(node_count):
            graph.add_node(np.random.rand(10))

    nsw1 = NSWGraph()
    nsw2 = NSWGraph()
    nsw3 = NSWGraph()

    add(nsw1, 1000)
    add(nsw2, 2000)
    add(nsw3, 4000)

    time(nsw1)
    time(nsw2)
    time(nsw3)

test()

1000
0:00:00.108000

2000
0:00:00.214994

4000
0:00:00.432016

