Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): Query parameters #1768

Merged
merged 4 commits into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/core/search/base.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <absl/container/flat_hash_map.h>

#include <cstdint>
#include <string>
#include <string_view>
Expand All @@ -15,9 +17,23 @@ using DocId = uint32_t;
using FtVector = std::vector<float>;

// Query params represent named parameters for queries supplied via PARAMS.
// Currently its only a placeholder to pass the vector to KNN.
struct QueryParams {
FtVector knn_vec;
size_t Size() const {
return params.size();
}

std::string_view operator[](std::string_view name) const {
if (auto it = params.find(name); it != params.end())
return it->second;
return "";
}

decltype(auto) operator[](std::string_view k) {
dranikpg marked this conversation as resolved.
Show resolved Hide resolved
return params[k];
}

private:
absl::flat_hash_map<std::string, std::string> params;
};

// Interface for accessing document values with different data structures underneath.
Expand Down
10 changes: 4 additions & 6 deletions src/core/search/lexer.lex
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,15 @@ term_char [_]|\w

{dq}{str_char}*{dq} return make_StringLit(matched_view(1, 1), loc());

"$"{term_char}+ return Parser::make_PARAM(str(), loc());
"$"{term_char}+ return ParseParam(str(), loc());
"@"{term_char}+ return Parser::make_FIELD(str(), loc());

{term_char}+ return Parser::make_TERM(str(), loc());

<<EOF>> return Parser::make_YYEOF(loc());
%%

Parser::symbol_type
make_INT64 (string_view str, const Parser::location_type& loc)
{
Parser::symbol_type make_INT64 (string_view str, const Parser::location_type& loc) {
int64_t val = 0;
if (!absl::SimpleAtoi(str, &val))
throw Parser::syntax_error (loc, "not an integer or out of range: " + string(str));
Expand All @@ -87,8 +85,8 @@ make_INT64 (string_view str, const Parser::location_type& loc)

Parser::symbol_type make_StringLit(string_view src, const Parser::location_type& loc) {
string res;
if (!absl::CUnescape(src, &res)) {
if (!absl::CUnescape(src, &res))
throw Parser::syntax_error (loc, "bad escaped string: " + string(src));
}

return Parser::make_TERM(res, loc);
}
3 changes: 2 additions & 1 deletion src/core/search/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
// Added to cc file
%code {
#include "core/search/query_driver.h"
#include "core/search/vector.h"

// Have to disable because GCC doesn't understand `symbol_type`'s union
// implementation
Expand Down Expand Up @@ -80,7 +81,7 @@ final_query:
filter
{ driver->Set(move($1)); }
| filter ARROW LBRACKET KNN INT64 FIELD TERM RBRACKET
{ driver->Set(AstKnnNode(move($1), $5, $6, driver->GetParams().knn_vec)); }
{ driver->Set(AstKnnNode(move($1), $5, $6, BytesToFtVector($7))); }

filter:
search_expr { $$ = move($1); }
Expand Down
2 changes: 2 additions & 0 deletions src/core/search/query_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ QueryDriver::~QueryDriver() {

void QueryDriver::ResetScanner() {
scanner_ = std::make_unique<Scanner>();
scanner_->SetParams(params_);
}

} // namespace search

} // namespace dfly
19 changes: 10 additions & 9 deletions src/core/search/query_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@ class QueryDriver {
QueryDriver();
~QueryDriver();

Scanner* scanner() {
return scanner_.get();
}

void SetInput(std::string str) {
cur_str_ = std::move(str);
scanner()->in(cur_str_);
}

void SetParams(QueryParams params) {
params_ = std::move(params);
void SetParams(const QueryParams* params) {
params_ = params;
scanner_->SetParams(params);
}

Parser::symbol_type Lex() {
Expand All @@ -47,15 +44,19 @@ class QueryDriver {
return std::move(expr_);
}

const QueryParams& GetParams() {
return params_;
const QueryParams& GetParams() const {
return *params_;
}

Scanner* scanner() {
return scanner_.get();
}

public:
Parser::location_type location;

private:
QueryParams params_;
const QueryParams* params_;
AstExpr expr_;

std::string cur_str_;
Expand Down
28 changes: 27 additions & 1 deletion src/core/search/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,24 @@
#include "core/search/lexer.h"
#endif

#include <absl/strings/str_cat.h>

#include "base/logging.h"

namespace dfly {
namespace search {

class Scanner : public Lexer {
public:
Scanner() {
Scanner() : params_{nullptr} {
}

Parser::symbol_type Lex();

void SetParams(const QueryParams* params) {
params_ = params;
}

private:
std::string_view matched_view(size_t skip_left = 0, size_t skip_right = 0) const {
std::string_view res(matcher().begin() + skip_left, matcher().size() - skip_left - skip_right);
Expand All @@ -29,6 +37,24 @@ class Scanner : public Lexer {
dfly::search::location loc() {
return location();
}

Parser::symbol_type ParseParam(std::string_view name, const Parser::location_type& loc) {
if (name.size() > 0)
dranikpg marked this conversation as resolved.
Show resolved Hide resolved
name.remove_prefix(1);

std::string_view str = (*params_)[name];
if (str.empty())
throw std::runtime_error(absl::StrCat("Query parameter ", name, " not found"));

int64_t val = 0;
if (!absl::SimpleAtoi(str, &val))
return Parser::make_TERM(std::string{str}, loc);

return Parser::make_INT64(val, loc);
}

private:
const QueryParams* params_;
};

} // namespace search
Expand Down
4 changes: 2 additions & 2 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace dfly::search {

namespace {

AstExpr ParseQuery(std::string_view query, const QueryParams& params) {
AstExpr ParseQuery(std::string_view query, const QueryParams* params) {
QueryDriver driver{};
driver.ResetScanner();
driver.SetParams(params);
Expand Down Expand Up @@ -396,7 +396,7 @@ const vector<DocId>& FieldIndices::GetAllDocs() const {
SearchAlgorithm::SearchAlgorithm() = default;
SearchAlgorithm::~SearchAlgorithm() = default;

bool SearchAlgorithm::Init(string_view query, const QueryParams& params) {
bool SearchAlgorithm::Init(string_view query, const QueryParams* params) {
try {
query_ = make_unique<AstExpr>(ParseQuery(query, params));
return !holds_alternative<monostate>(*query_);
Expand Down
2 changes: 1 addition & 1 deletion src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class SearchAlgorithm {
~SearchAlgorithm();

// Init with query and return true if successful.
bool Init(std::string_view query, const QueryParams& params);
bool Init(std::string_view query, const QueryParams* params);

SearchResult Search(const FieldIndices* index) const;

Expand Down
18 changes: 16 additions & 2 deletions src/core/search/search_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class SearchParserTest : public ::testing::Test {
return Parser(&query_driver_)();
}

void SetParams(const QueryParams* params) {
query_driver_.SetParams(params);
}

QueryDriver query_driver_;
};

Expand Down Expand Up @@ -79,8 +83,7 @@ TEST_F(SearchParserTest, Scanner) {
SetInput(R"( "hello\"world" )");
NEXT_EQ(TOK_TERM, string, R"(hello"world)");

SetInput(" $param @field:hello");
NEXT_EQ(TOK_PARAM, string, "$param");
SetInput("@field:hello");
NEXT_EQ(TOK_FIELD, string, "@field");
NEXT_TOK(TOK_COLON);
NEXT_EQ(TOK_TERM, string, "hello");
Expand Down Expand Up @@ -111,4 +114,15 @@ TEST_F(SearchParserTest, Parse) {
EXPECT_EQ(1, Parse(" @foo: "));
}

TEST_F(SearchParserTest, ParseParams) {
QueryParams params;
params["k"] = "10";
params["name"] = "alex";
SetParams(&params);

SetInput("$name $k");
NEXT_EQ(TOK_TERM, string, "alex");
NEXT_EQ(TOK_INT64, int64_t, 10);
}

} // namespace dfly::search
44 changes: 29 additions & 15 deletions src/core/search/search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <absl/cleanup/cleanup.h>
#include <absl/container/flat_hash_map.h>
#include <absl/strings/escaping.h>
#include <absl/strings/numbers.h>
#include <absl/strings/str_split.h>
#include <gmock/gmock.h>
Expand Down Expand Up @@ -86,10 +87,6 @@ class SearchParserTest : public ::testing::Test {
EXPECT_EQ(entries_.size(), 0u) << "Missing check";
}

void SetKnnVec(FtVector vec) {
params_.knn_vec = vec;
}

void PrepareSchema(initializer_list<pair<string_view, SchemaField::FieldType>> ilist) {
schema_ = MakeSimpleSchema(ilist);
}
Expand Down Expand Up @@ -334,6 +331,10 @@ TEST_F(SearchParserTest, IntegerTerms) {
EXPECT_TRUE(Check()) << GetError();
}

std::string FtVectorToBytes(FtVector vec) {
return string{reinterpret_cast<const char*>(vec.data()), sizeof(float) * vec.size()};
}

TEST_F(SearchParserTest, SimpleKnn) {
auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}});
FieldIndices indices{schema};
Expand All @@ -346,34 +347,40 @@ TEST_F(SearchParserTest, SimpleKnn) {
}

SearchAlgorithm algo{};
QueryParams params;

// Five closest to 50
{
algo.Init("*=>[KNN 5 @pos VEC]", QueryParams{FtVector{50.0}});
params["vec"] = FtVectorToBytes(FtVector{50.0});
algo.Init("*=>[KNN 5 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(48, 49, 50, 51, 52));
}

// Five closest to 0
{
algo.Init("*=>[KNN 5 @pos VEC]", QueryParams{FtVector{0.0}});
params["vec"] = FtVectorToBytes(FtVector{0.0});
algo.Init("*=>[KNN 5 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
}

// Five closest to 20, all even
{
algo.Init("@even:{yes} =>[KNN 5 @pos VEC]", QueryParams{FtVector{20.0}});
params["vec"] = FtVectorToBytes(FtVector{20.0});
algo.Init("@even:{yes} =>[KNN 5 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(16, 18, 20, 22, 24));
}

// Three closest to 31, all odd
{
algo.Init("@even:{no} =>[KNN 3 @pos VEC]", QueryParams{FtVector{31.0}});
params["vec"] = FtVectorToBytes(FtVector{31.0});
algo.Init("@even:{no} =>[KNN 3 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(29, 31, 33));
}

// Two closest to 70.5
{
algo.Init("* =>[KNN 2 @pos VEC]", QueryParams{FtVector{70.5}});
params["vec"] = FtVectorToBytes(FtVector{70.5});
algo.Init("* =>[KNN 2 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(70, 71));
}
}
Expand All @@ -396,40 +403,47 @@ TEST_F(SearchParserTest, Simple2dKnn) {
}

SearchAlgorithm algo{};
QueryParams params;

// Single center
{
algo.Init("* =>[KNN 1 @pos VEC]", QueryParams{FtVector{0.5, 0.5}});
params["vec"] = FtVectorToBytes(FtVector{0.5, 0.5});
algo.Init("* =>[KNN 1 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(4));
}

// Lower left
{
algo.Init("* =>[KNN 4 @pos VEC]", QueryParams{FtVector{0, 0}});
params["vec"] = FtVectorToBytes(FtVector{0, 0});
algo.Init("* =>[KNN 4 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 3, 4));
}

// Upper right
{
algo.Init("* =>[KNN 4 @pos VEC]", QueryParams{FtVector{1, 1}});
params["vec"] = FtVectorToBytes(FtVector{1, 1});
algo.Init("* =>[KNN 4 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 2, 3, 4));
}

// Request more than there is
{
algo.Init("* => [KNN 10 @pos VEC]", QueryParams{FtVector{0, 0}});
params["vec"] = FtVectorToBytes(FtVector{0, 0});
algo.Init("* => [KNN 10 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
}

// Test correct order: (0.7, 0.15)
{
algo.Init("* => [KNN 10 @pos VEC]", QueryParams{FtVector{0.7, 0.15}});
params["vec"] = FtVectorToBytes(FtVector{0.7, 0.15});
algo.Init("* => [KNN 10 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(1, 4, 0, 2, 3));
}

// Test correct order: (0.8, 0.9)
{
algo.Init("* => [KNN 10 @pos VEC]", QueryParams{FtVector{0.8, 0.9}});
params["vec"] = FtVectorToBytes(FtVector{0.8, 0.9});
algo.Init("* => [KNN 10 @pos $vec]", params);
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(2, 4, 3, 1, 0));
}
}
Expand Down
Loading
Loading