Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support label creation via property values #1762

Merged
merged 32 commits into from Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f360555
create label via variable
DavIvek Feb 26, 2024
1eb39ec
adjust tests
DavIvek Feb 27, 2024
ef2fcd4
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Feb 27, 2024
502d7a6
make creation via variables work
DavIvek Feb 28, 2024
62c1497
change grammar so it works for multiple labels
DavIvek Feb 28, 2024
003d08c
wip
DavIvek Feb 28, 2024
147a0e5
add support for set and remove labels clauses
DavIvek Feb 29, 2024
92e6c32
wip
DavIvek Feb 29, 2024
0c734c9
fix clang-tidy issue
DavIvek Feb 29, 2024
fc84243
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Feb 29, 2024
112f2e1
fix set and remove clauses
DavIvek Mar 1, 2024
a2e21d0
add tests
DavIvek Mar 1, 2024
ec3d487
minor fixes
DavIvek Mar 1, 2024
f58697d
add more tests
DavIvek Mar 1, 2024
7c829fb
fix labels test
DavIvek Mar 1, 2024
b90b6ec
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Mar 1, 2024
fdd797e
fix error message
DavIvek Mar 4, 2024
329178d
fix grammar
DavIvek Mar 5, 2024
23f3ce0
remove unused constructors
DavIvek Mar 6, 2024
04b06e0
implement suggestions
DavIvek Mar 6, 2024
e191ec2
minor grammar change
DavIvek Mar 6, 2024
1c453cc
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Mar 6, 2024
200eaef
implement suggestions
DavIvek Mar 6, 2024
8cacb78
minor fixes and changed tests
DavIvek Mar 6, 2024
d5e0c58
use QueryLabelType in cypher_main_visitor
DavIvek Mar 6, 2024
ce016e0
minor grammar change
DavIvek Mar 7, 2024
3505082
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Mar 7, 2024
6149c2d
add multiple property lookups test case
DavIvek Mar 7, 2024
ed1fb40
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Mar 7, 2024
36a9f53
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Mar 8, 2024
42fe07b
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Mar 8, 2024
3b7b42b
Merge branch 'master' into support-label-manipulation-via-variables
DavIvek Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 10 additions & 19 deletions src/query/frontend/ast/ast.hpp
Expand Up @@ -1248,6 +1248,8 @@ class AllPropertiesLookup : public memgraph::query::Expression {
friend class AstStorage;
};

using QueryLabelType = std::variant<LabelIx, Expression *>;

class LabelsTest : public memgraph::query::Expression {
public:
static const utils::TypeInfo kType;
Expand Down Expand Up @@ -1280,8 +1282,7 @@ class LabelsTest : public memgraph::query::Expression {

protected:
LabelsTest(Expression *expression, const std::vector<LabelIx> &labels) : expression_(expression), labels_(labels) {}
LabelsTest(Expression *expression, const std::vector<std::variant<LabelIx, Expression *>> &labels)
: expression_(expression) {
LabelsTest(Expression *expression, const std::vector<QueryLabelType> &labels) : expression_(expression) {
labels_.reserve(labels.size());
for (const auto &label : labels) {
if (const auto *label_ix = std::get_if<LabelIx>(&label)) {
Expand Down Expand Up @@ -1781,7 +1782,7 @@ class NodeAtom : public memgraph::query::PatternAtom {
return visitor.PostVisit(*this);
}

std::vector<std::variant<memgraph::query::LabelIx, memgraph::query::Expression *>> labels_;
std::vector<QueryLabelType> labels_;
std::variant<std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>,
memgraph::query::ParameterLookup *>
properties_;
Expand Down Expand Up @@ -2643,7 +2644,7 @@ class SetLabels : public memgraph::query::Clause {
}

memgraph::query::Identifier *identifier_{nullptr};
std::vector<std::variant<memgraph::query::LabelIx, memgraph::query::Expression *>> labels_;
std::vector<QueryLabelType> labels_;

SetLabels *Clone(AstStorage *storage) const override {
SetLabels *object = storage->Create<SetLabels>();
Expand All @@ -2660,13 +2661,8 @@ class SetLabels : public memgraph::query::Clause {
}

protected:
SetLabels(Identifier *identifier, const std::vector<std::variant<LabelIx, Expression *>> &labels)
: identifier_(identifier), labels_(labels) {}
SetLabels(Identifier *identifier, const std::vector<LabelIx> &labels) : identifier_(identifier) {
for (const auto &label : labels) {
labels_.emplace_back(label);
}
}
SetLabels(Identifier *identifier, std::vector<QueryLabelType> labels)
: identifier_(identifier), labels_(std::move(labels)) {}

private:
friend class AstStorage;
Expand Down Expand Up @@ -2716,7 +2712,7 @@ class RemoveLabels : public memgraph::query::Clause {
}

memgraph::query::Identifier *identifier_{nullptr};
std::vector<std::variant<memgraph::query::LabelIx, memgraph::query::Expression *>> labels_;
std::vector<QueryLabelType> labels_;

RemoveLabels *Clone(AstStorage *storage) const override {
RemoveLabels *object = storage->Create<RemoveLabels>();
Expand All @@ -2733,13 +2729,8 @@ class RemoveLabels : public memgraph::query::Clause {
}

protected:
RemoveLabels(Identifier *identifier, const std::vector<std::variant<LabelIx, Expression *>> &labels)
: identifier_(identifier), labels_(labels) {}
RemoveLabels(Identifier *identifier, const std::vector<LabelIx> &labels) : identifier_(identifier) {
for (const auto &label : labels) {
labels_.emplace_back(label);
}
}
RemoveLabels(Identifier *identifier, std::vector<QueryLabelType> labels)
: identifier_(identifier), labels_(std::move(labels)) {}

private:
friend class AstStorage;
Expand Down
42 changes: 23 additions & 19 deletions src/query/frontend/ast/cypher_main_visitor.cpp
Expand Up @@ -1912,7 +1912,7 @@ antlrcpp::Any CypherMainVisitor::visitNodePattern(MemgraphCypher::NodePatternCon
anonymous_identifiers.push_back(&node->identifier_);
}
if (ctx->nodeLabels()) {
node->labels_ = std::any_cast<std::vector<std::variant<LabelIx, Expression *>>>(ctx->nodeLabels()->accept(this));
node->labels_ = std::any_cast<std::vector<QueryLabelType>>(ctx->nodeLabels()->accept(this));
}
if (ctx->properties()) {
// This can return either properties or parameters
Expand All @@ -1926,21 +1926,27 @@ antlrcpp::Any CypherMainVisitor::visitNodePattern(MemgraphCypher::NodePatternCon
}

antlrcpp::Any CypherMainVisitor::visitNodeLabels(MemgraphCypher::NodeLabelsContext *ctx) {
std::vector<std::variant<LabelIx, Expression *>> labels;
std::vector<QueryLabelType> labels;
for (auto *node_label : ctx->nodeLabel()) {
if (node_label->labelName()) {
if (node_label->labelName()->symbolicName()) {
labels.emplace_back(AddLabel(std::any_cast<std::string>(node_label->accept(this))));
} else {
// If we have a parameter, we have to resolve it.
const auto *param_lookup = std::any_cast<ParameterLookup *>(node_label->accept(this));
const auto label_name = parameters_->AtTokenPosition(param_lookup->token_position_).ValueString();
labels.emplace_back(storage_->GetLabelIx(label_name));
query_info_.is_cacheable = false; // We can't cache queries with label parameters.
}
auto *label_name = node_label->labelName();
if (label_name->symbolicName()) {
labels.emplace_back(AddLabel(std::any_cast<std::string>(node_label->accept(this))));
} else if (label_name->parameter()) {
// If we have a parameter, we have to resolve it.
const auto *param_lookup = std::any_cast<ParameterLookup *>(node_label->accept(this));
const auto label_name = parameters_->AtTokenPosition(param_lookup->token_position_).ValueString();
labels.emplace_back(storage_->GetLabelIx(label_name));
query_info_.is_cacheable = false; // We can't cache queries with label parameters.
} else {
// expression
labels.emplace_back(std::any_cast<Expression *>(node_label->accept(this)));
auto variable = std::any_cast<std::string>(label_name->variable()->accept(this));
users_identifiers.insert(variable);
auto *expression = static_cast<Expression *>(storage_->Create<Identifier>(variable));
for (auto *lookup : label_name->propertyLookup()) {
DavIvek marked this conversation as resolved.
Show resolved Hide resolved
auto key = std::any_cast<PropertyIx>(lookup->accept(this));
auto *property_lookup = storage_->Create<PropertyLookup>(expression, key);
expression = property_lookup;
}
labels.emplace_back(expression);
}
}
return labels;
Expand Down Expand Up @@ -2488,7 +2494,7 @@ antlrcpp::Any CypherMainVisitor::visitListIndexingOrSlicing(MemgraphCypher::List
antlrcpp::Any CypherMainVisitor::visitExpression2a(MemgraphCypher::Expression2aContext *ctx) {
auto *expression = std::any_cast<Expression *>(ctx->expression2b()->accept(this));
if (ctx->nodeLabels()) {
auto labels = std::any_cast<std::vector<std::variant<LabelIx, Expression *>>>(ctx->nodeLabels()->accept(this));
auto labels = std::any_cast<std::vector<QueryLabelType>>(ctx->nodeLabels()->accept(this));
expression = storage_->Create<LabelsTest>(expression, labels);
}
return expression;
Expand Down Expand Up @@ -2814,8 +2820,7 @@ antlrcpp::Any CypherMainVisitor::visitSetItem(MemgraphCypher::SetItemContext *ct
// SetLabels
auto *set_labels = storage_->Create<SetLabels>();
set_labels->identifier_ = storage_->Create<Identifier>(std::any_cast<std::string>(ctx->variable()->accept(this)));
set_labels->labels_ =
std::any_cast<std::vector<std::variant<LabelIx, Expression *>>>(ctx->nodeLabels()->accept(this));
set_labels->labels_ = std::any_cast<std::vector<QueryLabelType>>(ctx->nodeLabels()->accept(this));
return static_cast<Clause *>(set_labels);
}

Expand All @@ -2838,8 +2843,7 @@ antlrcpp::Any CypherMainVisitor::visitRemoveItem(MemgraphCypher::RemoveItemConte
// RemoveLabels
auto *remove_labels = storage_->Create<RemoveLabels>();
remove_labels->identifier_ = storage_->Create<Identifier>(std::any_cast<std::string>(ctx->variable()->accept(this)));
remove_labels->labels_ =
std::any_cast<std::vector<std::variant<LabelIx, Expression *>>>(ctx->nodeLabels()->accept(this));
remove_labels->labels_ = std::any_cast<std::vector<QueryLabelType>>(ctx->nodeLabels()->accept(this));
return static_cast<Clause *>(remove_labels);
}

Expand Down
7 changes: 5 additions & 2 deletions src/query/frontend/opencypher/grammar/Cypher.g4
Expand Up @@ -191,9 +191,12 @@ relationshipTypes : ':' relTypeName ( '|' ':'? relTypeName )* ;

nodeLabels : nodeLabel ( nodeLabel )* ;

nodeLabel : ':' (labelName | expression2b);
nodeLabel : ':' labelName ;

labelName : symbolicName | parameter;
labelName : symbolicName
| parameter
| variable ( propertyLookup )*
;
andrejtonev marked this conversation as resolved.
Show resolved Hide resolved

relTypeName : symbolicName ;

Expand Down
81 changes: 23 additions & 58 deletions src/query/plan/operator.cpp
Expand Up @@ -47,6 +47,7 @@
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/procedure/module.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/id_types.hpp"
#include "storage/v2/property_value.hpp"
#include "storage/v2/view.hpp"
#include "utils/algorithm.hpp"
Expand Down Expand Up @@ -177,6 +178,20 @@ inline void AbortCheck(ExecutionContext const &context) {
if (auto const reason = MustAbort(context); reason != AbortReason::NO_ABORT) throw HintedAbortError(reason);
}

std::vector<storage::LabelId> EvaluateLabels(const std::vector<StorageLabelType> &labels,
ExpressionEvaluator &evaluator, DbAccessor *dba) {
std::vector<storage::LabelId> result;
result.reserve(labels.size());
for (const auto &label : labels) {
if (const auto *label_atom = std::get_if<storage::LabelId>(&label)) {
result.emplace_back(*label_atom);
} else {
result.emplace_back(dba->NameToLabel(std::get<Expression *>(label)->Accept(evaluator).ValueString()));
}
}
return result;
}

} // namespace

// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
Expand Down Expand Up @@ -277,15 +292,7 @@ bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context)

if (input_cursor_->Pull(frame, context)) {
// we have to resolve the labels before we can check for permissions
DavIvek marked this conversation as resolved.
Show resolved Hide resolved
std::vector<storage::LabelId> labels;
for (const auto &label : self_.node_info_.labels) {
if (const auto *label_atom = std::get_if<storage::LabelId>(&label)) {
labels.emplace_back(*label_atom);
} else {
labels.emplace_back(
context.db_accessor->NameToLabel(std::get<Expression *>(label)->Accept(evaluator).ValueString()));
}
}
auto labels = EvaluateLabels(self_.node_info_.labels, evaluator, context.db_accessor);

#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
Expand Down Expand Up @@ -380,15 +387,7 @@ bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &cont
if (!input_cursor_->Pull(frame, context)) return false;
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::NEW);
std::vector<storage::LabelId> labels;
for (const auto &label : self_.node_info_.labels) {
if (const auto *label_atom = std::get_if<storage::LabelId>(&label)) {
labels.emplace_back(*label_atom);
} else {
labels.emplace_back(
context.db_accessor->NameToLabel(std::get<Expression *>(label)->Accept(evaluator).ValueString()));
}
}
auto labels = EvaluateLabels(self_.node_info_.labels, evaluator, context.db_accessor);

#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast()) {
Expand Down Expand Up @@ -3142,17 +3141,8 @@ void SetProperties::SetPropertiesCursor::Shutdown() { input_cursor_->Shutdown();
void SetProperties::SetPropertiesCursor::Reset() { input_cursor_->Reset(); }

SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels)
: input_(input), input_symbol_(std::move(input_symbol)), labels_(labels) {}

SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels)
: input_(input), input_symbol_(std::move(input_symbol)) {
labels_.reserve(labels.size());
for (const auto &label : labels) {
labels_.emplace_back(label);
}
}
std::vector<StorageLabelType> labels)
: input_(input), input_symbol_(std::move(input_symbol)), labels_(std::move(labels)) {}

ACCEPT_WITH_INPUT(SetLabels)

Expand All @@ -3175,15 +3165,7 @@ bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) {
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::NEW);
if (!input_cursor_->Pull(frame, context)) return false;
std::vector<storage::LabelId> labels;
for (const auto &label : self_.labels_) {
if (const auto *label_id = std::get_if<storage::LabelId>(&label)) {
labels.emplace_back(*label_id);
} else {
labels.emplace_back(
context.db_accessor->NameToLabel(std::get<query::Expression *>(label)->Accept(evaluator).ValueString()));
}
}
auto labels = EvaluateLabels(self_.labels_, evaluator, context.db_accessor);

#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
Expand Down Expand Up @@ -3321,17 +3303,8 @@ void RemoveProperty::RemovePropertyCursor::Shutdown() { input_cursor_->Shutdown(
void RemoveProperty::RemovePropertyCursor::Reset() { input_cursor_->Reset(); }

RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels)
: input_(input), input_symbol_(std::move(input_symbol)), labels_(labels) {}

RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels)
: input_(input), input_symbol_(std::move(input_symbol)) {
labels_.reserve(labels.size());
for (const auto &label : labels) {
labels_.emplace_back(label);
}
}
std::vector<StorageLabelType> labels)
: input_(input), input_symbol_(std::move(input_symbol)), labels_(std::move(labels)) {}

ACCEPT_WITH_INPUT(RemoveLabels)

Expand All @@ -3354,15 +3327,7 @@ bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &cont
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::NEW);
if (!input_cursor_->Pull(frame, context)) return false;
std::vector<storage::LabelId> labels;
for (const auto &label : self_.labels_) {
if (const auto *label_id = std::get_if<storage::LabelId>(&label)) {
labels.emplace_back(*label_id);
} else {
labels.emplace_back(
context.db_accessor->NameToLabel(std::get<query::Expression *>(label)->Accept(evaluator).ValueString()));
}
}
auto labels = EvaluateLabels(self_.labels_, evaluator, context.db_accessor);

#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
Expand Down
24 changes: 9 additions & 15 deletions src/query/plan/operator.hpp
Expand Up @@ -283,27 +283,26 @@ class Once : public memgraph::query::plan::LogicalOperator {
};

using PropertiesMapList = std::vector<std::pair<storage::PropertyId, Expression *>>;
using StorageLabelType = std::variant<storage::LabelId, Expression *>;

struct NodeCreationInfo {
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const { return kType; }

NodeCreationInfo() = default;

NodeCreationInfo(Symbol symbol, std::vector<std::variant<storage::LabelId, Expression *>> labels,
NodeCreationInfo(Symbol symbol, std::vector<StorageLabelType> labels,
std::variant<PropertiesMapList, ParameterLookup *> properties)
: symbol{std::move(symbol)}, labels{std::move(labels)}, properties{std::move(properties)} {};

NodeCreationInfo(Symbol symbol, std::vector<std::variant<storage::LabelId, Expression *>> labels,
PropertiesMapList properties)
NodeCreationInfo(Symbol symbol, std::vector<StorageLabelType> labels, PropertiesMapList properties)
: symbol{std::move(symbol)}, labels{std::move(labels)}, properties{std::move(properties)} {};

NodeCreationInfo(Symbol symbol, std::vector<std::variant<storage::LabelId, Expression *>> labels,
ParameterLookup *properties)
NodeCreationInfo(Symbol symbol, std::vector<StorageLabelType> labels, ParameterLookup *properties)
: symbol{std::move(symbol)}, labels{std::move(labels)}, properties{properties} {};

Symbol symbol;
std::vector<std::variant<storage::LabelId, Expression *>> labels;
std::vector<StorageLabelType> labels;
std::variant<PropertiesMapList, ParameterLookup *> properties;

NodeCreationInfo Clone(AstStorage *storage) const {
Expand Down Expand Up @@ -1442,10 +1441,7 @@ class SetLabels : public memgraph::query::plan::LogicalOperator {

SetLabels() = default;

SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels);
SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels);
SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, std::vector<StorageLabelType> labels);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
Expand All @@ -1456,7 +1452,7 @@ class SetLabels : public memgraph::query::plan::LogicalOperator {

std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
Symbol input_symbol_;
std::vector<std::variant<storage::LabelId, query::Expression *>> labels_;
std::vector<StorageLabelType> labels_;

std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
auto object = std::make_unique<SetLabels>();
Expand Down Expand Up @@ -1534,9 +1530,7 @@ class RemoveLabels : public memgraph::query::plan::LogicalOperator {
RemoveLabels() = default;

RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels);
RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels);
std::vector<StorageLabelType> labels);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
Expand All @@ -1547,7 +1541,7 @@ class RemoveLabels : public memgraph::query::plan::LogicalOperator {

std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
Symbol input_symbol_;
std::vector<std::variant<storage::LabelId, query::Expression *>> labels_;
std::vector<StorageLabelType> labels_;

std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
auto object = std::make_unique<RemoveLabels>();
Expand Down