Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rnsg master #2831

Merged
merged 2 commits into from Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -23,6 +23,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#2752 Milvus formats vectors data to double-precision and return to http client
- \#2767 fix a bug of getting wrong nprobe limitation in knowhere on GPU version
- \#2776 Fix too many data copies during creating IVF index
- \#2813 To implemente RNSG IP

## Feature
- \#2319 Redo metadata to support MVCC
Expand Down
13 changes: 12 additions & 1 deletion core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp
Expand Up @@ -140,7 +140,18 @@ NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) {
b_params.search_length = config[IndexParams::search_length];

GETTENSORWITHIDS(dataset_ptr)
index_ = std::make_shared<impl::NsgIndex>(dim, rows, config[Metric::TYPE].get<std::string>());

impl::NsgIndex::Metric_Type metric;
auto metric_str = config[Metric::TYPE].get<std::string>();
if (metric_str == knowhere::Metric::IP) {
metric = impl::NsgIndex::Metric_Type::Metric_Type_IP;
} else if (metric_str == knowhere::Metric::L2) {
metric = impl::NsgIndex::Metric_Type::Metric_Type_L2;
} else {
KNOWHERE_THROW_MSG("Metric is not supported");
}

index_ = std::make_shared<impl::NsgIndex>(dim, rows, metric);
index_->SetKnnGraph(knng);
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
}
Expand Down
Expand Up @@ -237,7 +237,7 @@ DistanceL2::Compare(const float* a, const float* b, unsigned size) const {

float
DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
return faiss::fvec_inner_product(a, b, (size_t)size);
return -(faiss::fvec_inner_product(a, b, (size_t)size));
}

#endif
Expand Down
Expand Up @@ -31,11 +31,11 @@ namespace impl {

unsigned int seed = 100;

NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, std::string metric)
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric)
: dimension(dimension), ntotal(n), metric_type(metric) {
if (metric == knowhere::Metric::L2) {
if (metric == Metric_Type::Metric_Type_L2) {
distance_ = new DistanceL2;
} else if (metric == knowhere::Metric::IP) {
} else if (metric == Metric_Type::Metric_Type_IP) {
distance_ = new DistanceIP;
}
}
Expand Down Expand Up @@ -407,7 +407,6 @@ NsgIndex::GetNeighbors(const float* query, float* data, std::vector<Neighbor>& r
// std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
// nearest_updated_pos << std::endl;
/////

// trick: avoid search query search_length < init_ids.size() ...
if (buffer_size + 1 < resset.size())
++buffer_size;
Expand Down Expand Up @@ -847,14 +846,16 @@ NsgIndex::Search(const float* query, float* data, const unsigned& nq, const unsi
}
}
rc.RecordSection("search");

bool is_ip = (metric_type == Metric_Type::Metric_Type_IP);
for (unsigned int i = 0; i < nq; ++i) {
unsigned int pos = 0;
for (unsigned int j = 0; j < resset[i].size(); ++j) {
if (pos >= k)
break; // already top k
if (!bitset || !bitset->test((faiss::ConcurrentBitset::id_type_t)resset[i][j].id)) {
ids[i * k + pos] = ids_[resset[i][j].id];
dist[i * k + pos] = resset[i][j].distance;
dist[i * k + pos] = is_ip ? -resset[i][j].distance : resset[i][j].distance;
++pos;
}
}
Expand Down
Expand Up @@ -43,9 +43,14 @@ using Graph = std::vector<std::vector<node_t>>;

class NsgIndex {
public:
enum Metric_Type {
Metric_Type_L2,
Metric_Type_IP,
};

size_t dimension;
size_t ntotal; // totabl nb of indexed vectors
std::string metric_type; // L2 | IP
size_t ntotal; // totabl nb of indexed vectors
int32_t metric_type; // enum Metric_Type
Distance* distance_;

// float* ori_data_;
Expand All @@ -65,7 +70,7 @@ class NsgIndex {
size_t out_degree;

public:
explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = knowhere::Metric::L2);
explicit NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric);

NsgIndex() = default;

Expand Down
Expand Up @@ -19,6 +19,7 @@ namespace impl {

void
write_index(NsgIndex* index, MemoryIOWriter& writer) {
writer(&index->metric_type, sizeof(int32_t), 1);
writer(&index->ntotal, sizeof(index->ntotal), 1);
writer(&index->dimension, sizeof(index->dimension), 1);
writer(&index->navigation_point, sizeof(index->navigation_point), 1);
Expand All @@ -36,9 +37,11 @@ NsgIndex*
read_index(MemoryIOReader& reader) {
size_t ntotal;
size_t dimension;
int32_t metric;
reader(&metric, sizeof(int32_t), 1);
reader(&ntotal, sizeof(size_t), 1);
reader(&dimension, sizeof(size_t), 1);
auto index = new NsgIndex(dimension, ntotal);
auto index = new NsgIndex(dimension, ntotal, (impl::NsgIndex::Metric_Type)metric);
reader(&index->navigation_point, sizeof(index->navigation_point), 1);

// index->ori_data_ = new float[index->ntotal * index->dimension];
Expand Down
Expand Up @@ -144,7 +144,15 @@ NSG_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_ids = dataset_ptr->Get<const int64_t*>(meta::IDS);

GETTENSOR(dataset_ptr)
index_ = std::make_shared<impl::NsgIndex>(dim, rows, config[Metric::TYPE].get<std::string>());
impl::NsgIndex::Metric_Type metric_type_nsg;
if (config[Metric::TYPE].get<std::string>() == "IP") {
metric_type_nsg = impl::NsgIndex::Metric_Type::Metric_Type_IP;
} else if (config[Metric::TYPE].get<std::string>() == "L2") {
metric_type_nsg = impl::NsgIndex::Metric_Type::Metric_Type_L2;
} else {
KNOWHERE_THROW_MSG("either IP or L2");
}
index_ = std::make_shared<impl::NsgIndex>(dim, rows, metric_type_nsg);
index_->SetKnnGraph(knng);
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
}
Expand Down