Skip to content

Commit

Permalink
Add original detect filter penalty back
Browse files Browse the repository at this point in the history
  • Loading branch information
Siddharth Gollapudi committed May 2, 2024
1 parent a093ef3 commit 9fa8c42
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
3 changes: 3 additions & 0 deletions include/index.h
Expand Up @@ -104,6 +104,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
DISKANN_DLLEXPORT size_t get_num_points();
DISKANN_DLLEXPORT size_t get_max_points();

DISKANN_DLLEXPORT inline uint32_t detect_filter_penalty_orig(uint32_t point_id, bool search_invocation,
const std::vector<LabelT> &incoming_labels);

DISKANN_DLLEXPORT uint32_t detect_common_filters(uint32_t point_id, bool search_invocation,
const std::vector<LabelT> &incoming_labels);

Expand Down
54 changes: 51 additions & 3 deletions src/index.cpp
Expand Up @@ -836,6 +836,54 @@ uint32_t Index<T, TagT, LabelT>::detect_common_filters(uint32_t point_id, bool s
return common_filters.size();
}

template <typename T, typename TagT, typename LabelT>
inline uint32_t Index<T, TagT, LabelT>::detect_filter_penalty_orig(uint32_t point_id, bool search_invocation,
const std::vector<LabelT> &incoming_labels)
{

auto s = std::chrono::high_resolution_clock::now();

// not implemented for build-time use case, since we need to understand universal labels for multiple filters
// if (!search_invocation)
// return true;

// if (!search_invocation) {
// auto &curr_node_labels = _location_to_labels[point_id];
uint32_t overlap = 0;
for (auto &lbl : incoming_labels)
{
// if (std::find(curr_node_labels.begin(), curr_node_labels.end(), lbl) != curr_node_labels.end())
// if (!(_location_to_labels_robin[point_id].find(lbl) == _location_to_labels_robin[point_id].end()))
if (_labels_to_points[lbl].contains(point_id))
{
overlap++;
}
}
/* } else {
auto &curr_node_labels = _location_to_labels_bitmap[point_id];
for (auto &lbl : incoming_labels)
{
if (curr_node_labels.contains(lbl))
{
overlap++;
}
}
} */

// std::string tmp = "here, penalty=" + std::to_string(_filter_penalty_threshold) + ", overlap=" +
// std::to_string(overlap);
// std::cout << tmp << std::endl;

// if (overlap < _filter_penalty_threshold)
// return true;

// return false;
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
time_to_detect_penalty += diff.count();

return incoming_labels.size() - overlap;
}

// Find common filter between a node's labels and a given set of labels, while
// taking into account universal label
// TODO: modify for handling universal label
Expand Down Expand Up @@ -2076,8 +2124,8 @@ void Index<T, TagT, LabelT>::parse_label_file_bloom(const std::string &label_fil
std::string low = line.substr(64, 64);
assert((high.size() + low.size()) == 128);

uint128_t high_64 = (uint128_t)std::stoull(high,nullptr,2);
uint128_t low_64 = (uint128_t)std::stoull(low,nullptr,2);
uint128_t high_64 = (uint128_t)std::stoull(high, nullptr, 2);
uint128_t low_64 = (uint128_t)std::stoull(low, nullptr, 2);
_location_to_labels_bloom[line_cnt] = (high_64 << 64) | (low_64 << 0);

line_cnt++;
Expand Down Expand Up @@ -2573,7 +2621,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
for (size_t i = 0; i < best_L_nodes.size(); ++i)
{
if (best_L_nodes[i].id >= _max_points ||
(detect_filter_penalty(best_L_nodes[i].id, true, curr_query_bloom_filter) != 0))
(detect_filter_penalty_orig(best_L_nodes[i].id, true, filter_vec) != 0))
continue;

indices[pos] = (IdType)best_L_nodes[i].id;
Expand Down

0 comments on commit 9fa8c42

Please sign in to comment.