diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py index f40a43eef..34dcf0a3f 100644 --- a/graphdatascience/graph/graph_cypher_runner.py +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -4,6 +4,7 @@ from pandas import Series from ..error.illegal_attr_checker import IllegalAttrChecker +from ..query_runner.arrow_query_runner import ArrowQueryRunner from ..query_runner.query_runner import QueryRunner from ..server_version.server_version import ServerVersion from .graph_object import Graph @@ -232,6 +233,71 @@ def project( return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore + def run_project( + self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None + ) -> Tuple[Graph, "Series[Any]"]: + """ + Run a Cypher projection. + The provided query must end with a `RETURN gds.graph.project(...)` call. + + Parameters + ---------- + query: str + the Cypher projection query + params: Dict[str, Any] + parameters to the query + database: str + the database on which to run the query + + Returns + ------- + A tuple of the projected graph and statistics about the projection + """ + + return_clause = f"RETURN {self._namespace}" + + return_index = query.rfind(return_clause) + if return_index == -1: + raise ValueError(f"Invalid query, the query must end with a `{return_clause}` clause: {query}") + + return_index += len(return_clause) + return_part = query[return_index:] + + # Remove surrounding parentheses and whitespace + right_paren = return_part.rfind(")") + 1 + return_part = return_part[:right_paren].strip("() \n\t") + + graph_name = return_part.split(",", maxsplit=1)[0] + graph_name = graph_name.strip() + + if graph_name.startswith("$"): + if params is None: + raise ValueError( + f"Invalid query, the query references parameter `{graph_name}` but no params were given" + ) + + graph_name = graph_name[1:] + graph_name = params[graph_name] + else: + # remove the quotes + graph_name = graph_name.strip("'\"") + + # remove possible `AS graph` from the end of the query + end_of_query = return_index + right_paren + query = query[:end_of_query] + + # run_cypher + qr = self._query_runner + + # The Arrow query runner should not be used to execute arbitrary Cypher + if isinstance(qr, ArrowQueryRunner): + qr = qr.fallback_query_runner() + + result = qr.run_query(query, params, database, False) + result = result.squeeze() + + return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore + def _node_projections_spec(self, spec: Any) -> list[NodeProjection]: if spec is None or spec is False: return [] diff --git a/graphdatascience/tests/unit/test_graph_cypher.py b/graphdatascience/tests/unit/test_graph_cypher.py index a16d231f1..532690ebb 100644 --- a/graphdatascience/tests/unit/test_graph_cypher.py +++ b/graphdatascience/tests/unit/test_graph_cypher.py @@ -5,6 +5,38 @@ from graphdatascience.server_version.server_version import ServerVersion +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)") + + assert G.name() == "gg" + assert runner.last_params() == {} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)" + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project_with_return_as(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t) AS graph") + + assert G.name() == "gg" + assert runner.last_params() == {} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)" + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project_with_graph_name_parameter(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project( + "MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)", params={"graph_name": "gg"} + ) + + assert G.name() == "gg" + assert runner.last_params() == {"graph_name": "gg"} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)" + + @pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) def test_all(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: G, _ = gds.graph.cypher.project("g")