In [6]:
import shutil
import urllib.request as request
from contextlib import closing

# first we download the Sift1M dataset
with closing(request.urlopen('ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz')) as r:
    with open('sift.tar.gz', 'wb') as f:
        shutil.copyfileobj(r, f)

In [7]:
import tarfile

# the download leaves us with a tar.gz file, we unzip it
tar = tarfile.open('sift.tar.gz', "r:gz")
tar.extractall()

In [8]:
import numpy as np

# now define a function to read the fvecs file format of Sift1M dataset
def read_fvecs(fp):
    a = np.fromfile(fp, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy().view('float32')

In [9]:
# data we will search through
wb = read_fvecs('./sift/sift_base.fvecs')  # 1M samples
# also get some query vectors to search with
xq = read_fvecs('./sift/sift_query.fvecs')
# take just one query (there are many in sift_learn.fvecs)
xq = xq[0].reshape(1, xq.shape[1])


In [10]:
xq.shape


(1, 128)

In [11]:
wb.shape

(1000000, 128)

In [37]:
q = wb[0].reshape(1, xq.shape[1])
q

array([[  0.,  16.,  35.,   5.,  32.,  31.,  14.,  10.,  11.,  78.,  55.,
         10.,  45.,  83.,  11.,   6.,  14.,  57., 102.,  75.,  20.,   8.,
          3.,   5.,  67.,  17.,  19.,  26.,   5.,   0.,   1.,  22.,  60.,
         26.,   7.,   1.,  18.,  22.,  84.,  53.,  85., 119., 119.,   4.,
         24.,  18.,   7.,   7.,   1.,  81., 106., 102.,  72.,  30.,   6.,
          0.,   9.,   1.,   9., 119.,  72.,   1.,   4.,  33., 119.,  29.,
          6.,   1.,   0.,   1.,  14.,  52., 119.,  30.,   3.,   0.,   0.,
         55.,  92., 111.,   2.,   5.,   4.,   9.,  22.,  89.,  96.,  14.,
          1.,   0.,   1.,  82.,  59.,  16.,  20.,   5.,  25.,  14.,  11.,
          4.,   0.,   0.,   1.,  26.,  47.,  23.,   4.,   0.,   0.,   4.,
         38.,  83.,  30.,  14.,   9.,   4.,   9.,  17.,  23.,  41.,   0.,
          0.,   2.,   8.,  19.,  25.,  23.,   1.]], dtype=float32)

In [38]:
import faiss

D = xq.shape[1]
m = 8
assert D % m == 0
nbits = 8  # number of bits per subquantizer, k* = 2**nbits
index = faiss.IndexPQ(D, m, nbits)

In [39]:
index.is_trained

False

In [40]:
index.train(wb) 

In [41]:
index.is_trained

True

In [42]:
k = 3
dist, I = index.search(q, k)
dist

array([[3.4028235e+38, 3.4028235e+38, 3.4028235e+38]], dtype=float32)

In [45]:
%%timeit
index.search(q, k)

49.7 µs ± 3.06 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [46]:
l2_index = faiss.IndexFlatL2(D)
l2_index.add(wb)

0


In [47]:
print(index.ntotal)

0


In [24]:
%%time
l2_dist, l2_I = l2_index.search(xq, k)

CPU times: user 81.4 ms, sys: 10 µs, total: 81.4 ms
Wall time: 80.2 ms


In [25]:
sum([1 for i in I[0] if i in l2_I])

0