Skip to content

Commit

Permalink
Merge branch 'fix-update-hnswlib' of https://github.com/jina-ai/annlite
Browse files Browse the repository at this point in the history
… into feat-build-wheels
  • Loading branch information
OrangeSodahub committed Oct 13, 2022
2 parents 995d34c + 81778cc commit 14afbd7
Show file tree
Hide file tree
Showing 6 changed files with 2,141 additions and 2,004 deletions.
281 changes: 144 additions & 137 deletions include/hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,141 +5,148 @@
#include <unordered_map>

namespace hnswlib {
template <typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
public:
BruteforceSearch(SpaceInterface<dist_t> *s) {}
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location) {
loadIndex(location, s);
}

BruteforceSearch(SpaceInterface<dist_t> *s, size_t maxElements) {
maxelements_ = maxElements;
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *)malloc(maxElements * size_per_element_);
if (data_ == nullptr)
std::runtime_error(
"Not enough memory: BruteforceSearch failed to allocate data");
cur_element_count = 0;
}

~BruteforceSearch() { free(data_); }

char *data_;
size_t maxelements_;
size_t cur_element_count;
size_t size_per_element_;

size_t data_size_;
DISTFUNC<dist_t> fstdistfunc_;
void *dist_func_param_;
std::mutex index_lock;

std::unordered_map<labeltype, size_t> dict_external_to_internal;

void addPoint(const void *datapoint, labeltype label) {

int idx;
{
std::unique_lock<std::mutex> lock(index_lock);

auto search = dict_external_to_internal.find(label);
if (search != dict_external_to_internal.end()) {
idx = search->second;
} else {
if (cur_element_count >= maxelements_) {
throw std::runtime_error(
"The number of elements exceeds the specified limit\n");
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
public:
BruteforceSearch(SpaceInterface <dist_t> *s) {

}
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location) {
loadIndex(location, s);
}

BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
maxelements_ = maxElements;
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxElements * size_per_element_);
if (data_ == nullptr)
std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
cur_element_count = 0;
}

~BruteforceSearch() {
free(data_);
}

char *data_;
size_t maxelements_;
size_t cur_element_count;
size_t size_per_element_;

size_t data_size_;
DISTFUNC <dist_t> fstdistfunc_;
void *dist_func_param_;
std::mutex index_lock;

std::unordered_map<labeltype,size_t > dict_external_to_internal;

void addPoint(const void *datapoint, labeltype label) {

int idx;
{
std::unique_lock<std::mutex> lock(index_lock);



auto search=dict_external_to_internal.find(label);
if (search != dict_external_to_internal.end()) {
idx=search->second;
}
else{
if (cur_element_count >= maxelements_) {
throw std::runtime_error("The number of elements exceeds the specified limit\n");
}
idx=cur_element_count;
dict_external_to_internal[label] = idx;
cur_element_count++;
}
}
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);




};

void removePoint(labeltype cur_external) {
size_t cur_c=dict_external_to_internal[cur_external];

dict_external_to_internal.erase(cur_external);

labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
dict_external_to_internal[label]=cur_c;
memcpy(data_ + size_per_element_ * cur_c,
data_ + size_per_element_ * (cur_element_count-1),
data_size_+sizeof(labeltype));
cur_element_count--;

}


std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void *query_data, size_t k, size_t batch_index) const {
std::priority_queue<std::pair<dist_t, labeltype>> topResults;
if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_, nullptr);
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *)(data_ + size_per_element_ * i +
data_size_))));
}
dist_t lastdist = topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_, nullptr);
if (dist <= lastdist) {
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
if (topResults.size() > k)
topResults.pop();
lastdist = topResults.top().first;
}

}
return topResults;
};

void saveIndex(const std::string &location) {
std::ofstream output(location, std::ios::binary);
std::streampos position;

writeBinaryPOD(output, maxelements_);
writeBinaryPOD(output, size_per_element_);
writeBinaryPOD(output, cur_element_count);

output.write(data_, maxelements_ * size_per_element_);

output.close();
}

void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {


std::ifstream input(location, std::ios::binary);
std::streampos position;

readBinaryPOD(input, maxelements_);
readBinaryPOD(input, size_per_element_);
readBinaryPOD(input, cur_element_count);

data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxelements_ * size_per_element_);
if (data_ == nullptr)
std::runtime_error("Not enough memory: loadIndex failed to allocate data");

input.read(data_, maxelements_ * size_per_element_);

input.close();

}
idx = cur_element_count;
dict_external_to_internal[label] = idx;
cur_element_count++;
}
}
memcpy(data_ + size_per_element_ * idx + data_size_, &label,
sizeof(labeltype));
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
};

void removePoint(labeltype cur_external) {
size_t cur_c = dict_external_to_internal[cur_external];

dict_external_to_internal.erase(cur_external);

labeltype label =
*((labeltype *)(data_ + size_per_element_ * (cur_element_count - 1) +
data_size_));
dict_external_to_internal[label] = cur_c;
memcpy(data_ + size_per_element_ * cur_c,
data_ + size_per_element_ * (cur_element_count - 1),
data_size_ + sizeof(labeltype));
cur_element_count--;
}

std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void *query_data, size_t k, size_t batch_index) const {
std::priority_queue<std::pair<dist_t, labeltype>> topResults;
if (cur_element_count == 0)
return topResults;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i,
dist_func_param_, nullptr);
topResults.push(std::pair<dist_t, labeltype>(
dist, *((labeltype *)(data_ + size_per_element_ * i + data_size_))));
}
dist_t lastdist = topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i,
dist_func_param_, nullptr);
if (dist <= lastdist) {
topResults.push(std::pair<dist_t, labeltype>(
dist,
*((labeltype *)(data_ + size_per_element_ * i + data_size_))));
if (topResults.size() > k)
topResults.pop();
lastdist = topResults.top().first;
}
}
return topResults;
};

void saveIndex(const std::string &location) {
std::ofstream output(location, std::ios::binary);
std::streampos position;

writeBinaryPOD(output, maxelements_);
writeBinaryPOD(output, size_per_element_);
writeBinaryPOD(output, cur_element_count);

output.write(data_, maxelements_ * size_per_element_);

output.close();
}

void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {

std::ifstream input(location, std::ios::binary);
std::streampos position;

readBinaryPOD(input, maxelements_);
readBinaryPOD(input, size_per_element_);
readBinaryPOD(input, cur_element_count);

data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *)malloc(maxelements_ * size_per_element_);
if (data_ == nullptr)
std::runtime_error(
"Not enough memory: loadIndex failed to allocate data");

input.read(data_, maxelements_ * size_per_element_);

input.close();
}
};
} // namespace hnswlib

};
}
Loading

0 comments on commit 14afbd7

Please sign in to comment.