In [1]:
import numpy as np

d = 64                           # dimension
nb = 100000                      # database size
nq = 10000                       # nb of queries
np.random.seed(1234)             # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.
xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.

import faiss                   # make faiss available
index = faiss.IndexFlatL2(d)   # build the index
print(index.is_trained)
index.add(xb)                  # add vectors to the index
print(index.ntotal)

k = 4                          # we want to see 4 nearest neighbors
D, I = index.search(xb[:5], k) # sanity check
print(I)
print(D)
D, I = index.search(xq, k)     # actual search
print(I[:5])                   # neighbors of the 5 first queries
print(I[-5:])                  # neighbors of the 5 last queries


True
100000
[[  0 393 363  78]
 [  1 555 277 364]
 [  2 304 101  13]
 [  3 173  18 182]
 [  4 288 370 531]]
[[0.        7.1751738 7.20763   7.2511625]
 [0.        6.323565  6.684581  6.799946 ]
 [0.        5.7964087 6.391736  7.2815123]
 [0.        7.2779055 7.527987  7.6628466]
 [0.        6.7638035 7.2951202 7.3688145]]
[[ 381  207  210  477]
 [ 526  911  142   72]
 [ 838  527 1290  425]
 [ 196  184  164  359]
 [ 526  377  120  425]]
[[ 9900 10500  9309  9831]
 [11055 10895 10812 11321]
 [11353 11103 10164  9787]
 [10571 10664 10632  9638]
 [ 9628  9554 10036  9582]]


In [3]:
xl = np.random.randint(0, 10, size=(nb,)).astype('int64')

In [5]:
print(xb.shape, xl.shape)

(100000, 64) (100000,)


In [6]:
import tensorflow as tf
train_ds = tf.data.Dataset.from_tensor_slices(
            (xb, xl)).shuffle(10000).batch(5)


In [None]:
next(iter(train_ds))

In [15]:
def tf_search(x):
    def search(x):
        return index.search(x, 3)
    return tf.numpy_function(search, [x],
        (tf.float32, tf.int64))

In [20]:
ds = train_ds.map(tf_search)

In [21]:
next(iter(ds))

(<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[0.       , 6.2177753, 6.486453 ],
        [0.       , 6.0933847, 6.17522  ],
        [0.       , 5.8112097, 5.8512087],
        [0.       , 5.6316347, 6.290154 ],
        [0.       , 6.317706 , 6.3523116]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=int64, numpy=
 array([[9016, 7887, 8985],
        [2665, 3207, 2300],
        [8466, 8582, 8155],
        [2301, 1475, 1573],
        [6996, 7807, 7870]])>)