In [95]:
import numpy as np
from tqdm import tqdm
from scipy.spatial import distance

In [130]:
# now define a function to read the fvecs file format of Sift1M dataset
def read_fvecs(fp):
    a = np.fromfile(fp, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy().view('float32')

def read_ivecs(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy()

In [203]:
def calculate_recall(predicted_neighbors, actual_neighbors):
    total_recall = 0
    
    for pred, actual in zip(predicted_neighbors, actual_neighbors):
        true_positives = len(set(pred) & set(actual))
        possible_positives = len(set(actual))

        recall = true_positives / possible_positives if possible_positives else 0

        total_recall += recall

    average_recall = total_recall / len(actual_neighbors)

    return average_recall

In [127]:
# read in data
# data we will search through

base = read_fvecs('./siftsmall/siftsmall_base.fvecs')  # 1M samples
# also get some query vectors to search with
query = read_fvecs('./siftsmall/siftsmall_query.fvecs')
# take just one query (there are many in sift_learn.fvecs)
# xq = xq[0].reshape(1, xq.shape[1])

groundtruth = read_ivecs('./siftsmall/siftsmall_groundtruth.ivecs')

In [99]:
class Node:
    """
    Node for a navigable small world graph.

    Parameters
    ----------
    idx : int
        For uniquely identifying a node.

    value : 1d np.ndarray
        To access the embedding associated with this node.

    neighborhood : set
        For storing adjacent nodes.

    References
    ----------
    https://book.pythontips.com/en/latest/__slots__magic.html
    https://hynek.me/articles/hashes-and-equality/
    """
    __slots__ = ['idx', 'value', 'neighborhood']

    def __init__(self, idx, value):
        self.idx = idx
        self.value = value
        self.neighborhood = set()

    def __hash__(self):
        return hash(self.idx)

    def __eq__(self, other):
        return (
            self.__class__ == other.__class__ and
            self.idx == other.idx
        )

In [100]:
import heapq
import random
from typing import List, Tuple


def greedy_search(
    graph: List[Node],
    query: np.ndarray,
    k: int=5,
    m: int=50) -> Tuple[List[Tuple[float, int]], float]:
    """
    Performs knn search using the navigable small world graph.

    Parameters
    ----------
    graph :
        Navigable small world graph from build_nsw_graph.

    query : 1d np.ndarray
        Query embedding that we wish to find the nearest neighbors.

    k : int
        Number of nearest neighbors returned.

    m : int
        The recall set will be chosen from m different entry points.

    Returns
    -------
    The list of nearest neighbors (distance, index) tuple.
    and the average number of hops that was made during the search.
    """
    result_queue = []
    visited_set = set()
    
    hops = 0
    for _ in range(m):
        # random entry point from all possible candidates
        entry_node = random.randint(0, len(graph) - 1)
        entry_dist = distance.cosine(query, graph[entry_node].value)
        candidate_queue = []
        heapq.heappush(candidate_queue, (entry_dist, entry_node))

        temp_result_queue = []
        while candidate_queue:
            candidate_dist, candidate_idx = heapq.heappop(candidate_queue)

            if len(result_queue) >= k:
                # if candidate is further than the k-th element from the result,
                # then we would break the repeat loop
                current_k_dist, current_k_idx = heapq.nsmallest(k, result_queue)[-1]
                if candidate_dist > current_k_dist:
                    break

            for friend_node in graph[candidate_idx].neighborhood:
                if friend_node not in visited_set:
                    visited_set.add(friend_node)

                    friend_dist = distance.cosine(query, graph[friend_node].value)
                    heapq.heappush(candidate_queue, (friend_dist, friend_node))
                    heapq.heappush(temp_result_queue, (friend_dist, friend_node))
                    hops += 1

        result_queue = list(heapq.merge(result_queue, temp_result_queue))

    return heapq.nsmallest(k, result_queue), hops / m

In [108]:
def beam_search(
    graph: List[Node],
    query: np.ndarray,
    k: int = 5,
    m: int = 50,
    beam_width: int = 10) -> Tuple[List[Tuple[float, int]], float]:
    """
    Performs knn search using beam search on the navigable small world graph.

    Parameters
    ----------
    graph :
        Navigable small world graph from build_nsw_graph.

    query : 1d np.ndarray
        Query embedding that we wish to find the nearest neighbors.

    k : int
        Number of nearest neighbors returned.

    m : int
        The recall set will be chosen from m different entry points.

    beam_width : int
        Number of nodes to consider at each level of the search.

    Returns
    -------
    The list of nearest neighbors (distance, index) tuple.
    and the average number of hops that was made during the search.
    """
    result_queue = []
    visited_set = set()

    hops = 0
    for _ in range(m):
        entry_node = random.randint(0, len(graph) - 1)
        entry_dist = distance.cosine(query, graph[entry_node].value)
        candidate_queue = []
        heapq.heappush(candidate_queue, (entry_dist, entry_node))

        while candidate_queue:
            temp_result_queue = []
            # Consider up to beam_width best candidates
            for _ in range(min(beam_width, len(candidate_queue))):
                candidate_dist, candidate_idx = heapq.heappop(candidate_queue)

                if len(result_queue) >= k:
                    current_k_dist, _ = heapq.nsmallest(k, result_queue)[-1]
                    if candidate_dist > current_k_dist:
                        break

                for friend_node in graph[candidate_idx].neighborhood:
                    if friend_node not in visited_set:
                        visited_set.add(friend_node)
                        friend_dist = distance.cosine(query, graph[friend_node].value)
                        heapq.heappush(candidate_queue, (friend_dist, friend_node))
                        heapq.heappush(temp_result_queue, (friend_dist, friend_node))
                        hops += 1

            result_queue = list(heapq.merge(result_queue, temp_result_queue))

    return heapq.nsmallest(k, result_queue), hops / m

In [101]:
def build_nsw_graph(index_factors: np.ndarray, k: int) -> List[Node]:
    n_nodes = index_factors.shape[0]
    tqdm_loader = tqdm(index_factors)
    tqdm_loader.set_description("Building Graph")
    graph = []
    for i, value in enumerate(tqdm_loader):
        node = Node(i, value)
        if i > k:
            neighbors, hops = greedy_search(graph, node.value, k)
            neighbors_indices = [node_idx for _, node_idx in neighbors]
        else:
            neighbors_indices = list(range(i))

        # insert bi-directional connection
        node.neighborhood.update(neighbors_indices)
        for i in neighbors_indices:
            graph[i].neighborhood.add(node.idx)
        
        graph.append(node)

    return graph

In [103]:
k = 10
# selected_rows = np.random.choice(xb.shape[0], round(0.001*xb.shape[0]), replace=False)
# index_factors = xb[selected_rows]
graph = build_nsw_graph(base, k)

Building Graph: 100%|██████████| 10000/10000 [40:28<00:00,  4.12it/s]


In [104]:
import pickle
with open("graph.pkl", "wb") as f:
    pickle.dump(graph, f)

In [123]:
with open('graph.pkl', 'rb') as f:
    objects = pickle.load(f)

In [154]:
len(query)

100

In [193]:
k = 5
results_greedy = []
results_beam = []
for q in tqdm(query):
  g = [r[1] for r in greedy_search(graph, q, k=k)[0]]
  b = [r[1] for r in beam_search(graph, q, k=k)[0]]
  results_greedy.append(g)
  results_beam.append(b)

100%|██████████| 100/100 [00:51<00:00,  1.93it/s]


In [194]:
true = groundtruth[:, :k]

In [204]:
average_recall = calculate_recall(results_greedy, true)
print(average_recall)

0.9919999999999999


In [206]:
average_recall = calculate_recall(results_beam, true)
print(average_recall)

0.9059999999999995


In [12]:
def test():
    def time(graph):
        start = datetime.now()
        for _ in range(100):
            query_point = np.random.rand(10)
            nearest_neighbor = graph.greedy_search(query_point)
        end = datetime.now()
        print(len(graph.nodes))
        print(end - start)
        print()

    def add(graph, node_count):
        for _ in range(node_count):
            graph.add_node(np.random.rand(10))

    nsw1 = NSWGraph()
    nsw2 = NSWGraph()
    nsw3 = NSWGraph()

    add(nsw1, 1000)
    add(nsw2, 2000)
    add(nsw3, 4000)

    time(nsw1)
    time(nsw2)
    time(nsw3)

test()

1000
0:00:00.108000

2000
0:00:00.214994

4000
0:00:00.432016

