In [2]:
import numpy as np
import faiss
from faiss.contrib.datasets import SyntheticDataset, DatasetSIFT1M
from faiss.contrib.inspect_tools import get_NSG_neighbors
import bz2, io

In [72]:
# ds = SyntheticDataset(128, 0, 1_000_000, 10_000)
ds = DatasetSIFT1M()
print(ds)

dataset in dimension 128, with metric L2, size: Q 10000 B 1000000 T 100000


# Make an NSG index

NSG is a graph based indexing method. It is less popular than HSNW but it is simpler because it's a single graph rather than a layered set of graphs.

In [80]:
# create the C++ index: a NSG graph with 16 neighbors per node
index = faiss.index_factory(ds.d, "NSG64")

In [81]:
index.add(ds.get_database())

In [82]:
graph = get_NSG_neighbors(index.nsg)

In [83]:
graph.shape

(1000000, 64)

# Search

In [78]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
import heapq

type NDfloat = NDArray[np.floating]
type NDint = NDArray[np.integer]

def compute_distance[T: np.floating](x: NDArray[T], y: NDArray[T]) -> NDArray[T]:
    '''Last axis must be vector dimensions.'''
    return np.square(x - y).sum(axis=-1)

def graph_search(graph: NDint, entrypoint: np.integer, query: np.floating, max_steps: int) -> tuple[NDint, NDfloat]:
    visited_ids: set[np.integer] = set()
    query = query[None, :] # shape=(edges, d)
    path_ids: NDint = -1*np.ones(max_steps, dtype=np.int32)
    path_distances: NDfloat = np.full(max_steps, np.nan, dtype=np.float32)
    min_distance: float = float('inf')

    for i in range(max_steps):
        endpoint_ids = graph[entrypoint] # shape=(edges, d)
        visited_ids.add(entrypoint)

        # drop invalid ids
        endpoint_ids = endpoint_ids[endpoint_ids != -1]
        endpoint_values = index.reconstruct_batch(endpoint_ids) # shape=(edges, d)
        distances = compute_distance(endpoint_values, query)  # shape=(edges,)

        for j, rank_id in enumerate(np.argsort(distances)):
            if endpoint_ids[rank_id] not in visited_ids and distances[rank_id] < min_distance:
                path_ids[i] = endpoint_ids[rank_id]
                path_distances[i] = min_distance = distances[rank_id]
                entrypoint = endpoint_ids[rank_id]
                break
            elif distances[rank_id] >= min_distance:
                break

        if j == len(distances) - 1:
            print(visited_ids)
            print(path_ids)
            print(entrypoint)
            raise Exception

    return path_ids, path_distances


def greedy_routing(graph: NDint, entrypoint: np.integer, query: np.floating, max_steps: int, heap_size: int) -> tuple[NDint, NDfloat]:
    heap: list[tuple[(np.float32, np.int32)]] = list()
    visited_ids: set[np.integer] = set()
    query = query[None, :] # shape=(edges, d)
    path_ids: NDint = -1*np.ones(max_steps, dtype=np.int32)
    path_distances: NDfloat = np.full(max_steps, np.nan, dtype=np.float32)

    heapq.heappush(heap, (compute_distance(entrypoint, query), entrypoint))

    for i in range(max_steps):
        dist, v_id = heapq.heappop(heap)
        endpoint_ids = graph[v_id] # shape=(edges, d)
        visited_ids.add(v_id)

        # drop invalid or visited ids
        endpoint_ids = endpoint_ids[endpoint_ids != -1]
        endpoint_ids = np.array([id for id in endpoint_ids if id not in visited_ids])

        # compute distances
        endpoint_values = index.reconstruct_batch(endpoint_ids) # shape=(edges, d)
        distances = compute_distance(endpoint_values, query)  # shape=(edges,)

        for j in range(len(distances)):
            heapq.heappush(heap, (distances[j], endpoint_ids[j]))

        path_ids[i] = endpoint_ids[j]

    return path_ids, path_distances

queries = ds.get_queries() # shape=(batch, d)
max_steps = 100
path_ids = -1*np.ones((queries.shape[0], max_steps), dtype=np.int32)
for i, q in enumerate(queries):
    path_ids[i], path_distances = graph_search(graph, np.int32(23), q, max_steps)
    # plt.plot(path_distances)
# plt.show()

p_quantized = np.ones(index.ntotal)
path_ids = path_ids.flatten()
path_ids = path_ids[path_ids != -1]
ids, freqs = np.unique(path_ids, return_counts=True)
p_quantized[ids] = freqs
p = p_quantized/p_quantized.sum()
print('entropy:', -np.sum(p*np.log2(p)))
print('log2(#ids):', np.log2(index.ntotal))

entropy: 19.764543051025118
log2(#ids): 19.931568569324174


50k nodes, 16 outgoing edges per node

In [54]:
queries = ds.get_queries()
index.reconstruct_batch(graph[0]).shape

(16, 2)

In [47]:
graph[123]

array([36029, 49886, 34188, 28030, 45562, 15327,  8257, 21581,  8152,
       13419, 10381, 49588, 34097, 27447,  6180, 13718], dtype=int32)

edges of node #123 -- they are ordered arbitrarily

In [48]:
# number of -1s 
(graph == -1).sum() / graph.size

0.09029

9% invalid edges (represented as -1). 

In [49]:
queries = ds.get_queries()
print(queries.shape)
# reference results
Dref, Iref = index.search(queries, 5)

(4, 16)


In [50]:
Iref  # show ids of the 5 nearest neighbors of each query

array([[ 6708, 16841, 13264, 42596, 21075],
       [21695, 15810, 33083, 12244, 48081],
       [ 1126, 26665, 17933, 49309, 39121],
       [12169, 27365, 16199, 21937,  3238]])

# Order invariance

In [51]:
# sort all edges in graph (ignoring -1s)
for row in graph: 
    npos = (row >= 0).sum()
    row[:npos].sort()

In [52]:
def set_NSG_neighbors(nsg, neighbors): 
    graph = nsg.get_final_graph()
    assert neighbors.shape == (graph.N, graph.K)
    assert neighbors.dtype == np.int32
    faiss.memcpy(
        graph.data,
        faiss.swig_ptr(neighbors),
        neighbors.nbytes
    )

In [53]:
set_NSG_neighbors(index.nsg, graph)

In [54]:
queries = ds.get_queries()
Dnew, Inew = index.search(queries, 5)
Inew

array([[ 6708, 16841, 13264, 42596, 21075],
       [21695, 15810, 33083, 12244, 48081],
       [ 1126, 26665, 17933, 49309, 39121],
       [12169, 27365, 16199, 21937,  3238]])

In [55]:
np.all(Inew == Iref)

True

# Naive compression of the graph

In [56]:
# raw size in memory
graph.nbytes

3200000

In [57]:
# if we use only the necessary number of bits 
bits_per_vector = np.ceil(np.log2(index.ntotal + 1))
graph.size * bits_per_vector / 8

1600000.0

In [58]:
# if we don't store the -1s (but we need to store the nb valid entries per vector)
bits_per_vector = np.ceil(np.log2(index.ntotal))
(graph != -1).sum() * bits_per_vector / 8 + graph.shape[0]

1505536.0

In [60]:
# compress with generic compressor 
buf = io.BytesIO()
with bz2.open(buf, mode="wb") as f: 
    f.write(graph.tobytes())
len(buf.getbuffer())

1336870

# Do indices correlate with a set?

In [216]:
# ds = SyntheticDataset(128, 0, 1_000_000, 10_000)
ds = DatasetSIFT1M()
print(ds)

# create the C++ index: a NSG graph with 16 neighbors per node
index = faiss.index_factory(ds.d, "NSG64")
index.add(ds.get_database())
graph = get_NSG_neighbors(index.nsg)

dataset in dimension 128, with metric L2, size: Q 10000 B 1000000 T 100000


In [38]:
import faiss
import os
import numpy as np
from numpy.typing import NDArray
from faiss.contrib import datasets
from faiss.contrib.inspect_tools import get_NSG_neighbors
from joblib import Parallel, delayed

def average_gap_value(arr: NDArray[np.integer]) -> NDArray[np.integer]:
    '''Any negative value is considered invalid. `arr` should have shape (batch, sequence).'''

    has_at_least_2_valid_entries = (arr >= 0).sum(axis=1) >= 2
    arr = arr[has_at_least_2_valid_entries, :]
    gaps = np.diff(arr, axis=1)
    gap_sum = np.where(arr[:, :-1] >= 0, gaps, 0).sum(axis=1)
    valid_entries = (arr >= 0).sum(axis=1)
    return gap_sum/(valid_entries - 1)

def test_average_gap_value():
    arr = np.arange(100).reshape(10, 10)
    assert average_gap_value(arr).mean() == 1.0
    assert average_gap_value(23*arr).mean() == 23

    arr = np.arange(10).reshape(2, 5)
    arr[0, :2] = -1 # -1, -1, 2, 3, 4 -> 1, 1 -> 1
    arr[1, :3] = -1 # -1, -1, -1, 8, 9 -> 1, 1 -> 1
    assert average_gap_value(arr).mean() == 1.0

    # -1, -1, 6, 9, 12 -> 3, 3 -> 3
    # -1, -1, -1, 24, 27 -> 3 -> 3
    assert average_gap_value(3*arr).mean() == 3.0

def test_average_gap_value_uniform():
    np.random.seed(0)
    max_value = (1 << 32) - 1
    arr = np.sort(np.random.randint(0, max_value + 1, size=(5*1000)).reshape((5, 1000)), axis=1)

    avg_num_gaps = ((arr >= 0).sum(axis=1) - 1).mean()
    assert np.all(avg_num_gaps == 999)

    uniform_gap_value_from_bound = max_value/avg_num_gaps
    uniform_gap_value_from_max = arr.max()/avg_num_gaps 
    assert 0.98 < uniform_gap_value_from_bound/uniform_gap_value_from_max < 1.02

    avg_gap_value = average_gap_value(arr).mean()
    assert 0.98 < avg_gap_value/uniform_gap_value_from_max < 1.02

test_average_gap_value()
test_average_gap_value_uniform()

In [None]:
def compute_gap_statistics(ds, index_str='NSG64'):
    index = faiss.index_factory(ds.d, index_str)
    index.add(ds.get_database())

    graph = np.sort(get_NSG_neighbors(index.nsg), axis=1)
    avg_num_gaps = ((graph >= 0).sum(axis=1) - 1).mean()

    avg_gap_value = average_gap_value(graph).mean()
    uniform_gap_value = graph.max()/avg_num_gaps

    print('index_str:', index_str, 'Dataset:', ds.__class__.__name__, 'Average gap value of SIFT1M:', avg_gap_value, 'max/avg_of_gaps:', uniform_gap_value, 'gaps ratio': avg_gap_value/uniform_gap_value)

    return avg_gap_value, uniform_gap_value

datasets_space = [
    datasets.SyntheticDataset(128, 0, 1_000_000, 10_000),
    datasets.DatasetSIFT1M(),
    datasets.DatasetGIST1M()
]

index_str_space = [
    'NSG64', 'NSG32', 'NSG16', 'NSG8'
]

# results = Parallel(n_jobs=os.cpu_count())([
#     delayed(compute_gap_statistics)(ds, index_str)
#     for ds in datasets_space
#     for index_str in index_str_space
# ])

results = [
    compute_gap_statistics(ds, index_str)
    for ds in datasets_space
    for index_str in index_str_space
]