Skip to content

Commit 3671c61

Browse files
limqiyingfacebook-github-bot
authored andcommitted
more formatting (#4568)
Summary: Pull Request resolved: #4568 just formatting lints on Faiss Reviewed By: junjieqi Differential Revision: D81286229 fbshipit-source-id: 208efcf07b92a6dec9f78056a07f0c5bf7199aea
1 parent 1ed2611 commit 3671c61

File tree

9 files changed

+220
-131
lines changed

9 files changed

+220
-131
lines changed

benchs/bench_all_ivf/bench_all_ivf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def run_search(args, ds, index, res):
400400
# Driver function
401401
######################################################
402402

403+
403404
def main():
404405

405406
parser = argparse.ArgumentParser()

benchs/bench_all_ivf/cmp_with_scann.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def aa(*args, **kwargs):
5555
aa('--download', default=False, action="store_true")
5656
aa('--lib', default='faiss', help='library to use (faiss or scann)')
5757
aa('--thenscann', default=False, action="store_true")
58-
aa('--base_dir', default='/checkpoint/matthijs/faiss_improvements/cmp_ivf_scan_2')
58+
aa('--base_dir',
59+
default='/checkpoint/matthijs/faiss_improvements/cmp_ivf_scan_2')
5960

6061
group = parser.add_argument_group('searching')
6162
aa('--k', default=10, type=int, help='nb of nearest neighbors')
@@ -83,11 +84,11 @@ def aa(*args, **kwargs):
8384
print(ds)
8485
# store for SCANN
8586
os.system(f"rm -rf {cache_dir}; mkdir -p {cache_dir}")
86-
tosave = dict(
87-
xb = ds.get_database(),
88-
xq = ds.get_queries(),
89-
gt = ds.get_groundtruth()
90-
)
87+
tosave = {
88+
"xb": ds.get_database(),
89+
"xq": ds.get_queries(),
90+
"gt": ds.get_groundtruth(),
91+
}
9192
for name, v in tosave.items():
9293
fname = cache_dir + "/" + name + ".npy"
9394
print("save", fname)
@@ -137,13 +138,15 @@ def aa(*args, **kwargs):
137138
if os.path.exists(scann_dir + "/scann_config.pb"):
138139
searcher = scann_ops_pybind.load_searcher(scann_dir)
139140
else:
140-
searcher = scann_make_index(xb, name1_to_name2[distance_measure], scann_dir, 0)
141+
searcher = scann_make_index(
142+
xb, name1_to_name2[distance_measure], scann_dir, 0)
141143

142144
scann_dir = cache_dir + "/scann1.1.1_serialized_reorder"
143145
if os.path.exists(scann_dir + "/scann_config.pb"):
144146
searcher_reo = scann_ops_pybind.load_searcher(scann_dir)
145147
else:
146-
searcher_reo = scann_make_index(xb, name1_to_name2[distance_measure], scann_dir, 100)
148+
searcher_reo = scann_make_index(
149+
xb, name1_to_name2[distance_measure], scann_dir, 100)
147150

148151
scann_eval_search(
149152
searcher, searcher_reo,
@@ -180,8 +183,10 @@ def scann_make_index(xb, distance_measure, scann_dir, reorder_k):
180183
else:
181184
thr = 0
182185
k = 10
183-
sb = scann.scann_ops_pybind.builder(xb, k, distance_measure)
184-
sb = sb.tree(num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000)
186+
sb = scann.scann_ops_pybind.builder(
187+
xb, k, distance_measure)
188+
sb = sb.tree(
189+
num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000)
185190
sb = sb.score_ah(2, anisotropic_quantization_threshold=thr)
186191

187192
if reorder_k > 0:
@@ -198,6 +203,7 @@ def scann_make_index(xb, distance_measure, scann_dir, reorder_k):
198203
searcher.serialize(scann_dir)
199204
return searcher
200205

206+
201207
def scann_eval_search(
202208
searcher, searcher_reo,
203209
xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
@@ -235,8 +241,6 @@ def scann_eval_search(
235241
eval_inters(header, I, gt, times)
236242

237243

238-
239-
240244
###############################################################
241245
# Faiss
242246
###############################################################
@@ -261,6 +265,7 @@ def faiss_make_index(xb, metric_type, fname):
261265

262266
return index
263267

268+
264269
def faiss_eval_search(
265270
index, xq, xb, nprobe_tab, pre_reorder_k_tab,
266271
k, gt, nrun, measure

benchs/bench_gpu_1bn.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,38 @@ def usage():
9090

9191
while args:
9292
a = args.pop(0)
93-
if a == '-h': usage()
94-
elif a == '-ngpu': ngpu = int(args.pop(0))
95-
elif a == '-R': replicas = int(args.pop(0))
96-
elif a == '-noptables': use_precomputed_tables = False
97-
elif a == '-abs': add_batch_size = int(args.pop(0))
98-
elif a == '-qbs': query_batch_size = int(args.pop(0))
99-
elif a == '-nnn': nnn = int(args.pop(0))
100-
elif a == '-tempmem': tempmem = int(args.pop(0))
101-
elif a == '-nocache': use_cache = False
102-
elif a == '-knngraph': knngraph = True
103-
elif a == '-altadd': altadd = True
104-
elif a == '-float16': use_float16 = True
105-
elif a == '-nprobe': nprobes = [int(x) for x in args.pop(0).split(',')]
106-
elif a == '-max_add': max_add = int(args.pop(0))
107-
elif not dbname: dbname = a
108-
elif not index_key: index_key = a
93+
if a == '-h':
94+
usage()
95+
elif a == '-ngpu':
96+
ngpu = int(args.pop(0))
97+
elif a == '-R':
98+
replicas = int(args.pop(0))
99+
elif a == '-noptables':
100+
use_precomputed_tables = False
101+
elif a == '-abs':
102+
add_batch_size = int(args.pop(0))
103+
elif a == '-qbs':
104+
query_batch_size = int(args.pop(0))
105+
elif a == '-nnn':
106+
nnn = int(args.pop(0))
107+
elif a == '-tempmem':
108+
tempmem = int(args.pop(0))
109+
elif a == '-nocache':
110+
use_cache = False
111+
elif a == '-knngraph':
112+
knngraph = True
113+
elif a == '-altadd':
114+
altadd = True
115+
elif a == '-float16':
116+
use_float16 = True
117+
elif a == '-nprobe':
118+
nprobes = [int(x) for x in args.pop(0).split(',')]
119+
elif a == '-max_add':
120+
max_add = int(args.pop(0))
121+
elif not dbname:
122+
dbname = a
123+
elif not index_key:
124+
index_key = a
109125
else:
110126
print("argument %s unknown" % a, file=sys.stderr)
111127
sys.exit(1)
@@ -239,7 +255,6 @@ def eval_intersection_measure(gt_I, I):
239255
gt_I.shape if gt_I is not None else None))
240256

241257

242-
243258
#################################################################
244259
# Parse index_key and set cache files
245260
#
@@ -305,7 +320,7 @@ def eval_intersection_measure(gt_I, I):
305320

306321
gpu_resources = []
307322

308-
for i in range(ngpu):
323+
for _ in range(ngpu):
309324
res = faiss.StandardGpuResources()
310325
if tempmem >= 0:
311326
res.setTempMemory(tempmem)
@@ -608,7 +623,6 @@ def quantize(args):
608623
return None, indexall
609624

610625

611-
612626
def get_populated_index(preproc):
613627

614628
if not index_cachefile or not os.path.exists(index_cachefile):
@@ -644,7 +658,7 @@ def get_populated_index(preproc):
644658
index = gpu_index
645659

646660
else:
647-
del gpu_index # We override the GPU index
661+
del gpu_index # We override the GPU index
648662

649663
print("Copy CPU index to %d sharded GPU indexes" % replicas)
650664

@@ -719,8 +733,9 @@ def eval_dataset(index, preproc):
719733
print(" probe=%-3d: %.3f s" % (nprobe, t1 - t0), end=' ')
720734
gtc = gt_I[:, :1]
721735
nq = xq.shape[0]
722-
for rank in 1, 10, 100:
723-
if rank > nnn: continue
736+
for rank in (1, 10, 100):
737+
if rank > nnn:
738+
continue
724739
nok = (I[:, :rank] == gtc).sum()
725740
print("1-R@%d: %.4f" % (rank, nok / float(nq)), end=' ')
726741
print()

benchs/distributed_ondisk/search_server.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
import combined_index
1111
import argparse
1212

13+
from multiprocessing.pool import ThreadPool
14+
import faiss
15+
import numpy as np
1316

1417

1518
############################################################
1619
# Server implementation
1720
############################################################
1821

22+
1923
class MyServer(rpc.Server):
2024
""" Assign version that can be exposed via RPC """
2125
def __init__(self, s, index):
@@ -56,6 +60,7 @@ def aa(*args, **kwargs):
5660
args.port, report_to_file=when_ready,
5761
v6=not args.ipv4)
5862

63+
5964
if __name__ == '__main__':
6065
main()
6166

@@ -64,11 +69,6 @@ def aa(*args, **kwargs):
6469
# Client implementation
6570
############################################################
6671

67-
from multiprocessing.pool import ThreadPool
68-
import faiss
69-
import numpy as np
70-
71-
7272

7373
class ResultHeap:
7474
""" Combine query results from a sliced dataset (for k-nn search) """
@@ -111,7 +111,6 @@ def distribute_weights(weights, nbin):
111111
return bins, assign
112112

113113

114-
115114
class SplitPerListIndex:
116115
"""manages a local index, that does the coarse quantization and a set
117116
of sub_indexes. The sub_indexes search a subset of the inverted
@@ -177,15 +176,15 @@ def do_query(i):
177176
t0 = time.time()
178177
Di, Ii = sub_index.ivf_search_preassigned(
179178
xqo, list_nos_i, coarse_dis, k)
180-
#print(list_nos_i, Ii)
179+
# print(list_nos_i, Ii)
181180
if self.verbose:
182181
print('client %d: %.3f s' % (i, time.time() - t0))
183182
return Di, Ii
184183

185184
rh = ResultHeap(x.shape[0], k)
186185

187186
for Di, Ii in self.pool.imap(do_query, range(self.ni)):
188-
#print("ADD", Ii, rh.I)
187+
# print("ADD", Ii, rh.I)
189188
rh.add_batch_result(Di, Ii, 0)
190189
rh.finalize()
191190
return rh.D, rh.I

0 commit comments

Comments
 (0)