Skip to content

Commit

Permalink
add --count and --batch args for data_export.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyinzuo committed May 17, 2024
1 parent 19d9f8c commit 3982de0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ann_benchmarks/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_all_results(dataset: Optional[str] = None,
Yields:
tuple: A tuple containing properties as a dictionary and an h5py file object.
"""
for root, _, files in os.walk(build_result_filepath(dataset, count)):
for root, _, files in os.walk(build_result_filepath(dataset, count, batch_mode=batch_mode)):
for filename in files:
if os.path.splitext(filename)[-1] != ".hdf5":
continue
Expand All @@ -110,4 +110,4 @@ def get_unique_algorithms() -> Set[str]:
for batch_mode in [False, True]:
for properties, _ in load_all_results(batch_mode=batch_mode):
algorithms.add(properties["algo"])
return algorithms
return algorithms
11 changes: 9 additions & 2 deletions data_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
parser = argparse.ArgumentParser()
parser.add_argument("--output", help="Path to the output file", required=True)
parser.add_argument("--recompute", action="store_true", help="Recompute metrics")
parser.add_argument(
"-k", "--count", default=10, type=int, help="The number of near neighbours to search for"
)
parser.add_argument("--batch", action="store_true", help="Batch mode")
args = parser.parse_args()

datasets = DATASETS.keys()
dfs = []
for dataset_name in datasets:
print("Looking at dataset", dataset_name)
if len(list(load_all_results(dataset_name))) > 0:
results = load_all_results(dataset_name)
if len(list(load_all_results(dataset_name,
count=args.count,
batch_mode=args.batch
))) > 0:
results = load_all_results(dataset_name, count=args.count, batch_mode=args.batch)
dataset, _ = get_dataset(dataset_name)
results = compute_metrics_all_runs(dataset, results, args.recompute)
for res in results:
Expand Down

0 comments on commit 3982de0

Please sign in to comment.