Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace faiss::MetricType with knowhere::MetricType #17891

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion internal/core/src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ milvus_add_pkg_config("milvus_common")

set(COMMON_SRC
Schema.cpp
Types.cpp
SystemProperty.cpp
vector_index_c.cpp
memory_c.cpp
Expand Down
7 changes: 4 additions & 3 deletions internal/core/src/common/FieldMeta.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ class FieldMeta {
Assert(is_string());
}

FieldMeta(const FieldName& name, FieldId id, DataType type, int64_t dim, std::optional<MetricType> metric_type)
FieldMeta(
const FieldName& name, FieldId id, DataType type, int64_t dim, std::optional<knowhere::MetricType> metric_type)
: name_(name), id_(id), type_(type), vector_info_(VectorInfo{dim, metric_type}) {
Assert(is_vector());
}
Expand Down Expand Up @@ -177,7 +178,7 @@ class FieldMeta {
return string_info_->max_length;
}

std::optional<MetricType>
std::optional<knowhere::MetricType>
get_metric_type() const {
Assert(is_vector());
Assert(vector_info_.has_value());
Expand Down Expand Up @@ -213,7 +214,7 @@ class FieldMeta {
private:
struct VectorInfo {
int64_t dim_;
std::optional<MetricType> metric_type_;
std::optional<knowhere::MetricType> metric_type_;
};
struct StringInfo {
int64_t max_length;
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/common/Schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) {
if (!index_map.count("metric_type")) {
schema->AddField(name, field_id, data_type, dim, std::nullopt);
} else {
auto metric_type = GetMetricType(index_map.at("metric_type"));
auto metric_type = index_map.at("metric_type");
schema->AddField(name, field_id, data_type, dim, metric_type);
}
} else if (datatype_is_string(data_type)) {
Expand Down
7 changes: 5 additions & 2 deletions internal/core/src/common/Schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ class Schema {

// auto gen field_id for convenience
FieldId
AddDebugField(const std::string& name, DataType data_type, int64_t dim, std::optional<MetricType> metric_type) {
AddDebugField(const std::string& name,
DataType data_type,
int64_t dim,
std::optional<knowhere::MetricType> metric_type) {
auto field_id = FieldId(debug_id);
debug_id++;
auto field_meta = FieldMeta(FieldName(name), field_id, data_type, dim, metric_type);
Expand Down Expand Up @@ -71,7 +74,7 @@ class Schema {
const FieldId id,
DataType data_type,
int64_t dim,
std::optional<MetricType> metric_type) {
std::optional<knowhere::MetricType> metric_type) {
auto field_meta = FieldMeta(name, id, data_type, dim, metric_type);
this->AddField(std::move(field_meta));
}
Expand Down
64 changes: 0 additions & 64 deletions internal/core/src/common/Types.cpp

This file was deleted.

15 changes: 5 additions & 10 deletions internal/core/src/common/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <NamedType/named_type.hpp>
#include <variant>

#include "knowhere/common/MetricType.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "pb/schema.pb.h"
#include "pb/segcore.pb.h"
#include "pb/plan.pb.h"
Expand Down Expand Up @@ -68,19 +68,14 @@ using ScalarArray = proto::schema::ScalarField;
using DataArray = proto::schema::FieldData;
using VectorArray = proto::schema::VectorField;
using IdArray = proto::schema::IDs;
using MetricType = faiss::MetricType;
using InsertData = proto::segcore::InsertRecord;
using PkType = std::variant<std::monostate, int64_t, std::string>;
using Pk2OffsetType = tbb::concurrent_unordered_multimap<PkType, int64_t, std::hash<PkType>>;

MetricType
GetMetricType(const std::string& type);

std::string
MetricTypeToName(MetricType metric_type);

bool
IsPrimaryKeyDataType(DataType data_type);
inline bool
IsPrimaryKeyDataType(DataType data_type) {
return data_type == DataType::INT64 || data_type == DataType::VARCHAR;
}

// NOTE: dependent type
// used at meta-template programming
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/indexbuilder/VecIndexCreator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ VecIndexCreator::check_parameter(knowhere::Config& conf,

template <typename T>
std::optional<T>
VecIndexCreator::get_config_by_name(std::string_view name) {
VecIndexCreator::get_config_by_name(const std::string& name) {
if (config_.contains(name)) {
return knowhere::GetValueFromConfig<T>(config_, name);
}
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/indexbuilder/VecIndexCreator.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class VecIndexCreator : public IndexCreatorBase {

template <typename T>
std::optional<T>
get_config_by_name(std::string_view name);
get_config_by_name(const std::string& name);

void
StoreRawData(const knowhere::DatasetPtr& dataset);
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ Parser::ParseVecNode(const Json& out_body) {
}
}();
vec_node->search_info_.topk_ = topk;
vec_node->search_info_.metric_type_ = GetMetricType(vec_info.at("metric_type"));
vec_node->search_info_.metric_type_ = vec_info.at("metric_type");
vec_node->search_info_.search_params_ = vec_info.at("params");
vec_node->search_info_.field_id_ = field_id;
vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal");
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct SearchInfo {
int64_t topk_;
int64_t round_decimal_;
FieldId field_id_;
MetricType metric_type_;
knowhere::MetricType metric_type_;
knowhere::Config search_params_;
};

Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
auto field_id = FieldId(anns_proto.field_id());
search_info.field_id_ = field_id;

search_info.metric_type_ = GetMetricType(query_info_proto.metric_type());
search_info.metric_type_ = query_info_proto.metric_type();
search_info.topk_ = query_info_proto.topk();
search_info.round_decimal_ = query_info_proto.round_decimal();
search_info.search_params_ = json::parse(query_info_proto.search_params());
Expand Down
18 changes: 9 additions & 9 deletions internal/core/src/query/SearchBruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace milvus::query {
// copy from faiss/IndexBinaryFlat.cpp::IndexBinaryFlat::search()
// disable lint to make further migration easier
static void
binary_search(MetricType metric_type,
binary_search(const knowhere::MetricType& metric_type,
const uint8_t* xb,
int64_t ntotal,
int code_size,
Expand All @@ -36,28 +36,28 @@ binary_search(MetricType metric_type,
idx_t* labels,
const BitsetView bitset) {
using namespace faiss; // NOLINT
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
if (metric_type == knowhere::metric::JACCARD || metric_type == knowhere::metric::TANIMOTO) {
float_maxheap_array_t res = {size_t(n), size_t(k), labels, D};
binary_distance_knn_hc(METRIC_Jaccard, &res, x, xb, ntotal, code_size, bitset);

if (metric_type == METRIC_Tanimoto) {
if (metric_type == knowhere::metric::TANIMOTO) {
for (int i = 0; i < k * n; i++) {
D[i] = Jaccard_2_Tanimoto(D[i]);
}
}
} else if (metric_type == METRIC_Hamming) {
} else if (metric_type == knowhere::metric::HAMMING) {
std::vector<int32_t> int_distances(n * k);
int_maxheap_array_t res = {size_t(n), size_t(k), labels, int_distances.data()};
binary_distance_knn_hc(METRIC_Hamming, &res, x, xb, ntotal, code_size, bitset);
for (int i = 0; i < n * k; ++i) {
D[i] = int_distances[i];
}
} else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) {
} else if (metric_type == knowhere::metric::SUBSTRUCTURE || metric_type == knowhere::metric::SUPERSTRUCTURE) {
// only matched ids will be chosen, not to use heap
binary_distance_knn_mc(metric_type, x, xb, n, ntotal, k, code_size, D, labels, bitset);
auto faiss_metric_type = knowhere::GetFaissMetricType(metric_type);
binary_distance_knn_mc(faiss_metric_type, x, xb, n, ntotal, k, code_size, D, labels, bitset);
} else {
std::string msg =
std::string("binary search not support metric type: ") + segcore::MetricTypeToString(metric_type);
std::string msg = "binary search not support metric type: " + metric_type;
PanicInfo(msg);
}
}
Expand Down Expand Up @@ -97,7 +97,7 @@ FloatSearchBruteForce(const dataset::SearchDataset& dataset,
SubSearchResult sub_qr(num_queries, topk, metric_type, round_decimal);
auto query_data = reinterpret_cast<const float*>(dataset.query_data);
auto chunk_data = reinterpret_cast<const float*>(chunk_data_raw);
if (metric_type == MetricType::METRIC_L2) {
if (metric_type == knowhere::metric::L2) {
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(),
sub_qr.get_distances()};
faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, nullptr, bitset);
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/SearchOnSealed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ SearchOnSealed(const Schema& schema,

auto conf = search_info.search_params_;
knowhere::SetMetaTopk(conf, search_info.topk_);
knowhere::SetMetaMetricType(conf, MetricTypeToName(field_indexing->metric_type_));
knowhere::SetMetaMetricType(conf, field_indexing->metric_type_);
auto index_type = field_indexing->indexing_->index_type();
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type);
try {
Expand Down
14 changes: 7 additions & 7 deletions internal/core/src/query/SubSearchResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace milvus::query {

class SubSearchResult {
public:
SubSearchResult(int64_t num_queries, int64_t topk, MetricType metric_type, int64_t round_decimal)
SubSearchResult(int64_t num_queries, int64_t topk, const knowhere::MetricType& metric_type, int64_t round_decimal)
: metric_type_(metric_type),
num_queries_(num_queries),
topk_(topk),
Expand All @@ -29,15 +29,15 @@ class SubSearchResult {
}

public:
static constexpr float
init_value(MetricType metric_type) {
static float
init_value(const knowhere::MetricType& metric_type) {
return (is_descending(metric_type) ? -1 : 1) * std::numeric_limits<float>::max();
}

static constexpr bool
is_descending(MetricType metric_type) {
static bool
is_descending(const knowhere::MetricType& metric_type) {
// TODO(dog): more types
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
if (metric_type == knowhere::metric::IP) {
return true;
} else {
return false;
Expand Down Expand Up @@ -103,7 +103,7 @@ class SubSearchResult {
int64_t num_queries_;
int64_t topk_;
int64_t round_decimal_;
MetricType metric_type_;
knowhere::MetricType metric_type_;
std::vector<int64_t> seg_offsets_;
std::vector<float> distances_;
};
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/query/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace milvus::query {
namespace dataset {

struct SearchDataset {
MetricType metric_type;
knowhere::MetricType metric_type;
int64_t num_queries;
int64_t topk;
int64_t round_decimal;
Expand Down
24 changes: 12 additions & 12 deletions internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
assert(!ret_);
auto& info = node.search_info_;
Json json_body{
{"node_type", "FloatVectorANNS"}, //
{"metric_type", MetricTypeToName(info.metric_type_)}, //
{"field_id_", info.field_id_.get()}, //
{"topk", info.topk_}, //
{"search_params", info.search_params_}, //
{"placeholder_tag", node.placeholder_tag_}, //
{"node_type", "FloatVectorANNS"}, //
{"metric_type", info.metric_type_}, //
{"field_id_", info.field_id_.get()}, //
{"topk", info.topk_}, //
{"search_params", info.search_params_}, //
{"placeholder_tag", node.placeholder_tag_}, //
};
if (node.predicate_.has_value()) {
ShowExprVisitor expr_show;
Expand All @@ -75,12 +75,12 @@ ShowPlanNodeVisitor::visit(BinaryVectorANNS& node) {
assert(!ret_);
auto& info = node.search_info_;
Json json_body{
{"node_type", "BinaryVectorANNS"}, //
{"metric_type", MetricTypeToName(info.metric_type_)}, //
{"field_id_", info.field_id_.get()}, //
{"topk", info.topk_}, //
{"search_params", info.search_params_}, //
{"placeholder_tag", node.placeholder_tag_}, //
{"node_type", "BinaryVectorANNS"}, //
{"metric_type", info.metric_type_}, //
{"field_id_", info.field_id_.get()}, //
{"topk", info.topk_}, //
{"search_params", info.search_params_}, //
{"placeholder_tag", node.placeholder_tag_}, //
};
if (node.predicate_.has_value()) {
ShowExprVisitor expr_show;
Expand Down
10 changes: 4 additions & 6 deletions internal/core/src/segcore/FieldIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,13 @@ VectorFieldIndexing::get_build_params() const {
// TODO
auto type_opt = field_meta_.get_metric_type();
AssertInfo(type_opt.has_value(), "Metric type of field meta doesn't have value");
auto metric_type = type_opt.value();
auto type_name = MetricTypeToName(metric_type);
auto& metric_type = type_opt.value();
auto& config = segcore_config_.at(metric_type);
auto base_params = config.build_params;

AssertInfo(base_params.count("nlist"), "Can't get nlist from index params");
knowhere::SetMetaDim(base_params, field_meta_.get_dim());
knowhere::SetMetaMetricType(base_params, type_name);
knowhere::SetMetaMetricType(base_params, metric_type);

return base_params;
}
Expand All @@ -65,14 +64,13 @@ VectorFieldIndexing::get_search_params(int top_K) const {
// TODO
auto type_opt = field_meta_.get_metric_type();
AssertInfo(type_opt.has_value(), "Metric type of field meta doesn't have value");
auto metric_type = type_opt.value();
auto type_name = MetricTypeToName(metric_type);
auto& metric_type = type_opt.value();
auto& config = segcore_config_.at(metric_type);

auto base_params = config.search_params;
AssertInfo(base_params.count("nprobe"), "Can't get nprobe from base params");
knowhere::SetMetaTopk(base_params, top_K);
knowhere::SetMetaMetricType(base_params, type_name);
knowhere::SetMetaMetricType(base_params, metric_type);

return base_params;
}
Expand Down