Skip to content

Commit

Permalink
Add small wrapper over run_cypher
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwalker committed Jun 16, 2023
1 parent 65b620a commit 02d89af
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
66 changes: 66 additions & 0 deletions graphdatascience/graph/graph_cypher_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pandas import Series

from ..error.illegal_attr_checker import IllegalAttrChecker
from ..query_runner.arrow_query_runner import ArrowQueryRunner
from ..query_runner.query_runner import QueryRunner
from ..server_version.server_version import ServerVersion
from .graph_object import Graph
Expand Down Expand Up @@ -232,6 +233,71 @@ def project(

return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore

def run_project(
self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None
) -> Tuple[Graph, "Series[Any]"]:
"""
Run a Cypher projection.
The provided query must end with a `RETURN gds.graph.project(...)` call.
Parameters
----------
query: str
the Cypher projection query
params: Dict[str, Any]
parameters to the query
database: str
the database on which to run the query
Returns
-------
A tuple of the projected graph and statistics about the projection
"""

return_clause = f"RETURN {self._namespace}"

return_index = query.rfind(return_clause)
if return_index == -1:
raise ValueError(f"Invalid query, the query must end with a `{return_clause}` clause: {query}")

return_index += len(return_clause)
return_part = query[return_index:]

# Remove surrounding parentheses and whitespace
right_paren = return_part.rfind(")") + 1
return_part = return_part[:right_paren].strip("() \n\t")

graph_name = return_part.split(",", maxsplit=1)[0]
graph_name = graph_name.strip()

if graph_name.startswith("$"):
if params is None:
raise ValueError(
f"Invalid query, the query references parameter `{graph_name}` but no params were given"
)

graph_name = graph_name[1:]
graph_name = params[graph_name]
else:
# remove the quotes
graph_name = graph_name.strip("'\"")

# remove possible `AS graph` from the end of the query
end_of_query = return_index + right_paren
query = query[:end_of_query]

# run_cypher
qr = self._query_runner

# The Arrow query runner should not be used to execute arbitrary Cypher
if isinstance(qr, ArrowQueryRunner):
qr = qr.fallback_query_runner()

result = qr.run_query(query, params, database, False)
result = result.squeeze()

return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore

def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
if spec is None or spec is False:
return []
Expand Down
32 changes: 32 additions & 0 deletions graphdatascience/tests/unit/test_graph_cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,38 @@
from graphdatascience.server_version.server_version import ServerVersion


@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
def test_run_project(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)")

assert G.name() == "gg"
assert runner.last_params() == {}

assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)"


@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
def test_run_project_with_return_as(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t) AS graph")

assert G.name() == "gg"
assert runner.last_params() == {}

assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)"


@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
def test_run_project_with_graph_name_parameter(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
G, _ = gds.graph.cypher.run_project(
"MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)", params={"graph_name": "gg"}
)

assert G.name() == "gg"
assert runner.last_params() == {"graph_name": "gg"}

assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)"


@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
def test_all(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
G, _ = gds.graph.cypher.project("g")
Expand Down

0 comments on commit 02d89af

Please sign in to comment.