Skip to content

Commit

Permalink
additional stats retrieval from python interface
Browse files Browse the repository at this point in the history
  • Loading branch information
harsha-simhadri committed Oct 7, 2021
1 parent ab50064 commit 4c7e960
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions python/src/diskann_bindings.cpp
Expand Up @@ -51,27 +51,24 @@ struct DiskANNIndex {
<< " nodes based on BFS." << std::endl;
}


void cache_sample_paths(size_t num_nodes_to_cache,
const std::string &warmup_query_file,
uint32_t num_threads) {

if (!file_exists(warmup_query_file)) {
std::cout << "No warm up query file exists." << std::endl;
return;
}

std::vector<uint32_t> node_list;
pq_flash_index->generate_cache_list_from_sample_queries(
warmup_query_file, 15, 4, num_nodes_to_cache, num_threads,
node_list);
warmup_query_file, 15, 4, num_nodes_to_cache, num_threads, node_list);
pq_flash_index->load_cache_list(node_list);
std::cout << "loaded index, cached " << node_list.size()
<< " nodes based on sample search paths." << std::endl;
}

int load_index(const std::string &index_path_prefix, const int num_threads,
const size_t num_nodes_to_cache, int cache_mechanism) {
const size_t num_nodes_to_cache, int cache_mechanism) {
const std::string pq_path = index_path_prefix + std::string("_pq");
const std::string index_path =
index_path_prefix + std::string("_disk.index");
Expand All @@ -84,7 +81,8 @@ struct DiskANNIndex {
if (cache_mechanism == 0) {
// Nothing to do
} else if (cache_mechanism == 1) {
std::string sample_file = index_path_prefix + std::string("_sample_data.bin");
std::string sample_file =
index_path_prefix + std::string("_sample_data.bin");
cache_sample_paths(num_nodes_to_cache, sample_file, num_threads);
} else if (cache_mechanism == 2) {
cache_bfs_levels(num_nodes_to_cache);
Expand Down Expand Up @@ -161,24 +159,39 @@ struct DiskANNIndex {
py::array_t<unsigned> ids({num_queries, knn});
py::array_t<float> dists({num_queries, knn});

std::vector<_u64> u64_ids(knn * num_queries);
std::vector<_u64> u64_ids(knn * num_queries);
diskann::QueryStats *stats = new diskann::QueryStats[num_queries];

#pragma omp parallel for schedule(dynamic, 1)
for (_u64 i = 0; i < num_queries; i++) {
pq_flash_index->cached_beam_search(queries.data(i), knn, l_search,
u64_ids.data() + i * knn,
dists.mutable_data(i), beam_width);
pq_flash_index->cached_beam_search(
queries.data(i), knn, l_search, u64_ids.data() + i * knn,
dists.mutable_data(i), beam_width, stats + i);
}

auto r = ids.mutable_unchecked();
for (_u64 i = 0; i < num_queries; ++i)
for (_u64 j = 0; j < knn; ++j)
r(i, j) = (unsigned) u64_ids[i * knn + j];

return std::make_pair(ids, dists);
std::unordered_map<std::string, double> collective_stats;
collective_stats["mean_latency"] = diskann::get_mean_stats(
stats, num_queries,
[](const diskann::QueryStats &stats) { return stats.total_us; });
collective_stats["latency_999"] = diskann::get_percentile_stats(
stats, num_queries, 0.999,
[](const diskann::QueryStats &stats) { return stats.total_us; });
collective_stats["mean_ssd_ios"] = diskann::get_mean_stats(
stats, num_queries,
[](const diskann::QueryStats &stats) { return stats.n_ios; });
collective_stats["mean_dist_comps"] = diskann::get_mean_stats(
stats, num_queries,
[](const diskann::QueryStats &stats) { return stats.n_cmps; });
delete[] stats;
return std::make_pair(std::make_pair(ids, dists), collective_stats);
}

auto batch_range_search_numpy_input(
auto batch_range_search_numpy_input(
py::array_t<T, py::array::c_style | py::array::forcecast> &queries,
const _u64 dim, const _u64 num_queries, const double range,
const _u64 min_list_size, const _u64 max_list_size, const _u64 beam_width,
Expand All @@ -191,11 +204,13 @@ auto batch_range_search_numpy_input(
auto offsets_mutable = offsets.mutable_unchecked();
offsets_mutable(0) = 0;

diskann::QueryStats *stats = new diskann::QueryStats[num_queries];

#pragma omp parallel for schedule(dynamic, 1)
for (_u64 i = 0; i < num_queries; i++) {
_u32 res_count = pq_flash_index->range_search(
queries.data(i), range, min_list_size, max_list_size, u64_ids[i],
dists[i], beam_width);
dists[i], beam_width, stats + i);
offsets_mutable(i + 1) = res_count;
}

Expand All @@ -217,7 +232,23 @@ auto batch_range_search_numpy_input(
}
offsets_mutable(i + 1) = offsets_mutable(i) + offsets_mutable(i + 1);
}
return std::make_pair(offsets, std::make_pair(ids, res_dists));

std::unordered_map<std::string, double> collective_stats;
collective_stats["mean_latency"] = diskann::get_mean_stats(
stats, num_queries,
[](const diskann::QueryStats &stats) { return stats.total_us; });
collective_stats["latency_999"] = diskann::get_percentile_stats(
stats, num_queries, 0.999,
[](const diskann::QueryStats &stats) { return stats.total_us; });
collective_stats["mean_ssd_ios"] = diskann::get_mean_stats(
stats, num_queries,
[](const diskann::QueryStats &stats) { return stats.n_ios; });
collective_stats["mean_dist_comps"] = diskann::get_mean_stats(
stats, num_queries,
[](const diskann::QueryStats &stats) { return stats.n_cmps; });
delete[] stats;
return std::make_pair(std::make_pair(offsets, std::make_pair(ids, res_dists)),
collective_stats);
}
};

Expand Down

0 comments on commit 4c7e960

Please sign in to comment.