Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: search json support + client tests #1210

Merged
merged 3 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/core/search/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,30 @@

namespace dfly::search {

// Interface for accessing hashset values with different data structures underneath.
struct HSetAccessor {
// Interface for accessing document values with different data structures underneath.
struct DocumentAccessor {
// Callback that's supplied with field values.
using FieldConsumer = std::function<bool(std::string_view)>;

virtual bool Check(FieldConsumer f, std::string_view active_field) const = 0;
};

// Wrapper around hashset accessor and optional active field.
// Wrapper around document accessor and optional active field.
struct SearchInput {
SearchInput(const HSetAccessor* hset, std::string_view active_field = {})
: hset_{hset}, active_field_{active_field} {
SearchInput(const DocumentAccessor* doc, std::string_view active_field = {})
: doc_{doc}, active_field_{active_field} {
}

SearchInput(const SearchInput& base, std::string_view active_field)
: hset_{base.hset_}, active_field_{active_field} {
: doc_{base.doc_}, active_field_{active_field} {
}

bool Check(HSetAccessor::FieldConsumer f) {
return hset_->Check(move(f), active_field_);
bool Check(DocumentAccessor::FieldConsumer f) {
return doc_->Check(move(f), active_field_);
}

private:
const HSetAccessor* hset_;
const DocumentAccessor* doc_;
std::string_view active_field_;
};

Expand Down
18 changes: 9 additions & 9 deletions src/core/search/search_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ class SearchParserTest : public ::testing::Test {
QueryDriver query_driver_;
};

class MockedHSetAccessor : public HSetAccessor {
class MockedDocument : public DocumentAccessor {
public:
using Map = std::unordered_map<std::string, std::string>;

MockedHSetAccessor() = default;
MockedHSetAccessor(std::string test_field) : hset_{{"field", test_field}} {
MockedDocument() = default;
MockedDocument(std::string test_field) : hset_{{"field", test_field}} {
}

bool Check(HSetAccessor::FieldConsumer f, string_view active_field) const override {
bool Check(DocumentAccessor::FieldConsumer f, string_view active_field) const override {
if (!active_field.empty()) {
auto it = hset_.find(string{active_field});
return f(it != hset_.end() ? it->second : "");
Expand Down Expand Up @@ -108,15 +108,15 @@ class MockedHSetAccessor : public HSetAccessor {
#define CHECK_ALL(...) \
{ \
for (auto str : {__VA_ARGS__}) { \
MockedHSetAccessor hset{str}; \
MockedDocument hset{str}; \
EXPECT_TRUE(Check(SearchInput{&hset})) << str << " failed on " << DebugExpr(); \
} \
}

#define CHECK_NONE(...) \
{ \
for (auto str : {__VA_ARGS__}) { \
MockedHSetAccessor hset{str}; \
MockedDocument hset{str}; \
EXPECT_FALSE(Check(SearchInput{&hset})) << str << " failed on " << DebugExpr(); \
} \
}
Expand Down Expand Up @@ -238,7 +238,7 @@ TEST_F(SearchParserTest, CheckParenthesisPriority) {
TEST_F(SearchParserTest, MatchField) {
ParseExpr("@f1:foo @f2:bar @f3:baz");

MockedHSetAccessor hset{};
MockedDocument hset{};
SearchInput input{&hset};

hset.Set({{"f1", "foo"}, {"f2", "bar"}, {"f3", "baz"}});
Expand All @@ -260,7 +260,7 @@ TEST_F(SearchParserTest, MatchField) {
TEST_F(SearchParserTest, MatchRange) {
ParseExpr("@f1:[1 10] @f2:[50 100]");

MockedHSetAccessor hset{};
MockedDocument hset{};
SearchInput input{&hset};

hset.Set({{"f1", "5"}, {"f2", "50"}});
Expand All @@ -282,7 +282,7 @@ TEST_F(SearchParserTest, MatchRange) {
TEST_F(SearchParserTest, CheckExprInField) {
ParseExpr("@f1:(a|b) @f2:(c d) @f3:-e");

MockedHSetAccessor hset{};
MockedDocument hset{};
SearchInput input{&hset};

hset.Set({{"f1", "a"}, {"f2", "c and d"}, {"f3", "right"}});
Expand Down
115 changes: 93 additions & 22 deletions src/server/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

#include "server/search_family.h"

#include <jsoncons/json.hpp>
#include <variant>
#include <vector>

#include "base/logging.h"
#include "core/json_object.h"
#include "core/search/search.h"
#include "facade/error.h"
#include "facade/reply_builder.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
Expand Down Expand Up @@ -36,16 +39,18 @@ using DocumentData = absl::flat_hash_map<std::string, std::string>;
using SerializedDocument = pair<std::string /*key*/, DocumentData>;
using Query = search::AstExpr;

struct BaseAccessor : public search::HSetAccessor {
using FieldConsumer = search::HSetAccessor::FieldConsumer;
// Base class for document accessors
struct BaseAccessor : public search::DocumentAccessor {
using FieldConsumer = search::DocumentAccessor::FieldConsumer;

virtual DocumentData Serialize() const = 0;
};

// Accessor for hashes stored with listpack
struct ListPackAccessor : public BaseAccessor {
using LpPtr = uint8_t*;

ListPackAccessor(LpPtr ptr) : lp_{ptr} {
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
}

bool Check(FieldConsumer f, string_view active_field) const override {
Expand Down Expand Up @@ -79,7 +84,7 @@ struct ListPackAccessor : public BaseAccessor {

while (fptr) {
string_view k = container_utils::LpGetView(fptr, intbuf[0].data());
fptr = lpNext(lp_, fptr); // skip key
fptr = lpNext(lp_, fptr);
string_view v = container_utils::LpGetView(fptr, intbuf[1].data());
fptr = lpNext(lp_, fptr);

Expand All @@ -93,8 +98,9 @@ struct ListPackAccessor : public BaseAccessor {
LpPtr lp_;
};

// Accessor for hashes stored with StringMap
struct StringMapAccessor : public BaseAccessor {
StringMapAccessor(StringMap* hset) : hset_{hset} {
explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
}

bool Check(FieldConsumer f, string_view active_field) const override {
Expand All @@ -121,7 +127,42 @@ struct StringMapAccessor : public BaseAccessor {
StringMap* hset_;
};

// Accessor for json values
struct JsonAccessor : public BaseAccessor {
explicit JsonAccessor(JsonType* json) : json_{json} {
}

bool Check(FieldConsumer f, string_view active_field) const override {
if (!active_field.empty()) {
return f(json_->get_value_or<string>(active_field, string{}));
}
for (const auto& member : json_->object_range()) {
if (f(member.value().as_string()))
return true;
}
return false;
}

DocumentData Serialize() const override {
DocumentData out{};
for (const auto& member : json_->object_range()) {
out[member.key()] = member.value().as_string();
}
return out;
}

private:
JsonType* json_;
};

unique_ptr<BaseAccessor> GetAccessor(const OpArgs& op_args, const PrimeValue& pv) {
DCHECK(pv.ObjType() == OBJ_HASH || pv.ObjType() == OBJ_JSON);

if (pv.ObjType() == OBJ_JSON) {
DCHECK(pv.GetJson());
return make_unique<JsonAccessor>(pv.GetJson());
}

if (pv.Encoding() == kEncodingListPack) {
auto ptr = reinterpret_cast<ListPackAccessor::LpPtr>(pv.RObjPtr());
return make_unique<ListPackAccessor>(ptr);
Expand All @@ -133,7 +174,7 @@ unique_ptr<BaseAccessor> GetAccessor(const OpArgs& op_args, const PrimeValue& pv

// Perform brute force search for all hashes in shard with specific prefix
// that match the query
void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
void OpSearch(const OpArgs& op_args, const SearchFamily::IndexData& index, const Query& query,
vector<SerializedDocument>* shard_out) {
auto& db_slice = op_args.shard->db_slice();
DCHECK(db_slice.IsDbValid(op_args.db_cntx.db_index));
Expand All @@ -143,12 +184,12 @@ void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
auto cb = [&](PrimeTable::iterator it) {
// Check entry is hash
const PrimeValue& pv = it->second;
if (pv.ObjType() != OBJ_HASH)
if (pv.ObjType() != index.GetObjCode())
return;

// Check key starts with prefix
string_view key = it->first.GetSlice(&scratch);
if (key.rfind(prefix, 0) != 0)
if (key.rfind(index.prefix, 0) != 0)
return;

// Check entry matches filter
Expand All @@ -167,35 +208,61 @@ void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,

void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
string_view idx = ArgS(args, 0);
string prefix;

if (args.size() > 1 && ArgS(args, 1) == "ON") {
if (ArgS(args, 2) != "HASH" || ArgS(args, 3) != "PREFIX" || ArgS(args, 4) != "1") {
(*cntx)->SendError("Only simplest config supported");
return;
IndexData index{};

for (size_t i = 1; i < args.size(); i++) {
ToUpper(&args[i]);

// [ON HASH | JSON]
if (ArgS(args, i) == "ON") {
if (++i >= args.size())
return (*cntx)->SendError(kSyntaxErr);

ToUpper(&args[i]);
string_view type = ArgS(args, i);
if (type == "HASH")
index.type = IndexData::HASH;
else if (type == "JSON")
index.type = IndexData::JSON;
else
return (*cntx)->SendError("Invalid rule type: " + string{type});
continue;
}

// [PREFIX count prefix [prefix ...]]
if (ArgS(args, i) == "PREFIX") {
if (i + 2 >= args.size())
return (*cntx)->SendError(kSyntaxErr);

if (ArgS(args, ++i) != "1")
return (*cntx)->SendError("Multiple prefixes are not supported");

index.prefix = ArgS(args, ++i);
continue;
}
prefix = ArgS(args, 5);
}

{
lock_guard lk{indices_mu_};
indices_[idx] = prefix;
indices_[idx] = move(index);
}
(*cntx)->SendOk();
}

void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
string_view index = ArgS(args, 0);
string_view index_name = ArgS(args, 0);
string_view query_str = ArgS(args, 1);

string prefix;
IndexData index;
{
lock_guard lk{indices_mu_};
auto it = indices_.find(index);
auto it = indices_.find(index_name);
if (it == indices_.end()) {
(*cntx)->SendError(string{index} + ": no such index");
(*cntx)->SendError(string{index_name} + ": no such index");
return;
}
prefix = it->second;
index = it->second;
}

Query query = search::ParseQuery(query_str);
Expand All @@ -206,7 +273,7 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {

vector<vector<SerializedDocument>> docs(shard_set->size());
cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) {
OpSearch(t->GetOpArgs(shard), prefix, query, &docs[shard->shard_id()]);
OpSearch(t->GetOpArgs(shard), index, query, &docs[shard->shard_id()]);
return OpStatus::OK;
});

Expand All @@ -228,6 +295,10 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
}
}

uint8_t SearchFamily::IndexData::GetObjCode() const {
return type == JSON ? OBJ_JSON : OBJ_HASH;
}

#define HFUNC(x) SetHandler(&SearchFamily::x)

void SearchFamily::Register(CommandRegistry* registry) {
Expand All @@ -238,6 +309,6 @@ void SearchFamily::Register(CommandRegistry* registry) {
}

Mutex SearchFamily::indices_mu_{};
absl::flat_hash_map<std::string, std::string> SearchFamily::indices_{};
absl::flat_hash_map<std::string, SearchFamily::IndexData> SearchFamily::indices_{};

} // namespace dfly
12 changes: 11 additions & 1 deletion src/server/search_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,19 @@ class SearchFamily {
public:
static void Register(CommandRegistry* registry);

struct IndexData {
enum DataType { HASH, JSON };

// Get numeric OBJ_ code
uint8_t GetObjCode() const;
Comment on lines +25 to +29
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just didn't want to include redis headers here, so that's why I didn't set enum values

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that these types are not even enums in our redis headers anyway :)


std::string prefix{};
DataType type{HASH};
};

private:
static Mutex indices_mu_;
static absl::flat_hash_map<std::string, std::string> indices_;
static absl::flat_hash_map<std::string, IndexData> indices_;
};

} // namespace dfly
29 changes: 29 additions & 0 deletions src/server/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,33 @@ TEST_F(SearchFamilyTest, NoPrefix) {
EXPECT_THAT(Run({"ft.search", "i1", "one | three"}), ArrLen(1 + 2 * 2));
}

TEST_F(SearchFamilyTest, Json) {
EXPECT_EQ(Run({"ft.create", "i1", "on", "json"}), "OK");
Run({"json.set", "k1", ".", R"({"a": "small test", "b": "some details"})"});
Run({"json.set", "k2", ".", R"({"a": "another test", "b": "more details"})"});
Run({"json.set", "k3", ".", R"({"a": "last test", "b": "secret details"})"});

VLOG(0) << Run({"json.get", "k2", "$"});

{
auto resp = Run({"ft.search", "i1", "more"});
EXPECT_THAT(resp, ArrLen(1 + 2));

auto doc = resp.GetVec();
EXPECT_THAT(doc[0], IntArg(1));
EXPECT_EQ(doc[1], "k2");
EXPECT_THAT(doc[2], ArrLen(4));
}

EXPECT_THAT(Run({"ft.search", "i1", "some|more"}), ArrLen(1 + 2 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "some|more|secret"}), ArrLen(1 + 3 * 2));

EXPECT_THAT(Run({"ft.search", "i1", "@a:last @b:details"}), ArrLen(1 + 1 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "@a:(another|small)"}), ArrLen(1 + 2 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "@a:(another|small|secret)"}), ArrLen(1 + 2 * 2));

EXPECT_THAT(Run({"ft.search", "i1", "none"}), kNoResults);
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults);
}

} // namespace dfly