/
eval_annoy.py
executable file
·114 lines (92 loc) · 3.36 KB
/
eval_annoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# coding: utf-8
"""
File Name: eval_annoy.py
Author: Wan Ji
E-mail: wanji@live.com
Created on: Mon Aug 10 21:40:32 2015 CST
"""
DESCRIPTION = """
Evaluation the performance of Annoy.
"""
import os
import argparse
import logging
import time
import numpy as np
from annoy import AnnoyIndex
from hdidx.util import tic, toc
from eval_indexer import Dataset, compute_stats
def runcmd(cmd):
""" Run command.
"""
logging.info("%s" % cmd)
os.system(cmd)
def getargs():
""" Parse program arguments.
"""
parser = argparse.ArgumentParser(
description=DESCRIPTION,
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('dataset', type=str,
help='path of the dataset')
parser.add_argument('--exp_dir', type=str,
help='directory for saving experimental results')
parser.add_argument("--ntrees", type=int, nargs='+', default=[16],
help="number of trees")
parser.add_argument("--topk", type=int, default=100,
help="retrieval `topk` nearest neighbors")
parser.add_argument("--log", type=str, default="INFO",
help="log level")
return parser.parse_args()
def main(args):
""" Main entry.
"""
data = Dataset(args.dataset)
f = data.base.shape[1]
for ntrees in args.ntrees:
t = AnnoyIndex(f) # Length of item vector that will be indexed
idxpath = os.path.join(args.exp_dir, 'sift_annoy_ntrees%d.idx' % ntrees)
if not os.path.exists(idxpath):
logging.info("Adding items ...")
for i in xrange(data.nbae):
t.add_item(i, data.base[i])
if i % 100000 == 0:
logging.info("\t%d/%d" % (i, data.nbae))
logging.info("\tDone!")
logging.info("Building indexes ...")
t.build(ntrees)
logging.info("\tDone!")
t.save(idxpath)
else:
logging.info("Loading indexes ...")
t.load(idxpath)
logging.info("\tDone!")
ids = np.zeros((data.nqry, args.topk), np.int)
logging.info("Searching ...")
tic()
for i in xrange(data.nqry):
ids[i, :] = np.array(t.get_nns_by_vector(data.query[i], args.topk))
time_costs = toc()
logging.info("\tDone!")
report = os.path.join(args.exp_dir, "report.txt")
with open(report, "a") as rptf:
rptf.write("*" * 64 + "\n")
rptf.write("* %s\n" % time.asctime())
rptf.write("*" * 64 + "\n")
r_at_k = compute_stats(data.groundtruth, ids, args.topk)[-1][-1]
with open(report, "a") as rptf:
rptf.write("=" * 64 + "\n")
rptf.write("index_%s-ntrees_%s\n" % ("Annoy", ntrees))
rptf.write("-" * 64 + "\n")
rptf.write("recall@%-8d%.4f\n" % (args.topk, r_at_k))
rptf.write("time cost (ms): %.3f\n" %
(time_costs * 1000 / data.nqry))
if __name__ == '__main__':
args = getargs()
numeric_level = getattr(logging, args.log.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError("Invalid log level: " + args.log)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s",
level=numeric_level)
main(args)