Skip to content

Commit

Permalink
Fix rnsg master (#2831)
Browse files Browse the repository at this point in the history
* fix rnsg ip

Signed-off-by: cqy <yaya645@126.com>

* fix rnsg search in master

Signed-off-by: cqy <yaya645@126.com>
  • Loading branch information
cqy123456 committed Jul 14, 2020
1 parent e79cbf0 commit 777c76c
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#2752 Milvus formats vectors data to double-precision and return to http client
- \#2767 fix a bug of getting wrong nprobe limitation in knowhere on GPU version
- \#2776 Fix too many data copies during creating IVF index
- \#2813 To implemente RNSG IP

## Feature
- \#2319 Redo metadata to support MVCC
Expand Down
13 changes: 12 additions & 1 deletion core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,18 @@ NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) {
b_params.search_length = config[IndexParams::search_length];

GETTENSORWITHIDS(dataset_ptr)
index_ = std::make_shared<impl::NsgIndex>(dim, rows, config[Metric::TYPE].get<std::string>());

impl::NsgIndex::Metric_Type metric;
auto metric_str = config[Metric::TYPE].get<std::string>();
if (metric_str == knowhere::Metric::IP) {
metric = impl::NsgIndex::Metric_Type::Metric_Type_IP;
} else if (metric_str == knowhere::Metric::L2) {
metric = impl::NsgIndex::Metric_Type::Metric_Type_L2;
} else {
KNOWHERE_THROW_MSG("Metric is not supported");
}

index_ = std::make_shared<impl::NsgIndex>(dim, rows, metric);
index_->SetKnnGraph(knng);
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ DistanceL2::Compare(const float* a, const float* b, unsigned size) const {

float
DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
return faiss::fvec_inner_product(a, b, (size_t)size);
return -(faiss::fvec_inner_product(a, b, (size_t)size));
}

#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ namespace impl {

unsigned int seed = 100;

NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, std::string metric)
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric)
: dimension(dimension), ntotal(n), metric_type(metric) {
if (metric == knowhere::Metric::L2) {
if (metric == Metric_Type::Metric_Type_L2) {
distance_ = new DistanceL2;
} else if (metric == knowhere::Metric::IP) {
} else if (metric == Metric_Type::Metric_Type_IP) {
distance_ = new DistanceIP;
}
}
Expand Down Expand Up @@ -407,7 +407,6 @@ NsgIndex::GetNeighbors(const float* query, float* data, std::vector<Neighbor>& r
// std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
// nearest_updated_pos << std::endl;
/////

// trick: avoid search query search_length < init_ids.size() ...
if (buffer_size + 1 < resset.size())
++buffer_size;
Expand Down Expand Up @@ -847,14 +846,16 @@ NsgIndex::Search(const float* query, float* data, const unsigned& nq, const unsi
}
}
rc.RecordSection("search");

bool is_ip = (metric_type == Metric_Type::Metric_Type_IP);
for (unsigned int i = 0; i < nq; ++i) {
unsigned int pos = 0;
for (unsigned int j = 0; j < resset[i].size(); ++j) {
if (pos >= k)
break; // already top k
if (!bitset || !bitset->test((faiss::ConcurrentBitset::id_type_t)resset[i][j].id)) {
ids[i * k + pos] = ids_[resset[i][j].id];
dist[i * k + pos] = resset[i][j].distance;
dist[i * k + pos] = is_ip ? -resset[i][j].distance : resset[i][j].distance;
++pos;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@ using Graph = std::vector<std::vector<node_t>>;

class NsgIndex {
public:
enum Metric_Type {
Metric_Type_L2,
Metric_Type_IP,
};

size_t dimension;
size_t ntotal; // totabl nb of indexed vectors
std::string metric_type; // L2 | IP
size_t ntotal; // totabl nb of indexed vectors
int32_t metric_type; // enum Metric_Type
Distance* distance_;

// float* ori_data_;
Expand All @@ -65,7 +70,7 @@ class NsgIndex {
size_t out_degree;

public:
explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = knowhere::Metric::L2);
explicit NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric);

NsgIndex() = default;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace impl {

void
write_index(NsgIndex* index, MemoryIOWriter& writer) {
writer(&index->metric_type, sizeof(int32_t), 1);
writer(&index->ntotal, sizeof(index->ntotal), 1);
writer(&index->dimension, sizeof(index->dimension), 1);
writer(&index->navigation_point, sizeof(index->navigation_point), 1);
Expand All @@ -36,9 +37,11 @@ NsgIndex*
read_index(MemoryIOReader& reader) {
size_t ntotal;
size_t dimension;
int32_t metric;
reader(&metric, sizeof(int32_t), 1);
reader(&ntotal, sizeof(size_t), 1);
reader(&dimension, sizeof(size_t), 1);
auto index = new NsgIndex(dimension, ntotal);
auto index = new NsgIndex(dimension, ntotal, (impl::NsgIndex::Metric_Type)metric);
reader(&index->navigation_point, sizeof(index->navigation_point), 1);

// index->ori_data_ = new float[index->ntotal * index->dimension];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,15 @@ NSG_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_ids = dataset_ptr->Get<const int64_t*>(meta::IDS);

GETTENSOR(dataset_ptr)
index_ = std::make_shared<impl::NsgIndex>(dim, rows, config[Metric::TYPE].get<std::string>());
impl::NsgIndex::Metric_Type metric_type_nsg;
if (config[Metric::TYPE].get<std::string>() == "IP") {
metric_type_nsg = impl::NsgIndex::Metric_Type::Metric_Type_IP;
} else if (config[Metric::TYPE].get<std::string>() == "L2") {
metric_type_nsg = impl::NsgIndex::Metric_Type::Metric_Type_L2;
} else {
KNOWHERE_THROW_MSG("either IP or L2");
}
index_ = std::make_shared<impl::NsgIndex>(dim, rows, metric_type_nsg);
index_->SetKnnGraph(knng);
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
}
Expand Down

0 comments on commit 777c76c

Please sign in to comment.