Skip to content

Commit

Permalink
feat: support group_size for search_group_by(#33544) (#33720)
Browse files Browse the repository at this point in the history
related: #33544

mainly changes in three aspects:

1. enable setting group_size for group by function
2. separate normal reduce and group by reduce
3. eleminate uncessary padding in search result for reducing

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
  • Loading branch information
MrPresent-Han and MrPresent-Han committed Jul 12, 2024
1 parent 5bb0d21 commit f00c529
Show file tree
Hide file tree
Showing 28 changed files with 708 additions and 385 deletions.
5 changes: 3 additions & 2 deletions internal/core/src/common/QueryInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
namespace milvus {

struct SearchInfo {
int64_t topk_;
int64_t round_decimal_;
int64_t topk_{0};
int64_t group_size_{1};
int64_t round_decimal_{0};
FieldId field_id_;
MetricType metric_type_;
knowhere::Json search_params_;
Expand Down
3 changes: 2 additions & 1 deletion internal/core/src/common/QueryResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ struct SearchResult {
std::vector<float> distances_;
std::vector<int64_t> seg_offsets_;
std::optional<std::vector<GroupByValueType>> group_by_values_;
std::optional<int64_t> group_size_;

// first fill data during fillPrimaryKey, and then update data after reducing search results
std::vector<PkType> primary_keys_;
Expand All @@ -209,7 +210,7 @@ struct SearchResult {
std::map<FieldId, std::unique_ptr<milvus::DataArray>> output_fields_data_;

// used for reduce, filter invalid pk, get real topks count
std::vector<size_t> topk_per_nq_prefix_sum_;
std::vector<size_t> topk_per_nq_prefix_sum_{};

//Vector iterators, used for group by
std::optional<std::vector<std::shared_ptr<VectorIterator>>>
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ set(MILVUS_QUERY_SRCS
SearchOnIndex.cpp
SearchBruteForce.cpp
SubSearchResult.cpp
GroupByOperator.cpp
groupby/SearchGroupByOperator.cpp
PlanProto.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
Expand Down
6 changes: 6 additions & 0 deletions internal/core/src/query/Plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ CreateSearchPlanByExpr(const Schema& schema,
return ProtoParser(schema).CreatePlan(plan_node);
}

std::unique_ptr<Plan>
CreateSearchPlanFromPlanNode(const Schema& schema,
const proto::plan::PlanNode& plan_node) {
return ProtoParser(schema).CreatePlan(plan_node);
}

std::unique_ptr<RetrievePlan>
CreateRetrievePlanByExpr(const Schema& schema,
const void* serialized_expr_plan,
Expand Down
4 changes: 4 additions & 0 deletions internal/core/src/query/Plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ CreateSearchPlanByExpr(const Schema& schema,
const void* serialized_expr_plan,
const int64_t size);

std::unique_ptr<Plan>
CreateSearchPlanFromPlanNode(const Schema& schema,
const proto::plan::PlanNode& plan_node);

std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan,
const uint8_t* blob,
Expand Down
4 changes: 4 additions & 0 deletions internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,11 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
if (query_info_proto.group_by_field_id() > 0) {
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
search_info.group_by_field_id_ = group_by_field_id;
search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size()
: 1;
}

auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::BinaryVector) {
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/SearchOnIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License

#include "SearchOnIndex.h"
#include "query/GroupByOperator.h"
#include "query/groupby/SearchGroupByOperator.h"

namespace milvus::query {
void
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/SearchOnSealed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "query/SearchBruteForce.h"
#include "query/SearchOnSealed.h"
#include "query/helper.h"
#include "query/GroupByOperator.h"
#include "query/groupby/SearchGroupByOperator.h"

namespace milvus::query {

Expand Down
7 changes: 7 additions & 0 deletions internal/core/src/query/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,11 @@ out_of_range(int64_t t) {
return gt_ub<T>(t) || lt_lb<T>(t);
}

inline bool
dis_closer(float dis1, float dis2, const MetricType& metric_type) {
if (PositivelyRelated(metric_type))
return dis1 > dis2;
return dis1 < dis2;
}

} // namespace milvus::query
Original file line number Diff line number Diff line change
Expand Up @@ -13,94 +13,112 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "GroupByOperator.h"
#include "SearchGroupByOperator.h"
#include "common/Consts.h"
#include "segcore/SegmentSealedImpl.h"
#include "Utils.h"
#include "query/Utils.h"

namespace milvus {
namespace query {

void
GroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
const SearchInfo& search_info,
std::vector<GroupByValueType>& group_by_values,
const segcore::SegmentInternalInterface& segment,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances) {
SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
const SearchInfo& search_info,
std::vector<GroupByValueType>& group_by_values,
const segcore::SegmentInternalInterface& segment,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances,
std::vector<size_t>& topk_per_nq_prefix_sum) {
//1. get search meta
FieldId group_by_field_id = search_info.group_by_field_id_.value();
auto data_type = segment.GetFieldDataType(group_by_field_id);

int max_total_size =
search_info.topk_ * search_info.group_size_ * iterators.size();
seg_offsets.reserve(max_total_size);
distances.reserve(max_total_size);
group_by_values.reserve(max_total_size);
topk_per_nq_prefix_sum.reserve(iterators.size() + 1);
switch (data_type) {
case DataType::INT8: {
auto dataGetter = GetDataGetter<int8_t>(segment, group_by_field_id);
GroupIteratorsByType<int8_t>(iterators,
search_info.topk_,
search_info.group_size_,
*dataGetter,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
search_info.metric_type_,
topk_per_nq_prefix_sum);
break;
}
case DataType::INT16: {
auto dataGetter =
GetDataGetter<int16_t>(segment, group_by_field_id);
GroupIteratorsByType<int16_t>(iterators,
search_info.topk_,
search_info.group_size_,
*dataGetter,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
search_info.metric_type_,
topk_per_nq_prefix_sum);
break;
}
case DataType::INT32: {
auto dataGetter =
GetDataGetter<int32_t>(segment, group_by_field_id);
GroupIteratorsByType<int32_t>(iterators,
search_info.topk_,
search_info.group_size_,
*dataGetter,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
search_info.metric_type_,
topk_per_nq_prefix_sum);
break;
}
case DataType::INT64: {
auto dataGetter =
GetDataGetter<int64_t>(segment, group_by_field_id);
GroupIteratorsByType<int64_t>(iterators,
search_info.topk_,
search_info.group_size_,
*dataGetter,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
search_info.metric_type_,
topk_per_nq_prefix_sum);
break;
}
case DataType::BOOL: {
auto dataGetter = GetDataGetter<bool>(segment, group_by_field_id);
GroupIteratorsByType<bool>(iterators,
search_info.topk_,
search_info.group_size_,
*dataGetter,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
search_info.metric_type_,
topk_per_nq_prefix_sum);
break;
}
case DataType::VARCHAR: {
auto dataGetter =
GetDataGetter<std::string>(segment, group_by_field_id);
GroupIteratorsByType<std::string>(iterators,
search_info.topk_,
search_info.group_size_,
*dataGetter,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
search_info.metric_type_,
topk_per_nq_prefix_sum);
break;
}
default: {
Expand All @@ -117,43 +135,45 @@ void
GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK,
int64_t group_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type) {
const knowhere::MetricType& metrics_type,
std::vector<size_t>& topk_per_nq_prefix_sum) {
topk_per_nq_prefix_sum.push_back(0);
for (auto& iterator : iterators) {
GroupIteratorResult<T>(iterator,
topK,
group_size,
data_getter,
group_by_values,
seg_offsets,
distances,
metrics_type);
topk_per_nq_prefix_sum.push_back(seg_offsets.size());
}
}

template <typename T>
void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK,
int64_t group_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type) {
//1.
std::unordered_map<T, std::pair<int64_t, float>> groupMap;
GroupByMap<T> groupMap(topK, group_size);

//2. do iteration until fill the whole map or run out of all data
//note it may enumerate all data inside a segment and can block following
//query and search possibly
auto dis_closer = [&](float l, float r) {
if (PositivelyRelated(metrics_type))
return l > r;
return l < r;
};
while (iterator->HasNext() && groupMap.size() < topK) {
std::vector<std::tuple<int64_t, float, T>> res;
while (iterator->HasNext() && !groupMap.IsGroupResEnough()) {
auto offset_dis_pair = iterator->Next();
AssertInfo(
offset_dis_pair.has_value(),
Expand All @@ -162,38 +182,22 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
auto offset = offset_dis_pair.value().first;
auto dis = offset_dis_pair.value().second;
T row_data = data_getter.Get(offset);
auto it = groupMap.find(row_data);
if (it == groupMap.end()) {
groupMap.emplace(row_data, std::make_pair(offset, dis));
} else if (dis_closer(dis, it->second.second)) {
it->second = {offset, dis};
if (groupMap.Push(row_data)) {
res.emplace_back(offset, dis, row_data);
}
}

//3. sorted based on distances and metrics
std::vector<std::pair<T, std::pair<int64_t, float>>> sortedGroupVals(
groupMap.begin(), groupMap.end());
auto customComparator = [&](const auto& lhs, const auto& rhs) {
return dis_closer(lhs.second.second, rhs.second.second);
return dis_closer(std::get<1>(lhs), std::get<1>(rhs), metrics_type);
};
std::sort(sortedGroupVals.begin(), sortedGroupVals.end(), customComparator);
std::sort(res.begin(), res.end(), customComparator);

//4. save groupBy results
group_by_values.reserve(sortedGroupVals.size());
offsets.reserve(sortedGroupVals.size());
distances.reserve(sortedGroupVals.size());
for (auto iter = sortedGroupVals.cbegin(); iter != sortedGroupVals.cend();
iter++) {
group_by_values.emplace_back(iter->first);
offsets.push_back(iter->second.first);
distances.push_back(iter->second.second);
}

//5. padding topK results, extra memory consumed will be removed when reducing
for (std::size_t idx = groupMap.size(); idx < topK; idx++) {
offsets.push_back(INVALID_SEG_OFFSET);
distances.push_back(0.0);
group_by_values.emplace_back(std::monostate{});
for (auto iter = res.cbegin(); iter != res.cend(); iter++) {
offsets.push_back(std::get<0>(*iter));
distances.push_back(std::get<1>(*iter));
group_by_values.emplace_back(std::move(std::get<2>(*iter)));
}
}

Expand Down
Loading

0 comments on commit f00c529

Please sign in to comment.