Skip to content

Commit

Permalink
Support label creation via property values (#1762)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavIvek committed Mar 12, 2024
1 parent a282542 commit de2e204
Show file tree
Hide file tree
Showing 18 changed files with 439 additions and 132 deletions.
42 changes: 34 additions & 8 deletions src/query/frontend/ast/ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,8 @@ class AllPropertiesLookup : public memgraph::query::Expression {
friend class AstStorage;
};

using QueryLabelType = std::variant<LabelIx, Expression *>;

class LabelsTest : public memgraph::query::Expression {
public:
static const utils::TypeInfo kType;
Expand Down Expand Up @@ -1281,6 +1283,16 @@ class LabelsTest : public memgraph::query::Expression {

protected:
LabelsTest(Expression *expression, const std::vector<LabelIx> &labels) : expression_(expression), labels_(labels) {}
LabelsTest(Expression *expression, const std::vector<QueryLabelType> &labels) : expression_(expression) {
labels_.reserve(labels.size());
for (const auto &label : labels) {
if (const auto *label_ix = std::get_if<LabelIx>(&label)) {
labels_.push_back(*label_ix);
} else {
throw SemanticException("You can't use labels in filter expressions.");
}
}
}

private:
friend class AstStorage;
Expand Down Expand Up @@ -1771,7 +1783,7 @@ class NodeAtom : public memgraph::query::PatternAtom {
return visitor.PostVisit(*this);
}

std::vector<memgraph::query::LabelIx> labels_;
std::vector<QueryLabelType> labels_;
std::variant<std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>,
memgraph::query::ParameterLookup *>
properties_;
Expand All @@ -1781,7 +1793,11 @@ class NodeAtom : public memgraph::query::PatternAtom {
object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr;
object->labels_.resize(labels_.size());
for (auto i = 0; i < object->labels_.size(); ++i) {
object->labels_[i] = storage->GetLabelIx(labels_[i].name);
if (const auto *label = std::get_if<LabelIx>(&labels_[i])) {
object->labels_[i] = storage->GetLabelIx(label->name);
} else {
object->labels_[i] = std::get<Expression *>(labels_[i])->Clone(storage);
}
}
if (const auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&properties_)) {
auto &new_obj_properties = std::get<std::unordered_map<PropertyIx, Expression *>>(object->properties_);
Expand Down Expand Up @@ -2657,20 +2673,25 @@ class SetLabels : public memgraph::query::Clause {
}

memgraph::query::Identifier *identifier_{nullptr};
std::vector<memgraph::query::LabelIx> labels_;
std::vector<QueryLabelType> labels_;

SetLabels *Clone(AstStorage *storage) const override {
SetLabels *object = storage->Create<SetLabels>();
object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr;
object->labels_.resize(labels_.size());
for (auto i = 0; i < object->labels_.size(); ++i) {
object->labels_[i] = storage->GetLabelIx(labels_[i].name);
if (const auto *label = std::get_if<LabelIx>(&labels_[i])) {
object->labels_[i] = storage->GetLabelIx(label->name);
} else {
object->labels_[i] = std::get<Expression *>(labels_[i])->Clone(storage);
}
}
return object;
}

protected:
SetLabels(Identifier *identifier, const std::vector<LabelIx> &labels) : identifier_(identifier), labels_(labels) {}
SetLabels(Identifier *identifier, std::vector<QueryLabelType> labels)
: identifier_(identifier), labels_(std::move(labels)) {}

private:
friend class AstStorage;
Expand Down Expand Up @@ -2720,20 +2741,25 @@ class RemoveLabels : public memgraph::query::Clause {
}

memgraph::query::Identifier *identifier_{nullptr};
std::vector<memgraph::query::LabelIx> labels_;
std::vector<QueryLabelType> labels_;

RemoveLabels *Clone(AstStorage *storage) const override {
RemoveLabels *object = storage->Create<RemoveLabels>();
object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr;
object->labels_.resize(labels_.size());
for (auto i = 0; i < object->labels_.size(); ++i) {
object->labels_[i] = storage->GetLabelIx(labels_[i].name);
if (const auto *label = std::get_if<LabelIx>(&labels_[i])) {
object->labels_[i] = storage->GetLabelIx(label->name);
} else {
object->labels_[i] = std::get<Expression *>(labels_[i])->Clone(storage);
}
}
return object;
}

protected:
RemoveLabels(Identifier *identifier, const std::vector<LabelIx> &labels) : identifier_(identifier), labels_(labels) {}
RemoveLabels(Identifier *identifier, std::vector<QueryLabelType> labels)
: identifier_(identifier), labels_(std::move(labels)) {}

private:
friend class AstStorage;
Expand Down
25 changes: 18 additions & 7 deletions src/query/frontend/ast/cypher_main_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1933,7 +1933,7 @@ antlrcpp::Any CypherMainVisitor::visitNodePattern(MemgraphCypher::NodePatternCon
anonymous_identifiers.push_back(&node->identifier_);
}
if (ctx->nodeLabels()) {
node->labels_ = std::any_cast<std::vector<LabelIx>>(ctx->nodeLabels()->accept(this));
node->labels_ = std::any_cast<std::vector<QueryLabelType>>(ctx->nodeLabels()->accept(this));
}
if (ctx->properties()) {
// This can return either properties or parameters
Expand All @@ -1947,16 +1947,27 @@ antlrcpp::Any CypherMainVisitor::visitNodePattern(MemgraphCypher::NodePatternCon
}

antlrcpp::Any CypherMainVisitor::visitNodeLabels(MemgraphCypher::NodeLabelsContext *ctx) {
std::vector<LabelIx> labels;
std::vector<QueryLabelType> labels;
for (auto *node_label : ctx->nodeLabel()) {
if (node_label->labelName()->symbolicName()) {
auto *label_name = node_label->labelName();
if (label_name->symbolicName()) {
labels.emplace_back(AddLabel(std::any_cast<std::string>(node_label->accept(this))));
} else {
} else if (label_name->parameter()) {
// If we have a parameter, we have to resolve it.
const auto *param_lookup = std::any_cast<ParameterLookup *>(node_label->accept(this));
const auto label_name = parameters_->AtTokenPosition(param_lookup->token_position_).ValueString();
labels.emplace_back(storage_->GetLabelIx(label_name));
query_info_.is_cacheable = false; // We can't cache queries with label parameters.
} else {
auto variable = std::any_cast<std::string>(label_name->variable()->accept(this));
users_identifiers.insert(variable);
auto *expression = static_cast<Expression *>(storage_->Create<Identifier>(variable));
for (auto *lookup : label_name->propertyLookup()) {
auto key = std::any_cast<PropertyIx>(lookup->accept(this));
auto *property_lookup = storage_->Create<PropertyLookup>(expression, key);
expression = property_lookup;
}
labels.emplace_back(expression);
}
}
return labels;
Expand Down Expand Up @@ -2504,7 +2515,7 @@ antlrcpp::Any CypherMainVisitor::visitListIndexingOrSlicing(MemgraphCypher::List
antlrcpp::Any CypherMainVisitor::visitExpression2a(MemgraphCypher::Expression2aContext *ctx) {
auto *expression = std::any_cast<Expression *>(ctx->expression2b()->accept(this));
if (ctx->nodeLabels()) {
auto labels = std::any_cast<std::vector<LabelIx>>(ctx->nodeLabels()->accept(this));
auto labels = std::any_cast<std::vector<QueryLabelType>>(ctx->nodeLabels()->accept(this));
expression = storage_->Create<LabelsTest>(expression, labels);
}
return expression;
Expand Down Expand Up @@ -2830,7 +2841,7 @@ antlrcpp::Any CypherMainVisitor::visitSetItem(MemgraphCypher::SetItemContext *ct
// SetLabels
auto *set_labels = storage_->Create<SetLabels>();
set_labels->identifier_ = storage_->Create<Identifier>(std::any_cast<std::string>(ctx->variable()->accept(this)));
set_labels->labels_ = std::any_cast<std::vector<LabelIx>>(ctx->nodeLabels()->accept(this));
set_labels->labels_ = std::any_cast<std::vector<QueryLabelType>>(ctx->nodeLabels()->accept(this));
return static_cast<Clause *>(set_labels);
}

Expand All @@ -2853,7 +2864,7 @@ antlrcpp::Any CypherMainVisitor::visitRemoveItem(MemgraphCypher::RemoveItemConte
// RemoveLabels
auto *remove_labels = storage_->Create<RemoveLabels>();
remove_labels->identifier_ = storage_->Create<Identifier>(std::any_cast<std::string>(ctx->variable()->accept(this)));
remove_labels->labels_ = std::any_cast<std::vector<LabelIx>>(ctx->nodeLabels()->accept(this));
remove_labels->labels_ = std::any_cast<std::vector<QueryLabelType>>(ctx->nodeLabels()->accept(this));
return static_cast<Clause *>(remove_labels);
}

Expand Down
5 changes: 4 additions & 1 deletion src/query/frontend/opencypher/grammar/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,10 @@ nodeLabels : nodeLabel ( nodeLabel )* ;

nodeLabel : ':' labelName ;

labelName : symbolicName | parameter;
labelName : symbolicName
| parameter
| variable ( propertyLookup )+
;

relTypeName : symbolicName ;

Expand Down
47 changes: 47 additions & 0 deletions src/query/frontend/semantic/symbol_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,44 @@ bool SymbolGenerator::PostVisit(SetProperty & /*set_property*/) {
return true;
}

bool SymbolGenerator::PreVisit(SetLabels &set_labels) {
auto &scope = scopes_.back();
scope.in_set_labels = true;
for (auto &label : set_labels.labels_) {
if (auto *expression = std::get_if<Expression *>(&label)) {
(*expression)->Accept(*this);
}
}

return true;
}

bool SymbolGenerator::PostVisit(SetLabels & /*set_labels*/) {
auto &scope = scopes_.back();
scope.in_set_labels = false;

return true;
}

bool SymbolGenerator::PreVisit(RemoveLabels &remove_labels) {
auto &scope = scopes_.back();
scope.in_remove_labels = true;
for (auto &label : remove_labels.labels_) {
if (auto *expression = std::get_if<Expression *>(&label)) {
(*expression)->Accept(*this);
}
}

return true;
}

bool SymbolGenerator::PostVisit(RemoveLabels & /*remove_labels*/) {
auto &scope = scopes_.back();
scope.in_remove_labels = false;

return true;
}

// Pattern and its subparts.

bool SymbolGenerator::PreVisit(Pattern &pattern) {
Expand Down Expand Up @@ -602,6 +640,15 @@ bool SymbolGenerator::PreVisit(NodeAtom &node_atom) {
};

scope.in_node_atom = true;

if (scope.in_create) { // you can use expressions with labels only in create
for (auto &label : node_atom.labels_) {
if (auto *expression = std::get_if<Expression *>(&label)) {
(*expression)->Accept(*this);
}
}
}

if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&node_atom.properties_)) {
bool props_or_labels = !properties->empty() || !node_atom.labels_.empty();

Expand Down
6 changes: 6 additions & 0 deletions src/query/frontend/semantic/symbol_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool PostVisit(Foreach &) override;
bool PreVisit(SetProperty & /*set_property*/) override;
bool PostVisit(SetProperty & /*set_property*/) override;
bool PreVisit(SetLabels &) override;
bool PostVisit(SetLabels & /*set_labels*/) override;
bool PreVisit(RemoveLabels &) override;
bool PostVisit(RemoveLabels & /*remove_labels*/) override;

// Expressions
ReturnType Visit(Identifier &) override;
Expand Down Expand Up @@ -130,6 +134,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool in_set_property{false};
bool in_call_subquery{false};
bool has_return{false};
bool in_set_labels{false};
bool in_remove_labels{false};
// True when visiting a pattern atom (node or edge) identifier, which can be
// reused or created in the pattern itself.
bool in_pattern_atom_identifier{false};
Expand Down
Loading

0 comments on commit de2e204

Please sign in to comment.