-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
demo_auto_tune.py
executable file
·170 lines (119 loc) · 4.24 KB
/
demo_auto_tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python2
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import os
import time
import numpy as np
try:
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot
graphical_output = True
except ImportError:
graphical_output = False
import faiss
#################################################################
# Small I/O functions
#################################################################
def ivecs_read(fname):
a = np.fromfile(fname, dtype="int32")
d = a[0]
return a.reshape(-1, d + 1)[:, 1:].copy()
def fvecs_read(fname):
return ivecs_read(fname).view('float32')
def plot_OperatingPoints(ops, nq, **kwargs):
ops = ops.optimal_pts
n = ops.size() * 2 - 1
pyplot.plot([ops.at( i // 2).perf for i in range(n)],
[ops.at((i + 1) // 2).t / nq * 1000 for i in range(n)],
**kwargs)
#################################################################
# prepare common data for all indexes
#################################################################
t0 = time.time()
print("load data")
xt = fvecs_read("sift1M/sift_learn.fvecs")
xb = fvecs_read("sift1M/sift_base.fvecs")
xq = fvecs_read("sift1M/sift_query.fvecs")
d = xt.shape[1]
print("load GT")
gt = ivecs_read("sift1M/sift_groundtruth.ivecs")
gt = gt.astype('int64')
k = gt.shape[1]
print("prepare criterion")
# criterion = 1-recall at 1
crit = faiss.OneRecallAtRCriterion(xq.shape[0], 1)
crit.set_groundtruth(None, gt)
crit.nnn = k
# indexes that are useful when there is no limitation on memory usage
unlimited_mem_keys = [
"IMI2x10,Flat", "IMI2x11,Flat",
"IVF4096,Flat", "IVF16384,Flat",
"PCA64,IMI2x10,Flat"]
# memory limited to 16 bytes / vector
keys_mem_16 = [
'IMI2x10,PQ16', 'IVF4096,PQ16',
'IMI2x10,PQ8+8', 'OPQ16_64,IMI2x10,PQ16'
]
# limited to 32 bytes / vector
keys_mem_32 = [
'IMI2x10,PQ32', 'IVF4096,PQ32', 'IVF16384,PQ32',
'IMI2x10,PQ16+16',
'OPQ32,IVF4096,PQ32', 'IVF4096,PQ16+16', 'OPQ16,IMI2x10,PQ16+16'
]
# indexes that can run on the GPU
keys_gpu = [
"PCA64,IVF4096,Flat",
"PCA64,Flat", "Flat", "IVF4096,Flat", "IVF16384,Flat",
"IVF4096,PQ32"]
keys_to_test = unlimited_mem_keys
use_gpu = False
if use_gpu:
# if this fails, it means that the GPU version was not comp
assert faiss.StandardGpuResources, \
"FAISS was not compiled with GPU support, or loading _swigfaiss_gpu.so failed"
res = faiss.StandardGpuResources()
dev_no = 0
# remember results from other index types
op_per_key = []
# keep track of optimal operating points seen so far
op = faiss.OperatingPoints()
for index_key in keys_to_test:
print("============ key", index_key)
# make the index described by the key
index = faiss.index_factory(d, index_key)
if use_gpu:
# transfer to GPU (may be partial)
index = faiss.index_cpu_to_gpu(res, dev_no, index)
params = faiss.GpuParameterSpace()
else:
params = faiss.ParameterSpace()
params.initialize(index)
print("[%.3f s] train & add" % (time.time() - t0))
index.train(xt)
index.add(xb)
print("[%.3f s] explore op points" % (time.time() - t0))
# find operating points for this index
opi = params.explore(index, xq, crit)
print("[%.3f s] result operating points:" % (time.time() - t0))
opi.display()
# update best operating points so far
op.merge_with(opi, index_key + " ")
op_per_key.append((index_key, opi))
if graphical_output:
# graphical output (to tmp/ subdirectory)
fig = pyplot.figure(figsize=(12, 9))
pyplot.xlabel("1-recall at 1")
pyplot.ylabel("search time (ms/query, %d threads)" % faiss.omp_get_max_threads())
pyplot.gca().set_yscale('log')
pyplot.grid()
for i2, opi2 in op_per_key:
plot_OperatingPoints(opi2, crit.nq, label = i2, marker = 'o')
# plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r')
pyplot.legend(loc=2)
fig.savefig('tmp/demo_auto_tune.png')
print("[%.3f s] final result:" % (time.time() - t0))
op.display()