Skip to content

Commit

Permalink
Merge branch 'rakri/multifilter_with_query_planning' of https://githu…
Browse files Browse the repository at this point in the history
…b.com/microsoft/DiskANN into rakri/multifilter_with_query_planning
  • Loading branch information
ravishankar committed May 1, 2024
2 parents 8237560 + baaddce commit 6581827
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 40 deletions.
7 changes: 3 additions & 4 deletions include/index.h
Expand Up @@ -413,10 +413,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
tsl::robin_set<LabelT> _labels;
std::string _labels_file;
std::unordered_map<LabelT, uint32_t> _label_to_start_id;
// std::unordered_map<LabelT, roaring::Roaring> _labels_to_points;
std::vector<roaring::Roaring> _labels_to_points;
std::unordered_map<LabelT, roaring::Roaring> _labels_to_points_samples;
std::vector<std::unordered_map<LabelT, roaring::Roaring>> _clusters_to_labels_to_points;
std::vector<roaring::Roaring> _labels_to_points;
std::vector<roaring::Roaring> _labels_to_points_samples;
std::vector<std::vector<roaring::Roaring>> _clusters_to_labels_to_points;
std::unordered_map<uint32_t, uint32_t> _medoid_counts;
diskann::InMemClusterStore<T> *_ivf_clusters = nullptr;

Expand Down
2 changes: 1 addition & 1 deletion include/scratch.h
Expand Up @@ -97,10 +97,10 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>

inline roaring::Roaring &get_valid_bitmap()
{
_last_intersection.removeRangeClosed(_last_intersection.minimum(), _last_intersection.maximum());
return _last_intersection;
}


private:
uint32_t _L;
uint32_t _R;
Expand Down
85 changes: 50 additions & 35 deletions src/index.cpp
Expand Up @@ -25,7 +25,7 @@

#include "index.h"

#define MAX_POINTS_FOR_USING_BITSET 10000000
#define MAX_POINTS_FOR_USING_BITSET 40000000

namespace diskann
{
Expand Down Expand Up @@ -646,12 +646,14 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui
{
for (auto const &filter : _location_to_labels[*j])
{
_clusters_to_labels_to_points[i].reserve(_labels.size());
_clusters_to_labels_to_points[i][filter].add(*j);
}
}
}

std::srand(time(NULL));
_labels_to_points_samples.reserve(_labels.size());
for (auto const &label : _labels)
{
/* if (_labels_to_points[label].cardinality() > _bruteforce_threshold) */
Expand Down Expand Up @@ -913,9 +915,16 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::closest_cluster_filters(co
tmp &= _clusters_to_labels_to_points[cluster_id][filter_vec[k]];
}

roaring::Roaring::const_iterator x = tmp.begin();
x++;
for (roaring::Roaring::const_iterator i = tmp.begin(); i != tmp.end(); i++)
{
float distance = _data_store->get_distance(aligned_query, *i);
if (x != tmp.end())
{
_data_store->prefetch_vector(*x);
x++;
}
Neighbor nn = Neighbor(*i, distance);
best_L_nodes.insert(nn);
cmps++;
Expand Down Expand Up @@ -966,15 +975,17 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::brute_force_filters(const
#ifdef INSTRUMENT
auto s = std::chrono::high_resolution_clock::now();
#endif
// for (roaring::Roaring::const_iterator i = init_ids.begin(); i != init_ids.end(); i++)
for (uint32_t value : init_ids)
{
// float distance = _data_store->get_distance(aligned_query, *i);
float distance = _data_store->get_distance(aligned_query, value);
// if ((i+1) != init_ids.end())
// _data_store->prefetch_vector(*(i+1));
// Neighbor nn = Neighbor(*i, distance);
Neighbor nn = Neighbor(value, distance);
roaring::Roaring::const_iterator x = init_ids.begin();
x++;
for (roaring::Roaring::const_iterator i = init_ids.begin(); i != init_ids.end(); i++)
{
float distance = _data_store->get_distance(aligned_query, *i);
if (x != init_ids.end())
{
_data_store->prefetch_vector(*x);
x++;
}
Neighbor nn = Neighbor(*i, distance);
best_L_nodes.insert(nn);
cmps++;
}
Expand All @@ -994,7 +1005,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes();
best_L_nodes.reserve(Lsize);
tsl::robin_set<uint32_t> &inserted_into_pool_rs = scratch->inserted_into_pool_rs();
boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs();
roaring::Roaring &inserted_into_pool_bs = scratch->get_valid_bitmap();
std::vector<uint32_t> &id_scratch = scratch->id_scratch();
std::vector<float> &dist_scratch = scratch->dist_scratch();
assert(id_scratch.size() == 0);
Expand All @@ -1014,20 +1025,9 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
auto total_num_points = _max_points + _num_frozen_pts;
bool fast_iterate = total_num_points <= MAX_POINTS_FOR_USING_BITSET;

if (fast_iterate)
{
if (inserted_into_pool_bs.size() < total_num_points)
{
// hopefully using 2X will reduce the number of allocations.
auto resize_size =
2 * total_num_points > MAX_POINTS_FOR_USING_BITSET ? MAX_POINTS_FOR_USING_BITSET : 2 * total_num_points;
inserted_into_pool_bs.resize(resize_size);
}
}

// Lambda to determine if a node has been visited
auto is_not_visited = [this, fast_iterate, &inserted_into_pool_bs, &inserted_into_pool_rs](const uint32_t id) {
return fast_iterate ? inserted_into_pool_bs[id] == 0
return fast_iterate ? !inserted_into_pool_bs.contains(id)
: inserted_into_pool_rs.find(id) == inserted_into_pool_rs.end();
};

Expand Down Expand Up @@ -1064,7 +1064,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
{
if (fast_iterate)
{
inserted_into_pool_bs[id] = 1;
inserted_into_pool_bs.add(id);
}
else
{
Expand Down Expand Up @@ -1123,7 +1123,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(

if (fast_iterate)
{
inserted_into_pool_bs[id] = 1;
inserted_into_pool_bs.add(id);
}
else
{
Expand Down Expand Up @@ -1151,9 +1151,17 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
}
else
{
_locks[n].lock();
auto nbrs = _graph_store->get_neighbours(n);
_locks[n].unlock();
std::vector<location_t> nbrs;
if (search_invocation)
{
nbrs = _graph_store->get_neighbours(n);
}
else
{
_locks[n].lock();
nbrs = _graph_store->get_neighbours(n);
_locks[n].unlock();
}
for (auto id : nbrs)
{
assert(id < _max_points + _num_frozen_pts);
Expand All @@ -1163,7 +1171,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(

if (fast_iterate)
{
inserted_into_pool_bs[id] = 1;
inserted_into_pool_bs.add(id);
}
else
{
Expand Down Expand Up @@ -2048,13 +2056,12 @@ void Index<T, TagT, LabelT>::parse_label_file(const std::string &label_file, siz
line_cnt++;
}
_location_to_labels.resize(line_cnt, std::vector<LabelT>());
_location_to_labels_robin.resize(line_cnt, tsl::robin_set<LabelT>());
_location_to_labels_bitmap.resize(line_cnt, roaring::Roaring());
/* _location_to_labels_robin.resize(line_cnt, tsl::robin_set<LabelT>()); */
/* _location_to_labels_bitmap.resize(line_cnt, roaring::Roaring()); */

infile.clear();
infile.seekg(0, std::ios::beg);
line_cnt = 0;
_labels_to_points.resize(5000); // TODO TODO TODO
while (std::getline(infile, line))
{
std::istringstream iss(line);
Expand All @@ -2067,10 +2074,18 @@ void Index<T, TagT, LabelT>::parse_label_file(const std::string &label_file, siz
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
lbls.push_back(token_as_num);
_location_to_labels_bitmap[line_cnt].add(token_as_num);
_location_to_labels_robin[line_cnt].insert(token_as_num);
_labels_to_points[token_as_num].add(line_cnt);
//_location_to_labels_bitmap[line_cnt].add(token_as_num);
//_location_to_labels_robin[line_cnt].insert(token_as_num);
_labels.insert(token_as_num);
try
{
_labels_to_points.at(token_as_num).add(line_cnt);
}
catch (const std::out_of_range &oor)
{
_labels_to_points.resize(token_as_num + 1);
_labels_to_points.at(token_as_num).add(line_cnt);
}
}

std::sort(lbls.begin(), lbls.end());
Expand Down

0 comments on commit 6581827

Please sign in to comment.