Skip to content

Commit

Permalink
feat: search json support + client tests
Browse files Browse the repository at this point in the history
Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
  • Loading branch information
dranikpg committed May 13, 2023
1 parent b053741 commit 7f04d63
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 32 deletions.
18 changes: 9 additions & 9 deletions src/core/search/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,30 @@

namespace dfly::search {

// Interface for accessing hashset values with different data structures underneath.
struct HSetAccessor {
// Interface for accessing document values with different data structures underneath.
struct DocumentAccessor {
// Callback that's supplied with field values.
using FieldConsumer = std::function<bool(std::string_view)>;

virtual bool Check(FieldConsumer f, std::string_view active_field) const = 0;
};

// Wrapper around hashset accessor and optional active field.
// Wrapper around document accessor and optional active field.
struct SearchInput {
SearchInput(const HSetAccessor* hset, std::string_view active_field = {})
: hset_{hset}, active_field_{active_field} {
SearchInput(const DocumentAccessor* doc, std::string_view active_field = {})
: doc_{doc}, active_field_{active_field} {
}

SearchInput(const SearchInput& base, std::string_view active_field)
: hset_{base.hset_}, active_field_{active_field} {
: doc_{base.doc_}, active_field_{active_field} {
}

bool Check(HSetAccessor::FieldConsumer f) {
return hset_->Check(move(f), active_field_);
bool Check(DocumentAccessor::FieldConsumer f) {
return doc_->Check(move(f), active_field_);
}

private:
const HSetAccessor* hset_;
const DocumentAccessor* doc_;
std::string_view active_field_;
};

Expand Down
4 changes: 2 additions & 2 deletions src/core/search/search_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ class SearchParserTest : public ::testing::Test {
QueryDriver query_driver_;
};

class MockedHSetAccessor : public HSetAccessor {
class MockedHSetAccessor : public DocumentAccessor {
public:
using Map = std::unordered_map<std::string, std::string>;

MockedHSetAccessor() = default;
MockedHSetAccessor(std::string test_field) : hset_{{"field", test_field}} {
}

bool Check(HSetAccessor::FieldConsumer f, string_view active_field) const override {
bool Check(DocumentAccessor::FieldConsumer f, string_view active_field) const override {
if (!active_field.empty()) {
auto it = hset_.find(string{active_field});
return f(it != hset_.end() ? it->second : "");
Expand Down
112 changes: 92 additions & 20 deletions src/server/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@

#include "server/search_family.h"

#include <jsoncons/json.hpp>
#include <jsoncons_ext/jsonpatch/jsonpatch.hpp>
#include <jsoncons_ext/jsonpath/jsonpath.hpp>
#include <jsoncons_ext/jsonpointer/jsonpointer.hpp>
#include <variant>
#include <vector>

#include "base/logging.h"
#include "core/json_object.h"
#include "core/search/search.h"
#include "facade/error.h"
#include "facade/reply_builder.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
Expand Down Expand Up @@ -36,12 +42,14 @@ using DocumentData = absl::flat_hash_map<std::string, std::string>;
using SerializedDocument = pair<std::string /*key*/, DocumentData>;
using Query = search::AstExpr;

struct BaseAccessor : public search::HSetAccessor {
using FieldConsumer = search::HSetAccessor::FieldConsumer;
// Base class for document accessors
struct BaseAccessor : public search::DocumentAccessor {
using FieldConsumer = search::DocumentAccessor::FieldConsumer;

virtual DocumentData Serialize() const = 0;
};

// Accessor for hashes stored with listpack
struct ListPackAccessor : public BaseAccessor {
using LpPtr = uint8_t*;

Expand Down Expand Up @@ -79,7 +87,7 @@ struct ListPackAccessor : public BaseAccessor {

while (fptr) {
string_view k = container_utils::LpGetView(fptr, intbuf[0].data());
fptr = lpNext(lp_, fptr); // skip key
fptr = lpNext(lp_, fptr);
string_view v = container_utils::LpGetView(fptr, intbuf[1].data());
fptr = lpNext(lp_, fptr);

Expand All @@ -93,6 +101,7 @@ struct ListPackAccessor : public BaseAccessor {
LpPtr lp_;
};

// Accessor for hashes stored with StringMap
struct StringMapAccessor : public BaseAccessor {
StringMapAccessor(StringMap* hset) : hset_{hset} {
}
Expand Down Expand Up @@ -121,7 +130,42 @@ struct StringMapAccessor : public BaseAccessor {
StringMap* hset_;
};

// Accessor for json values
struct JsonAccessor : public BaseAccessor {
JsonAccessor(JsonType* json) : json_{json} {
}

bool Check(FieldConsumer f, string_view active_field) const override {
if (!active_field.empty()) {
return f(json_->get_value_or<string>(active_field, string{}));
}
for (const auto& member : json_->object_range()) {
if (f(member.value().as_string()))
return true;
}
return false;
}

DocumentData Serialize() const override {
DocumentData out{};
for (const auto& member : json_->object_range()) {
out[member.key()] = member.value().as_string();
}
return out;
}

private:
JsonType* json_;
};

unique_ptr<BaseAccessor> GetAccessor(const OpArgs& op_args, const PrimeValue& pv) {
DCHECK(pv.ObjType() == OBJ_HASH || pv.ObjType() == OBJ_JSON);

if (pv.ObjType() == OBJ_JSON) {
DCHECK(pv.GetJson());
return make_unique<JsonAccessor>(pv.GetJson());
}

if (pv.Encoding() == kEncodingListPack) {
auto ptr = reinterpret_cast<ListPackAccessor::LpPtr>(pv.RObjPtr());
return make_unique<ListPackAccessor>(ptr);
Expand All @@ -133,7 +177,7 @@ unique_ptr<BaseAccessor> GetAccessor(const OpArgs& op_args, const PrimeValue& pv

// Perform brute force search for all hashes in shard with specific prefix
// that match the query
void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
void OpSearch(const OpArgs& op_args, const SearchFamily::IndexData& index, const Query& query,
vector<SerializedDocument>* shard_out) {
auto& db_slice = op_args.shard->db_slice();
DCHECK(db_slice.IsDbValid(op_args.db_cntx.db_index));
Expand All @@ -143,12 +187,12 @@ void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,
auto cb = [&](PrimeTable::iterator it) {
// Check entry is hash
const PrimeValue& pv = it->second;
if (pv.ObjType() != OBJ_HASH)
if (pv.ObjType() != index.GetObjCode())
return;

// Check key starts with prefix
string_view key = it->first.GetSlice(&scratch);
if (key.rfind(prefix, 0) != 0)
if (key.rfind(index.prefix, 0) != 0)
return;

// Check entry matches filter
Expand All @@ -167,35 +211,59 @@ void OpSearch(const OpArgs& op_args, string_view prefix, const Query& query,

void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
string_view idx = ArgS(args, 0);
string prefix;

if (args.size() > 1 && ArgS(args, 1) == "ON") {
if (ArgS(args, 2) != "HASH" || ArgS(args, 3) != "PREFIX" || ArgS(args, 4) != "1") {
(*cntx)->SendError("Only simplest config supported");
return;
IndexData index{};

for (size_t i = 1; i < args.size(); i++) {
// [ON HASH | JSON]
if (ArgS(args, i) == "ON") {
if (++i >= args.size())
return (*cntx)->SendError(kSyntaxErr);

ToUpper(&args[i]);
string_view type = ArgS(args, i);
if (type == "HASH")
index.type = IndexData::HASH;
else if (type == "JSON")
index.type = IndexData::JSON;
else
return (*cntx)->SendError("Invalid rule type: " + string{type});
continue;
}

// [PREFIX count prefix [prefix ...]]
if (ArgS(args, i) == "PREFIX") {
if (i + 2 >= args.size())
return (*cntx)->SendError(kSyntaxErr);

if (ArgS(args, ++i) != "1")
return (*cntx)->SendError("Multiple prefixes are not supported");

index.prefix = ArgS(args, ++i);
continue;
}
prefix = ArgS(args, 5);
}

{
lock_guard lk{indices_mu_};
indices_[idx] = prefix;
indices_[idx] = move(index);
}
(*cntx)->SendOk();
}

void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
string_view index = ArgS(args, 0);
string_view index_name = ArgS(args, 0);
string_view query_str = ArgS(args, 1);

string prefix;
IndexData index;
{
lock_guard lk{indices_mu_};
auto it = indices_.find(index);
auto it = indices_.find(index_name);
if (it == indices_.end()) {
(*cntx)->SendError(string{index} + ": no such index");
(*cntx)->SendError(string{index_name} + ": no such index");
return;
}
prefix = it->second;
index = it->second;
}

Query query = search::ParseQuery(query_str);
Expand All @@ -206,7 +274,7 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {

vector<vector<SerializedDocument>> docs(shard_set->size());
cntx->transaction->ScheduleSingleHop([&](Transaction* t, EngineShard* shard) {
OpSearch(t->GetOpArgs(shard), prefix, query, &docs[shard->shard_id()]);
OpSearch(t->GetOpArgs(shard), index, query, &docs[shard->shard_id()]);
return OpStatus::OK;
});

Expand All @@ -228,6 +296,10 @@ void SearchFamily::FtSearch(CmdArgList args, ConnectionContext* cntx) {
}
}

uint8_t SearchFamily::IndexData::GetObjCode() const {
return type == JSON ? OBJ_JSON : OBJ_HASH;
}

#define HFUNC(x) SetHandler(&SearchFamily::x)

void SearchFamily::Register(CommandRegistry* registry) {
Expand All @@ -238,6 +310,6 @@ void SearchFamily::Register(CommandRegistry* registry) {
}

Mutex SearchFamily::indices_mu_{};
absl::flat_hash_map<std::string, std::string> SearchFamily::indices_{};
absl::flat_hash_map<std::string, SearchFamily::IndexData> SearchFamily::indices_{};

} // namespace dfly
12 changes: 11 additions & 1 deletion src/server/search_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,19 @@ class SearchFamily {
public:
static void Register(CommandRegistry* registry);

struct IndexData {
enum DataType { HASH, JSON };

// Get numeric OBJ_ code
uint8_t GetObjCode() const;

std::string prefix{};
DataType type{HASH};
};

private:
static Mutex indices_mu_;
static absl::flat_hash_map<std::string, std::string> indices_;
static absl::flat_hash_map<std::string, IndexData> indices_;
};

} // namespace dfly
27 changes: 27 additions & 0 deletions src/server/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,31 @@ TEST_F(SearchFamilyTest, NoPrefix) {
EXPECT_THAT(Run({"ft.search", "i1", "one | three"}), ArrLen(1 + 2 * 2));
}

TEST_F(SearchFamilyTest, Json) {
EXPECT_EQ(Run({"ft.create", "i1", "on", "json"}), "OK");
Run({"json.set", "k1", R"({"a": "small test", "b": "some details"})"});
Run({"json.set", "k2", R"({"a": "another test", "b": "more details"})"});
Run({"json.set", "k3", R"({"a": "last test", "b": "secret details"})"});

{
auto resp = Run({"ft.search", "i1", "more"});
EXPECT_THAT(resp, ArrLen(1 + 2));

auto doc = resp.GetVec();
EXPECT_THAT(doc[0], IntArg(0));
EXPECT_EQ(doc[1], "k2");
EXPECT_THAT(doc[2], ArrLen(2));
}

EXPECT_THAT(Run({"ft.search", "i1", "some|more"}), ArrLen(1 + 2 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "some|more|secret"}), ArrLen(1 + 3 * 2));

EXPECT_THAT(Run({"ft.search", "i1", "@a:last @b:details"}), ArrLen(1 + 1 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "@a:(another|small)"}), ArrLen(1 + 2 * 2));
EXPECT_THAT(Run({"ft.search", "i1", "@a:(another|small|secret)"}), ArrLen(1 + 2 * 2));

EXPECT_THAT(Run({"ft.search", "i1", "none"}), ArrLen(1));
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), ArrLen(1));
}

} // namespace dfly
64 changes: 64 additions & 0 deletions tests/dragonfly/search_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Test compatibility with the redis-py client search module.
Search correctness should be ensured with unit tests.
"""
import pytest
from redis import asyncio as aioredis
from .utility import *

from redis.commands.search.query import Query
from redis.commands.search.field import TextField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType

TEST_DATA = [
{"title": "First article", "content": "Long description"},
{"title": "Second article", "content": "Small text"},
{"title": "Third piece", "content": "Brief description"},
{"title": "Last piece", "content": "Interesting text"},
]

TEST_DATA_SCHEMA = [TextField("title"), TextField("content")]


async def index_td(async_client: aioredis.Redis, itype: IndexType, prefix=""):
for i, e in enumerate(TEST_DATA):
if itype == IndexType.HASH:
await async_client.hset(prefix+str(i), mapping=e)
else:
await async_client.json().set(prefix+str(i), "$", e)


def check_contains_td(docs, td_indices):
docset = set()
for doc in docs:
docset.add(f"{doc.title}//{doc.content}")

for td_index in td_indices:
td_entry = TEST_DATA[td_index]
if not f"{td_entry['title']}//{td_entry['content']}" in docset:
return False

return True


@pytest.mark.parametrize("index_type", [IndexType.HASH, IndexType.JSON])
async def test_basic(async_client, index_type):
i1 = async_client.ft("i1")
await i1.create_index(TEST_DATA_SCHEMA, definition=IndexDefinition(index_type=index_type))
await index_td(async_client, index_type)

res = await i1.search("article")
assert res.total == 2
assert check_contains_td(res.docs, [0, 1])

res = await i1.search("text")
assert res.total == 2
assert check_contains_td(res.docs, [1, 3])

res = await i1.search("brief piece")
assert res.total == 1
assert check_contains_td(res.docs, [2])

res = await i1.search("@title:(article|last) @content:text")
assert res.total == 2
assert check_contains_td(res.docs, [1, 3])

0 comments on commit 7f04d63

Please sign in to comment.