Skip to content

Commit

Permalink
add id(a) support to front-end
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed May 2, 2021
1 parent 3d738a4 commit 1efbfc1
Show file tree
Hide file tree
Showing 24 changed files with 218 additions and 152 deletions.
19 changes: 19 additions & 0 deletions src/common/expression_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ bool isExpressionLeafVariable(ExpressionType type) {
return PROPERTY == type;
}

ExpressionType comparisonToIDComparison(ExpressionType type) {
switch (type) {
case EQUALS:
return EQUALS_NODE_ID;
case NOT_EQUALS:
return NOT_EQUALS_NODE_ID;
case GREATER_THAN:
return GREATER_THAN_NODE_ID;
case GREATER_THAN_EQUALS:
return GREATER_THAN_EQUALS_NODE_ID;
case LESS_THAN:
return LESS_THAN_NODE_ID;
case LESS_THAN_EQUALS:
return LESS_THAN_EQUALS_NODE_ID;
default:
throw invalid_argument("Cannot map " + expressionTypeToString(type) + " to ID comparison.");
}
}

string expressionTypeToString(ExpressionType type) {
switch (type) {
case OR:
Expand Down
1 change: 1 addition & 0 deletions src/common/include/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ bool isExpressionNullComparison(ExpressionType type);
bool isExpressionLeafLiteral(ExpressionType type);
bool isExpressionLeafVariable(ExpressionType type);

ExpressionType comparisonToIDComparison(ExpressionType type);
string expressionTypeToString(ExpressionType type);

} // namespace common
Expand Down
5 changes: 3 additions & 2 deletions src/common/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ enum DataType : uint8_t {
INT64 = 5,
DOUBLE = 6,
STRING = 7,
UNSTRUCTURED = 8
NODE_ID = 8,
UNSTRUCTURED = 9
};

const string DataTypeNames[] = {
"REL", "NODE", "LABEL", "BOOL", "INT32", "INT64", "DOUBLE", "STRING", "UNSTRUCTURED"};
"REL", "NODE", "LABEL", "BOOL", "INT32", "INT64", "DOUBLE", "STRING", "NODE_ID", "UNKNOWN"};

int32_t convertToInt32(char* data);

Expand Down
7 changes: 4 additions & 3 deletions src/expression/include/logical/logical_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ using namespace std;
namespace graphflow {
namespace expression {

// replace this with function enum once we have multiple default functions
const string COUNT_STAR = "COUNT_STAR";
// replace this with function enum once we have more functions
const string FUNCTION_COUNT_STAR = "COUNT_STAR";
const string FUNCTION_ID = "ID";

class LogicalExpression {

Expand All @@ -40,7 +41,7 @@ class LogicalExpression {
return alias.empty() ? rawExpression : alias;
}

unordered_set<string> getIncludedVariables() const;
virtual unordered_set<string> getIncludedVariables() const;

unordered_set<string> getIncludedProperties() const;

Expand Down
4 changes: 4 additions & 0 deletions src/expression/include/logical/logical_node_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class LogicalNodeExpression : public LogicalExpression {
LogicalNodeExpression(string name, label_t label)
: LogicalExpression{VARIABLE, NODE}, name{move(name)}, label{label} {}

unordered_set<string> getIncludedVariables() const override {
return unordered_set<string>{name};
}

public:
string name;
label_t label;
Expand Down
4 changes: 4 additions & 0 deletions src/expression/include/logical/logical_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class LogicalRelExpression : public LogicalExpression {

inline string getDstNodeName() const { return dstNode->name; }

unordered_set<string> getIncludedVariables() const override {
return unordered_set<string>{name};
}

public:
string name;
label_t label;
Expand Down
3 changes: 1 addition & 2 deletions src/expression/logical/logical_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ unordered_set<string> LogicalExpression::getIncludedVariables() const {
return result;
}
if (VARIABLE == expressionType) {
result.insert(variableName);
return result;
return getIncludedVariables();
}
if (PROPERTY == expressionType) {
result.insert(variableName.substr(0, variableName.find('.')));
Expand Down
4 changes: 2 additions & 2 deletions src/planner/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ static void validateProjectionColumnNamesAreUnique(

static void validateOnlyFunctionIsCountStar(vector<shared_ptr<LogicalExpression>>& expressions) {
for (auto& expression : expressions) {
if (FUNCTION == expression->expressionType && COUNT_STAR == expression->variableName &&
1 != expressions.size()) {
if (FUNCTION == expression->expressionType &&
FUNCTION_COUNT_STAR == expression->variableName && expressions.size() != 1) {
throw invalid_argument("The only function in the return clause should be COUNT(*).");
}
}
Expand Down
39 changes: 25 additions & 14 deletions src/planner/enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ static vector<shared_ptr<LogicalExpression>> splitExpressionOnAND(

static pair<string, string> splitVariableAndPropertyName(const string& name);

static string variableNameToID(const string& name);

Enumerator::Enumerator(const Catalog& catalog, const BoundSingleQuery& boundSingleQuery)
: catalog{catalog}, boundSingleQuery{boundSingleQuery} {
subgraphPlanTable = make_unique<SubgraphPlanTable>(boundSingleQuery.getNumQueryRels());
Expand Down Expand Up @@ -221,8 +223,9 @@ void Enumerator::appendLogicalScan(uint32_t queryNodePos, LogicalPlan& plan) {
if (ANY_LABEL == queryNode.label) {
throw invalid_argument("Match any label is not yet supported in LogicalScanNodeID.");
}
auto scan = make_shared<LogicalScanNodeID>(queryNode.name, queryNode.label);
plan.schema->nameOperatorMap.insert({queryNode.name, scan.get()});
auto nodeID = variableNameToID(queryNode.name);
auto scan = make_shared<LogicalScanNodeID>(nodeID, queryNode.label);
plan.schema->nameOperatorMap.insert({nodeID, scan.get()});
plan.appendOperator(scan);
}

Expand All @@ -232,18 +235,22 @@ void Enumerator::appendLogicalExtend(uint32_t queryRelPos, Direction direction,
ANY_LABEL == queryRel.label) {
throw invalid_argument("Match any label is not yet supported in LogicalExtend");
}
auto extend = make_shared<LogicalExtend>(queryRel, direction, plan.lastOperator);
auto nbrNodeName = FWD == direction ? queryRel.getSrcNodeName() : queryRel.getDstNodeName();
plan.schema->addOperator(nbrNodeName, extend.get());
auto boundNode = FWD == direction ? queryRel.srcNode : queryRel.dstNode;
auto nbrNode = FWD == direction ? queryRel.dstNode : queryRel.srcNode;
auto boundNodeID = variableNameToID(boundNode->name);
auto nbrNodeID = variableNameToID(nbrNode->name);
auto extend = make_shared<LogicalExtend>(boundNodeID, boundNode->label, nbrNodeID,
nbrNode->label, queryRel.label, direction, plan.lastOperator);
plan.schema->addOperator(nbrNodeID, extend.get());
plan.schema->addOperator(queryRel.name, extend.get());
plan.appendOperator(extend);
}

void Enumerator::appendLogicalHashJoin(
uint32_t joinNodePos, const LogicalPlan& planToJoin, LogicalPlan& plan) {
auto joinNodeName = mergedQueryGraph->queryNodes[joinNodePos]->name;
auto joinNodeID = variableNameToID(mergedQueryGraph->queryNodes[joinNodePos]->name);
auto hashJoin =
make_shared<LogicalHashJoin>(joinNodeName, plan.lastOperator, planToJoin.lastOperator);
make_shared<LogicalHashJoin>(joinNodeID, plan.lastOperator, planToJoin.lastOperator);
for (auto& [name, op] : planToJoin.schema->nameOperatorMap) {
if (!plan.schema->containsName(name)) {
plan.schema->addOperator(name, op);
Expand All @@ -262,7 +269,7 @@ void Enumerator::appendProjection(
const vector<shared_ptr<LogicalExpression>>& returnOrWithClause, LogicalPlan& plan) {
// Do not append projection in case of RETURN COUNT(*)
if (1 == returnOrWithClause.size() && FUNCTION == returnOrWithClause[0]->expressionType &&
COUNT_STAR == returnOrWithClause[0]->variableName) {
FUNCTION_COUNT_STAR == returnOrWithClause[0]->variableName) {
return;
}
for (auto& expression : returnOrWithClause) {
Expand Down Expand Up @@ -304,18 +311,18 @@ void Enumerator::appendNecessaryScans(shared_ptr<LogicalExpression> expression,
void Enumerator::appendScanNodeProperty(
const string& nodeName, const string& propertyName, LogicalPlan& plan) {
auto queryNode = mergedQueryGraph->getQueryNode(nodeName);
auto scanProperty = make_shared<LogicalScanNodeProperty>(
queryNode->name, queryNode->label, propertyName, plan.lastOperator);
plan.schema->addOperator(nodeName + "." + propertyName, scanProperty.get());
auto scanProperty = make_shared<LogicalScanNodeProperty>(variableNameToID(queryNode->name),
queryNode->label, queryNode->name, propertyName, plan.lastOperator);
plan.schema->addOperator(queryNode->name + "." + propertyName, scanProperty.get());
plan.appendOperator(scanProperty);
}

void Enumerator::appendScanRelProperty(
const string& relName, const string& propertyName, LogicalPlan& plan) {
auto extend = (LogicalExtend*)plan.schema->getOperator(relName);
auto queryRel = mergedQueryGraph->getQueryRel(relName);
auto scanProperty = make_shared<LogicalScanRelProperty>(
*queryRel, extend->direction, propertyName, plan.lastOperator);
auto scanProperty = make_shared<LogicalScanRelProperty>(extend->boundNodeID,
extend->boundNodeLabel, extend->nbrNodeID, extend->nbrNodeLabel, relName, extend->relLabel,
extend->direction, propertyName, plan.lastOperator);
plan.schema->addOperator(relName + "." + propertyName, scanProperty.get());
plan.appendOperator(scanProperty);
}
Expand Down Expand Up @@ -369,5 +376,9 @@ pair<string, string> splitVariableAndPropertyName(const string& name) {
return make_pair(name.substr(0, splitPos), name.substr(splitPos + 1));
}

string variableNameToID(const string& name) {
return name + "._id";
}

} // namespace planner
} // namespace graphflow
29 changes: 24 additions & 5 deletions src/planner/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ shared_ptr<LogicalExpression> ExpressionBinder::bindComparisonExpression(
} else if (left->dataType != right->dataType) {
return make_shared<LogicalLiteralExpression>(LITERAL_BOOLEAN, BOOL, Value(NULL_BOOL));
}
return make_shared<LogicalExpression>(parsedExpression.type, BOOL, move(left), move(right));
return make_shared<LogicalExpression>(NODE_ID == left->dataType ?
comparisonToIDComparison(parsedExpression.type) :
parsedExpression.type,
BOOL, move(left), move(right));
}

shared_ptr<LogicalExpression> ExpressionBinder::bindBinaryArithmeticExpression(
Expand Down Expand Up @@ -211,11 +214,27 @@ shared_ptr<LogicalExpression> ExpressionBinder::bindPropertyExpression(
dataTypeToString(childExpression->dataType) + ".");
}

// COUNT(*) is the only function expression supported
// only support COUNT(*) or ID(nodeVariable)
shared_ptr<LogicalExpression> ExpressionBinder::bindFunctionExpression(
const ParsedExpression& parsedExpression) {
assert("COUNT_STAR" == parsedExpression.text);
return make_shared<LogicalExpression>(FUNCTION, INT64, COUNT_STAR);
auto functionName = parsedExpression.text;
transform(begin(functionName), end(functionName), begin(functionName), ::toupper);
if (FUNCTION_COUNT_STAR == functionName) {
return make_shared<LogicalExpression>(FUNCTION, INT64, FUNCTION_COUNT_STAR);
} else if (FUNCTION_ID == functionName) {
if (1 != parsedExpression.children.size()) {
throw invalid_argument(functionName + " takes exactly one parameter.");
}
auto child = bindExpression(*parsedExpression.children[0]);
if (NODE != child->dataType) {
throw invalid_argument("Expect " + child->rawExpression + " to be a node, but it was " +
dataTypeToString(child->dataType));
}
auto nodeName = static_pointer_cast<LogicalNodeExpression>(child)->name;
return make_shared<LogicalExpression>(PROPERTY, NODE_ID, nodeName + "._id");
} else {
throw invalid_argument(functionName + " is not supported.");
}
}

shared_ptr<LogicalExpression> ExpressionBinder::bindLiteralExpression(
Expand Down Expand Up @@ -260,7 +279,7 @@ void validateNoNullLiteralChildren(const ParsedExpression& parsedExpression) {
void validateExpectedType(const LogicalExpression& logicalExpression, DataType expectedType) {
auto dataType = logicalExpression.dataType;
if (expectedType != dataType) {
throw invalid_argument(logicalExpression.rawExpression + " is of data type " +
throw invalid_argument(logicalExpression.rawExpression + " has data type " +
dataTypeToString(dataType) + ". " + dataTypeToString(expectedType) +
" was expected.");
}
Expand Down
32 changes: 10 additions & 22 deletions src/planner/include/logical_plan/operator/extend/logical_extend.h
Original file line number Diff line number Diff line change
@@ -1,45 +1,33 @@
#pragma once

#include <string>

#include "src/common/include/types.h"
#include "src/expression/include/logical/logical_rel_expression.h"
#include "src/planner/include/logical_plan/operator/logical_operator.h"

using namespace graphflow::expression;
using namespace graphflow::common;
using namespace std;

namespace graphflow {
namespace planner {

class LogicalExtend : public LogicalOperator {

public:
LogicalExtend(const LogicalRelExpression& queryRel, const Direction& direction,
LogicalExtend(string boundNodeID, label_t boundNodeLabel, string nbrNodeID,
label_t nbrNodeLabel, label_t relLabel, Direction direction,
shared_ptr<LogicalOperator> prevOperator)
: LogicalOperator{prevOperator}, direction{direction} {
auto isFwd = FWD == direction;
boundNodeVarName = isFwd ? queryRel.getSrcNodeName() : queryRel.getDstNodeName();
boundNodeVarLabel = isFwd ? queryRel.srcNode->label : queryRel.dstNode->label;
nbrNodeVarName = isFwd ? queryRel.getDstNodeName() : queryRel.getSrcNodeName();
nbrNodeVarLabel = isFwd ? queryRel.dstNode->label : queryRel.srcNode->label;
relLabel = queryRel.label;
}
: LogicalOperator{prevOperator}, boundNodeID{move(boundNodeID)},
boundNodeLabel{boundNodeLabel}, nbrNodeID{move(nbrNodeID)},
nbrNodeLabel{nbrNodeLabel}, relLabel{relLabel}, direction{direction} {}

LogicalOperatorType getLogicalOperatorType() const override {
return LogicalOperatorType::LOGICAL_EXTEND;
}

string getOperatorInformation() const override {
return boundNodeVarName + (direction == Direction::FWD ? "->" : "<-") + nbrNodeVarName;
return boundNodeID + (direction == Direction::FWD ? "->" : "<-") + nbrNodeID;
}

public:
string boundNodeVarName;
label_t boundNodeVarLabel;
string nbrNodeVarName;
label_t nbrNodeVarLabel;
string boundNodeID;
label_t boundNodeLabel;
string nbrNodeID;
label_t nbrNodeLabel;
label_t relLabel;
Direction direction;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
#pragma once

#include "src/common/include/types.h"
#include "src/expression/include/logical/logical_expression.h"
#include "src/planner/include/logical_plan/operator/logical_operator.h"

using namespace graphflow::expression;
using namespace graphflow::common;
using namespace std;

namespace graphflow {
namespace planner {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
#pragma once

#include <string>
#include <utility>

#include "src/common/include/types.h"
#include "src/planner/include/logical_plan/operator/logical_operator.h"

using namespace graphflow::common;
using namespace std;

namespace graphflow {
namespace planner {

class LogicalHashJoin : public LogicalOperator {

public:
LogicalHashJoin(string joinNodeVarName, shared_ptr<LogicalOperator> buildSidePrevOperator,
LogicalHashJoin(string joinNodeID, shared_ptr<LogicalOperator> buildSidePrevOperator,
shared_ptr<LogicalOperator> probeSidePrevOperator)
: LogicalOperator(move(probeSidePrevOperator)), joinNodeVarName(move(joinNodeVarName)),
: LogicalOperator(move(probeSidePrevOperator)), joinNodeID(move(joinNodeID)),
buildSidePrevOperator(move(buildSidePrevOperator)) {}

LogicalOperatorType getLogicalOperatorType() const override {
Expand All @@ -34,10 +27,10 @@ class LogicalHashJoin : public LogicalOperator {
return result;
}

string getOperatorInformation() const override { return joinNodeVarName; }
string getOperatorInformation() const override { return joinNodeID; }

public:
const string joinNodeVarName;
const string joinNodeID;
const shared_ptr<LogicalOperator> buildSidePrevOperator;
};
} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
#pragma once

#include <string>
#include <vector>

#include "src/planner/include/logical_plan/operator/logical_operator.h"

using namespace std;

namespace graphflow {
namespace planner {

class LogicalScanNodeID : public LogicalOperator {

public:
LogicalScanNodeID(const string& variableName, label_t label)
: nodeVarName{variableName}, label{label} {}
LogicalScanNodeID(string nodeID, label_t label) : nodeID{move(nodeID)}, label{label} {}

LogicalOperatorType getLogicalOperatorType() const {
return LogicalOperatorType::LOGICAL_SCAN_NODE_ID;
}

string getOperatorInformation() const override { return nodeVarName; }
string getOperatorInformation() const override { return nodeID; }

public:
const string nodeVarName;
const string nodeID;
const label_t label;
};

Expand Down
Loading

0 comments on commit 1efbfc1

Please sign in to comment.