-
Notifications
You must be signed in to change notification settings - Fork 177
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
### What problem does this PR solve? Add: simple linscan sparse index. Add: unit test and benchmark for linscan index. Issue link:#1239 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Test cases
- Loading branch information
1 parent
005bc86
commit a6d5c6a
Showing
6 changed files
with
651 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <algorithm> | ||
#include <iostream> | ||
#include <stdexcept> | ||
|
||
import stl; | ||
import file_system; | ||
import local_file_system; | ||
import file_system_type; | ||
import compilation_config; | ||
import third_party; | ||
import profiler; | ||
|
||
import linscan_alg; | ||
import sparse_iter; | ||
|
||
using namespace infinity; | ||
|
||
// const f32 error_bound = 1e-6; | ||
const int log_interval = 10000; | ||
|
||
SparseMatrix DecodeSparseDataset(const Path &data_path) { | ||
SparseMatrix ret; | ||
|
||
LocalFileSystem fs; | ||
if (!fs.Exists(data_path)) { | ||
throw std::runtime_error(fmt::format("Data path: {} does not exist.", data_path.string())); | ||
} | ||
UniquePtr<FileHandler> file_handler = fs.OpenFile(data_path.string(), FileFlags::READ_FLAG, FileLockType::kNoLock); | ||
i64 nrow = 0; | ||
i64 ncol = 0; | ||
i64 nnz = 0; | ||
file_handler->Read(&nrow, sizeof(nrow)); | ||
file_handler->Read(&ncol, sizeof(ncol)); | ||
file_handler->Read(&nnz, sizeof(nnz)); | ||
|
||
auto indptr = MakeUnique<i64[]>(nrow + 1); | ||
file_handler->Read(indptr.get(), sizeof(i64) * (nrow + 1)); | ||
if (indptr[nrow] != nnz) { | ||
throw std::runtime_error("Invalid indptr."); | ||
} | ||
|
||
auto indices = MakeUnique<i32[]>(nnz); | ||
file_handler->Read(indices.get(), sizeof(i32) * nnz); | ||
// assert all element in indices >= 0 and < ncol | ||
{ | ||
bool check = std::all_of(indices.get(), indices.get() + nnz, [ncol](i32 ele) { return ele >= 0 && ele < ncol; }); | ||
if (!check) { | ||
throw std::runtime_error("Invalid indices."); | ||
} | ||
} | ||
|
||
auto data = MakeUnique<f32[]>(nnz); | ||
file_handler->Read(data.get(), sizeof(f32) * nnz); | ||
return {std::move(data), std::move(indices), std::move(indptr), nrow, ncol, nnz}; | ||
} | ||
|
||
Pair<UniquePtr<u32[]>, UniquePtr<f32[]>> DecodeGroundtruth(const Path &groundtruth_path, u32 top_k, u32 query_n) { | ||
LocalFileSystem fs; | ||
if (!fs.Exists(groundtruth_path)) { | ||
throw std::runtime_error(fmt::format("Groundtruth path: {} does not exist.", groundtruth_path.string())); | ||
} | ||
UniquePtr<FileHandler> file_handler = fs.OpenFile(groundtruth_path.string(), FileFlags::READ_FLAG, FileLockType::kNoLock); | ||
SizeT file_size = fs.GetFileSize(*file_handler); | ||
if (file_size != sizeof(u32) * 2 + (sizeof(u32) + sizeof(float)) * (query_n * top_k)) { | ||
throw std::runtime_error("Invalid groundtruth file format"); | ||
} | ||
{ | ||
u32 ans_n = 0; | ||
file_handler->Read(&ans_n, sizeof(ans_n)); | ||
u32 top_k1 = 0; | ||
file_handler->Read(&top_k1, sizeof(top_k1)); | ||
if (ans_n != query_n || top_k1 != top_k) { | ||
throw std::runtime_error("Invalid groundtruth file format"); | ||
} | ||
} | ||
auto indices = MakeUnique<u32[]>(query_n * top_k); | ||
file_handler->Read(indices.get(), sizeof(u32) * query_n * top_k); | ||
auto scores = MakeUnique<f32[]>(query_n * top_k); | ||
file_handler->Read(scores.get(), sizeof(f32) * query_n * top_k); | ||
return {std::move(indices), std::move(scores)}; | ||
} | ||
|
||
void ImportData(LinScan &index, const Path &data_path) { | ||
SparseMatrix mat = DecodeSparseDataset(data_path); | ||
for (SparseMatrixIter iter(mat); iter.HasNext(); iter.Next()) { | ||
SparseVecRef vec = iter.val(); | ||
u32 doc_id = iter.row_id(); | ||
index.Insert(vec, doc_id); | ||
|
||
if (log_interval != 0 && doc_id % log_interval == 0) { | ||
std::cout << fmt::format("Inserting doc {}\n", doc_id); | ||
} | ||
} | ||
} | ||
|
||
Vector<Pair<Vector<u32>, Vector<f32>>> QueryData(const LinScan &index, u32 top_k, const Path &query_path) { | ||
Vector<Pair<Vector<u32>, Vector<f32>>> res; | ||
SparseMatrix mat = DecodeSparseDataset(query_path); | ||
for (SparseMatrixIter iter(mat); iter.HasNext(); iter.Next()) { | ||
SparseVecRef query = iter.val(); | ||
auto [indices, score] = index.Query(query, top_k); | ||
res.emplace_back(std::move(indices), std::move(score)); | ||
|
||
if (log_interval != 0 && iter.row_id() % log_interval == 0) { | ||
std::cout << fmt::format("Querying doc {}\n", iter.row_id()); | ||
} | ||
} | ||
return res; | ||
} | ||
|
||
void PrintQuery(u32 query_id, const u32 *gt_indices, const f32 *gt_scores, u32 gt_size, const Vector<u32> &indices, const Vector<f32> &scores) { | ||
std::cout << fmt::format("Query {}\n", query_id); | ||
for (u32 i = 0; i < gt_size; ++i) { | ||
std::cout << fmt::format("{} {}, ", indices[i], scores[i]); | ||
} | ||
std::cout << "\n"; | ||
for (u32 i = 0; i < gt_size; ++i) { | ||
std::cout << fmt::format("{} {}, ", gt_indices[i], gt_scores[i]); | ||
} | ||
std::cout << "\n"; | ||
} | ||
|
||
f32 CheckGroundtruth(const Path &groundtruth_path, const Vector<Pair<Vector<u32>, Vector<f32>>> &results, u32 top_k) { | ||
u32 query_n = results.size(); | ||
auto [gt_indices_list, gt_score_list] = DecodeGroundtruth(groundtruth_path, top_k, query_n); | ||
SizeT recall_n = 0; | ||
for (u32 i = 0; i < results.size(); ++i) { | ||
const auto &[indices, scores] = results[i]; | ||
const u32 *gt_indices = gt_indices_list.get() + i * top_k; | ||
|
||
// const f32 *gt_score = gt_score_list.get() + i * top_k; | ||
// PrintQuery(i, gt_indices, gt_score, top_k, indices, scores); | ||
HashSet<u32> indices_set(indices.begin(), indices.end()); | ||
for (u32 j = 0; j < top_k; ++j) { | ||
if (indices_set.contains(gt_indices[j])) { | ||
++recall_n; | ||
} | ||
} | ||
} | ||
f32 recall = static_cast<f32>(recall_n) / (query_n * top_k); | ||
return recall; | ||
} | ||
|
||
int main(int argc, char *argv[]) { | ||
CLI::App app{"sparse_benchmark"}; | ||
|
||
// enum class ModeType : i8 { | ||
// kImport, | ||
// kQuery, | ||
// }; | ||
// Map<String, ModeType> mode_type_map = { | ||
// {"import", ModeType::kImport}, | ||
// {"query", ModeType::kQuery}, | ||
// }; | ||
// ModeType mode_type = ModeType::kImport; | ||
// app.add_option("--mode", mode_type, "Mode type")->required()->transform(CLI::CheckedTransformer(mode_type_map, CLI::ignore_case)); | ||
|
||
enum class DataSetType : u8 { | ||
kSmall, | ||
k1M, | ||
kFull, | ||
}; | ||
Map<String, DataSetType> dataset_type_map = { | ||
{"small", DataSetType::kSmall}, | ||
{"1M", DataSetType::k1M}, | ||
{"full", DataSetType::kFull}, | ||
}; | ||
DataSetType dataset_type = DataSetType::kSmall; | ||
app.add_option("--dataset", dataset_type, "Dataset type")->required()->transform(CLI::CheckedTransformer(dataset_type_map, CLI::ignore_case)); | ||
|
||
try { | ||
app.parse(argc, argv); | ||
} catch (const CLI::ParseError &e) { | ||
return app.exit(e); | ||
} | ||
|
||
Path dataset_dir = Path(test_data_path()) / "benchmark" / "splade"; | ||
Path query_path = dataset_dir / "queries.dev.csr"; | ||
Path data_path = dataset_dir; | ||
Path groundtruth_path = dataset_dir; | ||
switch (dataset_type) { | ||
case DataSetType::kSmall: { | ||
data_path /= "base_small.csr"; | ||
groundtruth_path /= "base_small.dev.gt"; | ||
break; | ||
} | ||
case DataSetType::k1M: { | ||
data_path /= "base_1M.csr"; | ||
groundtruth_path /= "base_1M.dev.gt"; | ||
break; | ||
} | ||
case DataSetType::kFull: { | ||
data_path /= "base_full.csr"; | ||
groundtruth_path /= "base_full.dev.gt"; | ||
break; | ||
} | ||
default: { | ||
throw std::runtime_error(fmt::format("Unsupported dataset type: {}.", static_cast<u8>(dataset_type))); | ||
} | ||
}; | ||
u32 top_k = 10; | ||
|
||
// switch (mode_type) { | ||
// case ModeType::kImport: { | ||
// ImportData(data_path); | ||
// break; | ||
// } | ||
// case ModeType::kQuery: { | ||
// throw std::runtime_error("Not implemented."); | ||
// return 1; | ||
// } | ||
// default: { | ||
// throw std::runtime_error(fmt::format("Unsupported mode type: {}.", static_cast<u8>(mode_type))); | ||
// } | ||
// } | ||
BaseProfiler profiler; | ||
|
||
LinScan index; | ||
|
||
profiler.Begin(); | ||
ImportData(index, data_path); | ||
profiler.End(); | ||
std::cout << fmt::format("Import data time: {}\n", profiler.ElapsedToString(1000)); | ||
|
||
profiler.Begin(); | ||
auto query_result = QueryData(index, top_k, query_path); | ||
profiler.End(); | ||
std::cout << fmt::format("Query data time: {}\n", profiler.ElapsedToString(1000)); | ||
|
||
f32 recall = CheckGroundtruth(groundtruth_path, query_result, top_k); | ||
std::cout << fmt::format("Recall: {}\n", recall); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
module; | ||
|
||
#include <algorithm> | ||
#include <vector> | ||
|
||
module linscan_alg; | ||
|
||
import stl; | ||
|
||
namespace infinity { | ||
|
||
void LinScan::Insert(const SparseVecRef &vec, u32 doc_id) { | ||
for (i32 i = 0; i < vec.nnz_; ++i) { | ||
u32 indice = vec.indices_[i]; | ||
f32 val = vec.data_[i]; | ||
Posting posting{doc_id, val}; | ||
inverted_idx_[indice].push_back(posting); | ||
} | ||
++row_num_; | ||
} | ||
|
||
Pair<Vector<u32>, Vector<f32>> LinScan::Query(const SparseVecRef &query, u32 top_k) const { | ||
auto scores = MakeUnique<f32[]>(row_num_); | ||
for (i32 i = 0; i < query.nnz_; ++i) { | ||
u32 indice = query.indices_[i]; | ||
f32 val = query.data_[i]; | ||
|
||
auto it = inverted_idx_.find(indice); | ||
if (it == inverted_idx_.end()) { | ||
continue; | ||
} | ||
const auto &posting_list = it->second; | ||
for (const auto &posting : posting_list) { | ||
scores[posting.doc_id_] += val * posting.val_; | ||
} | ||
} | ||
|
||
u32 result_n = std::min((u32)top_k, row_num_); | ||
Vector<u32> res(row_num_); | ||
std::iota(res.begin(), res.end(), 0); | ||
std::partial_sort(res.begin(), res.begin() + result_n, res.end(), [&scores](u32 i1, u32 i2) { return scores[i1] > scores[i2]; }); | ||
res.resize(result_n); | ||
Vector<f32> res_score(result_n); | ||
std::transform(res.begin(), res.end(), res_score.begin(), [&scores](u32 i) { return scores[i]; }); | ||
return {std::move(res), std::move(res_score)}; | ||
} | ||
|
||
} // namespace infinity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
module; | ||
|
||
export module linscan_alg; | ||
|
||
import stl; | ||
import sparse_iter; | ||
|
||
namespace infinity { | ||
|
||
struct Posting { | ||
u32 doc_id_; | ||
f32 val_; | ||
}; | ||
|
||
export class LinScan { | ||
public: | ||
void Insert(const SparseVecRef &vec, u32 doc_id); | ||
|
||
Pair<Vector<u32>, Vector<f32>> Query(const SparseVecRef &query, u32 top_k) const; | ||
|
||
u32 row_num() const { return row_num_; } | ||
|
||
private: | ||
HashMap<u32, Vector<Posting>> inverted_idx_; | ||
u32 row_num_{}; | ||
}; | ||
|
||
} // namespace infinity |
Oops, something went wrong.