# Approximate k-NN search with locality-sensitive hashing

In [1]:
import numpy as np

In [2]:
rs = np.random.default_rng(0)

In [3]:
m = 1000  # number of data points in the
n = 16    # number of features in each data point

In [4]:
X = rs.normal(size=(m, n))  # random dataset
q = rs.normal(size=n)       # query vector

## Vanilla k-NN search

In [5]:
def knn_search(query, data, k=5):
    assert k <= len(data)
    dists = np.sqrt(np.sum((data - query) ** 2, axis=1)) # euclidean distance
    inds = np.argsort(dists)                             # sorted in ascending order
    inds_k = inds[:k]                                    # top k closest data points
    # NOTE: optionally, if the argumet dataset has a set of labels, we can also
    # associate the query vector with a label (i.e., classification).
    return data[inds_k], dists[inds_k]

In [6]:
neighbors, dists = knn_search(q, X)

print("query =", q)
print()

for i, (neighbor, dist) in enumerate(zip(neighbors, dists)):
    print(f"top {i + 1}:")
    print("neighbor =", neighbor)
    print("dist =", dist)
    print()

query = [ 0.74597801  0.10798961  0.96640122  0.15459683 -0.52637468  1.29406665
  0.15440615 -0.16837241 -1.25924732 -0.29044147 -1.6855752  -1.01036298
 -0.05422822 -0.90894185 -0.83150002 -1.18530497]

top 1:
neighbor = [ 0.79580696 -0.66135363  0.06473675  1.00379453 -0.4779662  -0.08503506
 -0.50306338  0.22048967 -0.2966549   0.16166078 -1.55502652 -0.52165985
 -0.95881421 -0.24767401 -1.38765319  0.44639838]
dist = 3.198502734555067

top 2:
neighbor = [ 0.17507704 -0.46295949  0.21952077 -0.25228969  0.60720812  1.05960923
 -0.65088179  0.6286331  -0.91197297  0.71001593 -0.8652757   0.34580728
 -0.72779903 -1.01855172  0.39106477 -0.23056385]
dist = 3.2410639159872017

top 3:
neighbor = [ 0.62507323  0.65184664  1.06381321 -1.53810049 -0.14029532  1.01364623
  0.18528485  0.0226172   0.11771915  0.08777791 -0.50787174 -1.73716936
 -1.02854233  0.16894379 -0.76239505 -2.26023546]
dist = 3.2674043491121405

top 4:
neighbor = [-0.2143315   1.03290702  0.11692843  0.60510131 -0.284

## Approximate k-NN search

In [7]:
def locality_sensitive_hash(data, hyperplanes):
    b = hyperplanes.shape[0]                  # number of hyperplanes (i.e., number of bits in each code)
    hamm_codes = (data @ hyperplanes.T) >= 0  # hamming codes
    hash_vals = hamm_codes @ np.array([2 ** i for i in range(b)], dtype=int)
    hash_table = {}
    for i, v in enumerate(hash_vals):
        if v not in hash_table:
            hash_table[v] = []
        hash_table[v].append(i)
    return hash_table

In [8]:
hyperplanes = rs.normal(size=(3, X.shape[1]))  # hyperplanes represented as their normal vectors
hash_table = locality_sensitive_hash(X, hyperplanes)

for k, v in hash_table.items():
    print(k, len(v))

5 111
7 143
2 94
1 144
6 137
0 159
4 106
3 106


In [None]:
def approx_knn_search(query, data, k=5):
    hyperplanes = 