Skip to content

Commit

Permalink
Sparse knn1 (#1283)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

1. Add support for knn search on sparse vector column.
2. No index is implemented now. so knn use brute force.

TODO: 
1. Add more test case
2. Add an option to determine whether sort sparse vector when store in
database.
3. Add filter in sparse knn search.

Issue link:#1174

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Test cases
  • Loading branch information
small-turtle-1 committed Jun 3, 2024
1 parent aaffd0b commit 5a52545
Show file tree
Hide file tree
Showing 81 changed files with 3,180 additions and 1,557 deletions.
2 changes: 2 additions & 0 deletions src/executor/fragment_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ void FragmentBuilder::BuildFragments(PhysicalOperator *phys_op, PlanFragment *cu
case PhysicalOperatorType::kMergeTop:
case PhysicalOperatorType::kMergeSort:
case PhysicalOperatorType::kMergeMatchTensor:
case PhysicalOperatorType::kMergeMatchSparse:
case PhysicalOperatorType::kMergeKnn: {
current_fragment_ptr->AddOperator(phys_op);
current_fragment_ptr->SetSourceNode(query_context_ptr_, SourceType::kLocalQueue, phys_op->GetOutputNames(), phys_op->GetOutputTypes());
Expand Down Expand Up @@ -258,6 +259,7 @@ void FragmentBuilder::BuildFragments(PhysicalOperator *phys_op, PlanFragment *cu
LOG_CRITICAL(error_message);
UnrecoverableError(error_message);
}
case PhysicalOperatorType::kMatchSparseScan:
case PhysicalOperatorType::kMatchTensorScan:
case PhysicalOperatorType::kKnnScan: {
if (phys_op->left() != nullptr or phys_op->right() != nullptr) {
Expand Down
3 changes: 1 addition & 2 deletions src/executor/operator/physical_import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,13 @@ void PhysicalImport::ImportCSR(QueryContext *query_context, ImportOperatorState
file_handler->Read(&off, sizeof(i64));
i64 nnz = off - prev_off;
SizeT data_len = sparse_info->DataSize(nnz);
SizeT indice_len = sparse_info->IndiceSize(nnz);
auto tmp_indice_ptr = MakeUnique<char[]>(sizeof(i32) * nnz);
auto data_ptr = MakeUnique<char[]>(data_len);
idx_reader->Read(tmp_indice_ptr.get(), sizeof(i32) * nnz);
data_reader->Read(data_ptr.get(), data_len);
auto indice_ptr = ConvertCSRIndice(std::move(tmp_indice_ptr), sparse_info.get(), nnz);

auto value = Value::MakeSparse(nnz, std::move(indice_ptr), indice_len, std::move(data_ptr), data_len, sparse_info);
auto value = Value::MakeSparse(nnz, std::move(indice_ptr), std::move(data_ptr), sparse_info);
column_vector->AppendValue(value);

block_entry->IncreaseRowCount(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ import buffer_obj;
import create_index_info;
import knn_expr;
import chunk_index_entry;

import internal_types;
import block_entry;
import segment_index_entry;
import segment_entry;
Expand Down Expand Up @@ -179,8 +179,6 @@ TableEntry *PhysicalKnnScan::table_collection_ptr() const { return base_table_re

String PhysicalKnnScan::TableAlias() const { return base_table_ref_->alias_; }

BlockIndex *PhysicalKnnScan::GetBlockIndex() const { return base_table_ref_->block_index_.get(); }

Vector<SizeT> &PhysicalKnnScan::ColumnIDs() const { return base_table_ref_->column_ids_; }

void PhysicalKnnScan::PlanWithIndex(QueryContext *query_context) { // TODO: return base entry vector
Expand Down Expand Up @@ -535,66 +533,15 @@ void PhysicalKnnScan::ExecuteInternal(QueryContext *query_context, KnnScanOperat
merge_heap->End();
i64 result_n = std::min(knn_scan_shared_data->topk_, merge_heap->total_count());

if (!operator_state->data_block_array_.empty()) {
String error_message = "In physical_knn_scan : operator_state->data_block_array_ is not empty.";
LOG_CRITICAL(error_message);
UnrecoverableError(error_message);
SizeT query_n = knn_scan_shared_data->query_count_;
Vector<char *> result_dists_list;
Vector<RowID *> row_ids_list;
for (SizeT query_id = 0; query_id < query_n; ++query_id) {
result_dists_list.emplace_back(reinterpret_cast<char *>(merge_heap->GetDistancesByIdx(query_id)));
row_ids_list.emplace_back(merge_heap->GetIDsByIdx(query_id));
}
{
SizeT total_data_row_count = knn_scan_shared_data->query_count_ * result_n;
SizeT row_idx = 0;
do {
auto data_block = DataBlock::MakeUniquePtr();
data_block->Init(*GetOutputTypes());
operator_state->data_block_array_.emplace_back(std::move(data_block));
row_idx += DEFAULT_BLOCK_CAPACITY;
} while (row_idx < total_data_row_count);
}

SizeT output_block_row_id = 0;
SizeT output_block_idx = 0;
DataBlock *output_block_ptr = operator_state->data_block_array_[output_block_idx].get();
for (u64 query_idx = 0; query_idx < knn_scan_shared_data->query_count_; ++query_idx) {
DataType *result_dists = merge_heap->GetDistancesByIdx(query_idx);
RowID *row_ids = merge_heap->GetIDsByIdx(query_idx);

for (i64 top_idx = 0; top_idx < result_n; ++top_idx) {
SizeT id = query_idx * knn_scan_shared_data->query_count_ + top_idx;

SegmentID segment_id = row_ids[top_idx].segment_id_;
SegmentOffset segment_offset = row_ids[top_idx].segment_offset_;
BlockID block_id = segment_offset / DEFAULT_BLOCK_CAPACITY;
BlockOffset block_offset = segment_offset % DEFAULT_BLOCK_CAPACITY;

BlockEntry *block_entry = block_index->GetBlockEntry(segment_id, block_id);
if (block_entry == nullptr) {
String error_message = fmt::format("Cannot find segment id: {}, block id: {}", segment_id, block_id);
LOG_CRITICAL(error_message);
UnrecoverableError(error_message);
}

if (output_block_row_id == DEFAULT_BLOCK_CAPACITY) {
output_block_ptr->Finalize();
++output_block_idx;
output_block_ptr = operator_state->data_block_array_[output_block_idx].get();
output_block_row_id = 0;
}

SizeT column_n = base_table_ref_->column_ids_.size();
for (SizeT i = 0; i < column_n; ++i) {
SizeT column_id = base_table_ref_->column_ids_[i];
auto *block_column_entry = block_entry->GetColumnBlockEntry(column_id);
ColumnVector &&column_vector = block_column_entry->GetColumnVector(query_context->storage()->buffer_manager());

output_block_ptr->column_vectors[i]->AppendWith(column_vector, block_offset, 1);
}
output_block_ptr->AppendValueByPtr(column_n, (ptr_t)&result_dists[id]);
output_block_ptr->AppendValueByPtr(column_n + 1, (ptr_t)&row_ids[id]);

++output_block_row_id;
}
}
output_block_ptr->Finalize();
this->SetOutput(result_dists_list, row_ids_list, sizeof(DataType), result_n, query_context, operator_state);
operator_state->SetComplete();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ import infinity_exception;
import internal_types;
import data_type;
import common_query_filter;
import physical_scan_base;

namespace infinity {

export class PhysicalKnnScan final : public PhysicalOperator {
export class PhysicalKnnScan final : public PhysicalScanBase {
public:
explicit PhysicalKnnScan(u64 id,
SharedPtr<BaseTableRef> base_table_ref,
Expand All @@ -48,7 +49,7 @@ public:
SharedPtr<Vector<SharedPtr<DataType>>> output_types,
u64 knn_table_index,
SharedPtr<Vector<LoadMeta>> load_metas)
: PhysicalOperator(PhysicalOperatorType::kKnnScan, nullptr, nullptr, id, load_metas), base_table_ref_(std::move(base_table_ref)),
: PhysicalScanBase(id, PhysicalOperatorType::kKnnScan, nullptr, nullptr, base_table_ref, load_metas),
knn_expression_(std::move(knn_expression)), common_query_filter_(common_query_filter), output_names_(std::move(output_names)),
output_types_(std::move(output_types)), knn_table_index_(knn_table_index) {}

Expand All @@ -66,8 +67,6 @@ public:

[[nodiscard]] String TableAlias() const;

BlockIndex *GetBlockIndex() const;

Vector<SizeT> &ColumnIDs() const;

SizeT BlockEntryCount() const;
Expand All @@ -81,8 +80,6 @@ public:
}

public:
SharedPtr<BaseTableRef> base_table_ref_{};

SharedPtr<KnnExpression> knn_expression_{};

SharedPtr<CommonQueryFilter> common_query_filter_;
Expand Down
Loading

0 comments on commit 5a52545

Please sign in to comment.