From 85550a9cae647b4dd2ac9331f4f6d5c8e9183144 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Fri, 17 May 2024 15:09:02 +0200 Subject: [PATCH 01/12] Allow specifying arrow and writing configuration through proc call signatures this allows for example: gds.degree.write(G, arrowConfiguration={batchSize: 1}) gds.graph.nodeProperties.write(G, "prop", [], arrowConfiguration={batchSize: 1}) --- .../query_runner/aura_db_query_runner.py | 56 +++--- graphdatascience/tests/unit/conftest.py | 6 +- .../unit/test_aura_db_arrow_query_runner.py | 170 ++++++++++++++++++ 3 files changed, 204 insertions(+), 28 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_query_runner.py b/graphdatascience/query_runner/aura_db_query_runner.py index 3cbeb65de..94e232993 100644 --- a/graphdatascience/query_runner/aura_db_query_runner.py +++ b/graphdatascience/query_runner/aura_db_query_runner.py @@ -104,7 +104,7 @@ def _remote_projection( database: Optional[str] = None, logging: bool = False, ) -> DataFrame: - self._inject_connection_parameters(params) + self._inject_arrow_config(params["arrow_configuration"]) return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, False) def _remote_write_back( @@ -119,21 +119,27 @@ def _remote_write_back( if params["config"] is None: params["config"] = {} + # we pop these out so that they are not retained for the GDS proc call + db_write_config = params["config"].pop("writeConfiguration", {}) # type: ignore + db_arrow_config = params["config"].pop("arrowConfiguration", {}) # type: ignore + self._inject_write_config(endpoint, params, db_write_config) + self._inject_arrow_config(db_arrow_config) + params["config"]["writeToResultStore"] = True # type: ignore gds_write_result = self._gds_query_runner.call_procedure( endpoint, params, yields, database, logging, custom_error ) - write_params = { + db_write_proc_params = { "graphName": params["graph_name"], "databaseName": self._gds_query_runner.database(), - "writeConfiguration": self._extract_write_back_arguments(endpoint, params), + "writeConfiguration": db_write_config, + "arrowConfiguration": db_arrow_config, } - self._inject_connection_parameters(write_params) write_back_start = time.time() database_write_result = self._db_query_runner.call_procedure( - "gds.arrow.write", CallParameters(write_params), yields, None, False, False + "gds.arrow.write", CallParameters(db_write_proc_params), yields, None, False, False ) write_millis = (time.time() - write_back_start) * 1000 gds_write_result["writeMillis"] = write_millis @@ -149,22 +155,20 @@ def _remote_write_back( return gds_write_result - def _inject_connection_parameters(self, params: Dict[str, Any]) -> None: + def _inject_arrow_config(self, params: Dict[str, Any]) -> None: host, port = self._gds_arrow_client.connection_info() token = self._gds_arrow_client.request_token() if token is None: token = "IGNORED" - params["arrowConfiguration"] = { - "host": host, - "port": port, - "token": token, - "encrypted": self._encrypted, - } + + params["host"] = host + params["port"] = port + params["token"] = token + params["encrypted"] = self._encrypted @staticmethod - def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dict[str, Any]: - config = params.get("config", {}) - write_config = {} + def _inject_write_config(proc_name: str, proc_params: Dict[str, Any], write_config: Dict[str, Any]) -> None: + config = proc_params.get("config", {}) if "writeConcurrency" in config: write_config["concurrency"] = config["writeConcurrency"] @@ -188,21 +192,21 @@ def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dic elif "gds.graph." in proc_name: if "gds.graph.nodeProperties.write" == proc_name: - properties = params["properties"] + properties = proc_params["properties"] write_config["nodeProperties"] = properties if isinstance(properties, list) else [properties] - write_config["nodeLabels"] = params["entities"] + write_config["nodeLabels"] = proc_params["entities"] elif "gds.graph.nodeLabel.write" == proc_name: - write_config["nodeLabels"] = [params["node_label"]] + write_config["nodeLabels"] = [proc_params["node_label"]] elif "gds.graph.relationshipProperties.write" == proc_name: - write_config["relationshipProperties"] = params["relationship_properties"] - write_config["relationshipType"] = params["relationship_type"] + write_config["relationshipProperties"] = proc_params["relationship_properties"] + write_config["relationshipType"] = proc_params["relationship_type"] elif "gds.graph.relationship.write" == proc_name: - if "relationship_property" in params and params["relationship_property"] != "": - write_config["relationshipProperties"] = [params["relationship_property"]] - write_config["relationshipType"] = params["relationship_type"] + if "relationship_property" in proc_params and proc_params["relationship_property"] != "": + write_config["relationshipProperties"] = [proc_params["relationship_property"]] + write_config["relationshipType"] = proc_params["relationship_type"] else: raise ValueError(f"Unsupported procedure name: {proc_name}") @@ -215,9 +219,7 @@ def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dic else: if "writeProperty" in config: write_config["nodeProperties"] = [config["writeProperty"]] - if "nodeLabels" in params: - write_config["nodeLabels"] = params["nodeLabels"] + if "nodeLabels" in proc_params: + write_config["nodeLabels"] = proc_params["nodeLabels"] else: write_config["nodeLabels"] = ["*"] - - return write_config diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 00f8c5052..89f0415c8 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -68,9 +68,13 @@ def encrypted(self) -> bool: return False def last_query(self) -> str: + if len(self.queries) == 0: + return "" return self.queries[-1] - def last_params(self) -> Dict[str, Any]: + def last_params(self) -> dict[str, Any]: + if len(self.queries) == 0: + return {} return self.params[-1] def set_database(self, database: str) -> None: diff --git a/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py b/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py index e69de29bb..901501e8a 100644 --- a/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py +++ b/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py @@ -0,0 +1,170 @@ +from typing import Tuple + +from pandas import DataFrame + +from graphdatascience import ServerVersion +from graphdatascience.call_parameters import CallParameters +from graphdatascience.query_runner.aura_db_arrow_query_runner import ( + AuraDbArrowQueryRunner, +) +from graphdatascience.tests.unit.conftest import CollectingQueryRunner + + +class FakeArrowClient: + + def connection_info(self) -> Tuple[str, str]: + return "myHost", "1234" + + def request_token(self) -> str: + return "myToken" + + +def test_extracts_parameters_projection() -> None: + version = ServerVersion(2, 7, 0) + db_query_runner = CollectingQueryRunner(version) + gds_query_runner = CollectingQueryRunner(version) + gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) + qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore + + qr.call_procedure( + endpoint="gds.arrow.project", + params=CallParameters( + graph_name="g", + query="RETURN 1", + concurrency=2, + undirRels=[], + inverseRels=[], + arrow_configuration={"batchSize": 100}, + ), + ) + + # doesn't run anything on GDS + assert gds_query_runner.last_query() == "" + assert gds_query_runner.last_params() == {} + assert ( + db_query_runner.last_query() + == "CALL gds.arrow.project($graph_name, $query, $concurrency, $undirRels, $inverseRels, $arrow_configuration)" + ) + assert db_query_runner.last_params() == { + "graph_name": "g", + "query": "RETURN 1", + "concurrency": 2, + "undirRels": [], + "inverseRels": [], + "arrow_configuration": { + "encrypted": False, + "host": "myHost", + "port": "1234", + "token": "myToken", + "batchSize": 100, + }, + } + + +def test_extracts_parameters_algo_write() -> None: + version = ServerVersion(2, 7, 0) + db_query_runner = CollectingQueryRunner(version) + gds_query_runner = CollectingQueryRunner(version) + gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) + qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore + + qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={})) + + assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)" + assert gds_query_runner.last_params() == { + "graph_name": "g", + "config": {"writeToResultStore": True}, + } + assert ( + db_query_runner.last_query() + == "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)" + ) + assert db_query_runner.last_params() == { + "graphName": "g", + "databaseName": "dummy", + "writeConfiguration": {"nodeLabels": ["*"]}, + "arrowConfiguration": {"encrypted": False, "host": "myHost", "port": "1234", "token": "myToken"}, + } + + +def test_arrow_and_write_configuration() -> None: + version = ServerVersion(2, 7, 0) + db_query_runner = CollectingQueryRunner(version) + gds_query_runner = CollectingQueryRunner(version) + gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) + qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore + + qr.call_procedure( + endpoint="gds.degree.write", + params=CallParameters( + graph_name="g", + config={"arrowConfiguration": {"batchSize": 1000}, "writeConfiguration": {"writeMode": "FOOBAR"}}, + ), + ) + + assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)" + assert gds_query_runner.last_params() == { + "graph_name": "g", + "config": {"writeToResultStore": True}, + } + assert ( + db_query_runner.last_query() + == "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)" + ) + assert db_query_runner.last_params() == { + "graphName": "g", + "databaseName": "dummy", + "writeConfiguration": {"nodeLabels": ["*"], "writeMode": "FOOBAR"}, + "arrowConfiguration": { + "encrypted": False, + "host": "myHost", + "port": "1234", + "token": "myToken", + "batchSize": 1000, + }, + } + + +def test_arrow_and_write_configuration_graph_write() -> None: + version = ServerVersion(2, 7, 0) + db_query_runner = CollectingQueryRunner(version) + gds_query_runner = CollectingQueryRunner(version) + gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) + qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore + + qr.call_procedure( + endpoint="gds.graph.nodeProperties.write", + params=CallParameters( + graph_name="g", + properties=[], + entities=[], + config={"arrowConfiguration": {"batchSize": 42}, "writeConfiguration": {"writeMode": "FOOBAR"}}, + ), + ) + + assert ( + gds_query_runner.last_query() + == "CALL gds.graph.nodeProperties.write($graph_name, $properties, $entities, $config)" + ) + assert gds_query_runner.last_params() == { + "graph_name": "g", + "entities": [], + "properties": [], + "config": {"writeToResultStore": True}, + } + assert ( + db_query_runner.last_query() + == "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)" + ) + assert db_query_runner.last_params() == { + "graphName": "g", + "databaseName": "dummy", + "writeConfiguration": {"nodeLabels": [], "nodeProperties": [], "writeMode": "FOOBAR"}, + "arrowConfiguration": { + "encrypted": False, + "host": "myHost", + "port": "1234", + "token": "myToken", + "batchSize": 42, + }, + } From fe6e2f3d9e3e0209d404495c2152bcf6fde9ae30 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Fri, 17 May 2024 15:21:47 +0200 Subject: [PATCH 02/12] Expose arrow configuration as a remote projection parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Max Kießling --- graphdatascience/graph/graph_remote_project_runner.py | 2 ++ .../tests/integration/test_remote_graph_ops.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/graphdatascience/graph/graph_remote_project_runner.py b/graphdatascience/graph/graph_remote_project_runner.py index 0771b3878..a812dc496 100644 --- a/graphdatascience/graph/graph_remote_project_runner.py +++ b/graphdatascience/graph/graph_remote_project_runner.py @@ -20,6 +20,7 @@ def __call__( concurrency: int = 4, undirected_relationship_types: Optional[List[str]] = None, inverse_indexed_relationship_types: Optional[List[str]] = None, + batch_size: Optional[int] = None, ) -> GraphCreateResult: if inverse_indexed_relationship_types is None: inverse_indexed_relationship_types = [] @@ -32,6 +33,7 @@ def __call__( concurrency=concurrency, undirected_relationship_types=undirected_relationship_types, inverse_indexed_relationship_types=inverse_indexed_relationship_types, + arrow_configuration={"batchSize": batch_size}, ) result = self._query_runner.call_procedure( diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index 3227e4e73..c133121ec 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -43,6 +43,17 @@ def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None: assert result["nodeCount"] == 3 +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_projection_with_small_batch_size(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)", batch_size=10 + ) + + assert G.name() == GRAPH_NAME + assert result["nodeCount"] == 3 + + @pytest.mark.cloud_architecture @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) def test_remote_write_back_page_rank(gds_with_cloud_setup: AuraGraphDataScience) -> None: From 26cc360e01c9cb18e9f61c5311d27058d532f468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Mon, 27 May 2024 16:42:42 +0200 Subject: [PATCH 03/12] Do not set the batch size if it is not defined by the user --- graphdatascience/graph/graph_remote_project_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/graphdatascience/graph/graph_remote_project_runner.py b/graphdatascience/graph/graph_remote_project_runner.py index a812dc496..2eadae242 100644 --- a/graphdatascience/graph/graph_remote_project_runner.py +++ b/graphdatascience/graph/graph_remote_project_runner.py @@ -27,13 +27,17 @@ def __call__( if undirected_relationship_types is None: undirected_relationship_types = [] + arrow_configuration = {} + if batch_size is not None: + arrow_configuration["batchSize"] = batch_size + params = CallParameters( graph_name=graph_name, query=query, concurrency=concurrency, undirected_relationship_types=undirected_relationship_types, inverse_indexed_relationship_types=inverse_indexed_relationship_types, - arrow_configuration={"batchSize": batch_size}, + arrow_configuration=arrow_configuration, ) result = self._query_runner.call_procedure( From dfe3d2e424fae63d100c2cdd9e39294c0ae988d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Tue, 28 May 2024 10:15:16 +0200 Subject: [PATCH 04/12] Enable jobId based write back --- .../query_runner/aura_db_query_runner.py | 67 ++----------------- 1 file changed, 6 insertions(+), 61 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_query_runner.py b/graphdatascience/query_runner/aura_db_query_runner.py index 94e232993..67ef97cd7 100644 --- a/graphdatascience/query_runner/aura_db_query_runner.py +++ b/graphdatascience/query_runner/aura_db_query_runner.py @@ -1,5 +1,6 @@ import time from typing import Any, Dict, List, Optional +from uuid import uuid4 from pandas import DataFrame @@ -120,12 +121,14 @@ def _remote_write_back( params["config"] = {} # we pop these out so that they are not retained for the GDS proc call - db_write_config = params["config"].pop("writeConfiguration", {}) # type: ignore db_arrow_config = params["config"].pop("arrowConfiguration", {}) # type: ignore - self._inject_write_config(endpoint, params, db_write_config) self._inject_arrow_config(db_arrow_config) + job_id = params["config"]["jobId"] if "jobId" in params["config"] else str(uuid4()) + params["config"]["jobId"] = job_id + params["config"]["writeToResultStore"] = True # type: ignore + gds_write_result = self._gds_query_runner.call_procedure( endpoint, params, yields, database, logging, custom_error ) @@ -133,7 +136,7 @@ def _remote_write_back( db_write_proc_params = { "graphName": params["graph_name"], "databaseName": self._gds_query_runner.database(), - "writeConfiguration": db_write_config, + "writeConfiguration": {"jobId": job_id}, "arrowConfiguration": db_arrow_config, } @@ -165,61 +168,3 @@ def _inject_arrow_config(self, params: Dict[str, Any]) -> None: params["port"] = port params["token"] = token params["encrypted"] = self._encrypted - - @staticmethod - def _inject_write_config(proc_name: str, proc_params: Dict[str, Any], write_config: Dict[str, Any]) -> None: - config = proc_params.get("config", {}) - - if "writeConcurrency" in config: - write_config["concurrency"] = config["writeConcurrency"] - elif "concurrency" in config: - write_config["concurrency"] = config["concurrency"] - - if "gds.shortestPath" in proc_name or "gds.allShortestPaths" in proc_name: - write_config["relationshipType"] = config["writeRelationshipType"] - - write_node_ids = config.get("writeNodeIds") - write_costs = config.get("writeCosts") - - if write_node_ids and write_costs: - write_config["relationshipProperties"] = ["totalCost", "nodeIds", "costs"] - elif write_node_ids: - write_config["relationshipProperties"] = ["totalCost", "nodeIds"] - elif write_costs: - write_config["relationshipProperties"] = ["totalCost", "costs"] - else: - write_config["relationshipProperties"] = ["totalCost"] - - elif "gds.graph." in proc_name: - if "gds.graph.nodeProperties.write" == proc_name: - properties = proc_params["properties"] - write_config["nodeProperties"] = properties if isinstance(properties, list) else [properties] - write_config["nodeLabels"] = proc_params["entities"] - - elif "gds.graph.nodeLabel.write" == proc_name: - write_config["nodeLabels"] = [proc_params["node_label"]] - - elif "gds.graph.relationshipProperties.write" == proc_name: - write_config["relationshipProperties"] = proc_params["relationship_properties"] - write_config["relationshipType"] = proc_params["relationship_type"] - - elif "gds.graph.relationship.write" == proc_name: - if "relationship_property" in proc_params and proc_params["relationship_property"] != "": - write_config["relationshipProperties"] = [proc_params["relationship_property"]] - write_config["relationshipType"] = proc_params["relationship_type"] - - else: - raise ValueError(f"Unsupported procedure name: {proc_name}") - - else: - if "writeRelationshipType" in config: - write_config["relationshipType"] = config["writeRelationshipType"] - if "writeProperty" in config: - write_config["relationshipProperties"] = [config["writeProperty"]] - else: - if "writeProperty" in config: - write_config["nodeProperties"] = [config["writeProperty"]] - if "nodeLabels" in proc_params: - write_config["nodeLabels"] = proc_params["nodeLabels"] - else: - write_config["nodeLabels"] = ["*"] From c2443d2f1776b5897b9016138ac498a6afd05160 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Thu, 30 May 2024 10:24:17 +0200 Subject: [PATCH 05/12] Set jobId as a procedure parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sören Reichardt --- graphdatascience/query_runner/aura_db_query_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/query_runner/aura_db_query_runner.py b/graphdatascience/query_runner/aura_db_query_runner.py index 67ef97cd7..5546df0ac 100644 --- a/graphdatascience/query_runner/aura_db_query_runner.py +++ b/graphdatascience/query_runner/aura_db_query_runner.py @@ -136,7 +136,7 @@ def _remote_write_back( db_write_proc_params = { "graphName": params["graph_name"], "databaseName": self._gds_query_runner.database(), - "writeConfiguration": {"jobId": job_id}, + "jobId": job_id, "arrowConfiguration": db_arrow_config, } From 6b90b08b94b8ec9b2f94a3e570bf1a406ed17b2e Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 30 May 2024 16:07:18 +0200 Subject: [PATCH 06/12] Ignore some untyped dict operations --- graphdatascience/query_runner/aura_db_query_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_query_runner.py b/graphdatascience/query_runner/aura_db_query_runner.py index 5546df0ac..8630a2871 100644 --- a/graphdatascience/query_runner/aura_db_query_runner.py +++ b/graphdatascience/query_runner/aura_db_query_runner.py @@ -124,8 +124,8 @@ def _remote_write_back( db_arrow_config = params["config"].pop("arrowConfiguration", {}) # type: ignore self._inject_arrow_config(db_arrow_config) - job_id = params["config"]["jobId"] if "jobId" in params["config"] else str(uuid4()) - params["config"]["jobId"] = job_id + job_id = params["config"]["jobId"] if "jobId" in params["config"] else str(uuid4()) # type: ignore + params["config"]["jobId"] = job_id # type: ignore params["config"]["writeToResultStore"] = True # type: ignore From c11128b82614b72c6c94fbfdebd8685bfb81c879 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 4 Jun 2024 17:25:08 +0200 Subject: [PATCH 07/12] Rename class and test file after rebase --- ..._query_runner.py => test_aura_db_query_runner.py} | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) rename graphdatascience/tests/unit/{test_aura_db_arrow_query_runner.py => test_aura_db_query_runner.py} (91%) diff --git a/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py b/graphdatascience/tests/unit/test_aura_db_query_runner.py similarity index 91% rename from graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py rename to graphdatascience/tests/unit/test_aura_db_query_runner.py index 901501e8a..a4b2e8a5f 100644 --- a/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py +++ b/graphdatascience/tests/unit/test_aura_db_query_runner.py @@ -4,9 +4,7 @@ from graphdatascience import ServerVersion from graphdatascience.call_parameters import CallParameters -from graphdatascience.query_runner.aura_db_arrow_query_runner import ( - AuraDbArrowQueryRunner, -) +from graphdatascience.query_runner.aura_db_query_runner import AuraDbQueryRunner from graphdatascience.tests.unit.conftest import CollectingQueryRunner @@ -24,7 +22,7 @@ def test_extracts_parameters_projection() -> None: db_query_runner = CollectingQueryRunner(version) gds_query_runner = CollectingQueryRunner(version) gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) - qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore + qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore qr.call_procedure( endpoint="gds.arrow.project", @@ -66,7 +64,7 @@ def test_extracts_parameters_algo_write() -> None: db_query_runner = CollectingQueryRunner(version) gds_query_runner = CollectingQueryRunner(version) gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) - qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore + qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={})) @@ -92,7 +90,7 @@ def test_arrow_and_write_configuration() -> None: db_query_runner = CollectingQueryRunner(version) gds_query_runner = CollectingQueryRunner(version) gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) - qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore + qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore qr.call_procedure( endpoint="gds.degree.write", @@ -130,7 +128,7 @@ def test_arrow_and_write_configuration_graph_write() -> None: db_query_runner = CollectingQueryRunner(version) gds_query_runner = CollectingQueryRunner(version) gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) - qr = AuraDbArrowQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore + qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore qr.call_procedure( endpoint="gds.graph.nodeProperties.write", From bc63ed5e1add10f83ba02071057f113a0e4dd685 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 4 Jun 2024 17:40:47 +0200 Subject: [PATCH 08/12] Replace writeConfiguration with jobId In remote write unit tests --- .../tests/unit/test_aura_db_query_runner.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/graphdatascience/tests/unit/test_aura_db_query_runner.py b/graphdatascience/tests/unit/test_aura_db_query_runner.py index a4b2e8a5f..2fb6a8e13 100644 --- a/graphdatascience/tests/unit/test_aura_db_query_runner.py +++ b/graphdatascience/tests/unit/test_aura_db_query_runner.py @@ -66,21 +66,21 @@ def test_extracts_parameters_algo_write() -> None: gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}])) qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore - qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={})) + qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={"jobId": "my-job"})) assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)" assert gds_query_runner.last_params() == { "graph_name": "g", - "config": {"writeToResultStore": True}, + "config": {"jobId": "my-job", "writeToResultStore": True}, } assert ( db_query_runner.last_query() - == "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)" + == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" ) assert db_query_runner.last_params() == { "graphName": "g", "databaseName": "dummy", - "writeConfiguration": {"nodeLabels": ["*"]}, + "jobId": "my-job", "arrowConfiguration": {"encrypted": False, "host": "myHost", "port": "1234", "token": "myToken"}, } @@ -96,23 +96,23 @@ def test_arrow_and_write_configuration() -> None: endpoint="gds.degree.write", params=CallParameters( graph_name="g", - config={"arrowConfiguration": {"batchSize": 1000}, "writeConfiguration": {"writeMode": "FOOBAR"}}, + config={"arrowConfiguration": {"batchSize": 1000}, "jobId": "my-job"}, ), ) assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)" assert gds_query_runner.last_params() == { "graph_name": "g", - "config": {"writeToResultStore": True}, + "config": {"writeToResultStore": True, "jobId": "my-job"}, } assert ( db_query_runner.last_query() - == "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)" + == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" ) assert db_query_runner.last_params() == { "graphName": "g", "databaseName": "dummy", - "writeConfiguration": {"nodeLabels": ["*"], "writeMode": "FOOBAR"}, + "jobId": "my-job", "arrowConfiguration": { "encrypted": False, "host": "myHost", @@ -136,7 +136,7 @@ def test_arrow_and_write_configuration_graph_write() -> None: graph_name="g", properties=[], entities=[], - config={"arrowConfiguration": {"batchSize": 42}, "writeConfiguration": {"writeMode": "FOOBAR"}}, + config={"arrowConfiguration": {"batchSize": 42}, "jobId": "my-job"}, ), ) @@ -148,16 +148,16 @@ def test_arrow_and_write_configuration_graph_write() -> None: "graph_name": "g", "entities": [], "properties": [], - "config": {"writeToResultStore": True}, + "config": {"writeToResultStore": True, "jobId": "my-job"}, } assert ( db_query_runner.last_query() - == "CALL gds.arrow.write($graphName, $databaseName, $writeConfiguration, $arrowConfiguration)" + == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" ) assert db_query_runner.last_params() == { "graphName": "g", "databaseName": "dummy", - "writeConfiguration": {"nodeLabels": [], "nodeProperties": [], "writeMode": "FOOBAR"}, + "jobId": "my-job", "arrowConfiguration": { "encrypted": False, "host": "myHost", From 1dbdaa343d9d7096adbf3ceaf62831a650c784bc Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 4 Jun 2024 17:41:13 +0200 Subject: [PATCH 09/12] Add arrow_configuration to remote projection unit tests --- graphdatascience/tests/unit/test_graph_ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/graphdatascience/tests/unit/test_graph_ops.py b/graphdatascience/tests/unit/test_graph_ops.py index 62a033a59..79eb6a980 100644 --- a/graphdatascience/tests/unit/test_graph_ops.py +++ b/graphdatascience/tests/unit/test_graph_ops.py @@ -97,7 +97,7 @@ def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataSc assert ( runner.last_query() == "CALL gds.arrow.project(" - + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" + + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)" ) # injection of token and host into the params is done by the actual query runner assert runner.last_params() == { @@ -106,6 +106,7 @@ def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataSc "inverse_indexed_relationship_types": [], "query": "RETURN gds.graph.project.remote(0, 1, null)", "undirected_relationship_types": [], + "arrow_configuration": {}, } @@ -722,7 +723,7 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura assert ( runner.last_query() == "CALL gds.arrow.project(" - + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" + + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)" ) assert runner.last_params() == { @@ -738,4 +739,5 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura """, "undirected_relationship_types": ["R"], "inverse_indexed_relationship_types": ["R"], + "arrow_configuration": {}, } From b6e71f983701c4bb6a38f0a3712ee7be0a9908e3 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 4 Jun 2024 17:46:27 +0200 Subject: [PATCH 10/12] Shorten long lines --- .../tests/unit/test_aura_db_query_runner.py | 9 +++------ graphdatascience/tests/unit/test_graph_ops.py | 12 ++++++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/graphdatascience/tests/unit/test_aura_db_query_runner.py b/graphdatascience/tests/unit/test_aura_db_query_runner.py index 2fb6a8e13..deeebc0e8 100644 --- a/graphdatascience/tests/unit/test_aura_db_query_runner.py +++ b/graphdatascience/tests/unit/test_aura_db_query_runner.py @@ -74,8 +74,7 @@ def test_extracts_parameters_algo_write() -> None: "config": {"jobId": "my-job", "writeToResultStore": True}, } assert ( - db_query_runner.last_query() - == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" + db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" ) assert db_query_runner.last_params() == { "graphName": "g", @@ -106,8 +105,7 @@ def test_arrow_and_write_configuration() -> None: "config": {"writeToResultStore": True, "jobId": "my-job"}, } assert ( - db_query_runner.last_query() - == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" + db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" ) assert db_query_runner.last_params() == { "graphName": "g", @@ -151,8 +149,7 @@ def test_arrow_and_write_configuration_graph_write() -> None: "config": {"writeToResultStore": True, "jobId": "my-job"}, } assert ( - db_query_runner.last_query() - == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" + db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)" ) assert db_query_runner.last_params() == { "graphName": "g", diff --git a/graphdatascience/tests/unit/test_graph_ops.py b/graphdatascience/tests/unit/test_graph_ops.py index 79eb6a980..3f43a4d8c 100644 --- a/graphdatascience/tests/unit/test_graph_ops.py +++ b/graphdatascience/tests/unit/test_graph_ops.py @@ -95,9 +95,9 @@ def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataSc aura_gds.graph.project("g", "RETURN gds.graph.project.remote(0, 1, null)") assert ( - runner.last_query() - == "CALL gds.arrow.project(" - + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)" + runner.last_query() == "CALL gds.arrow.project(" + "$graph_name, $query, $concurrency, " + "$undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)" ) # injection of token and host into the params is done by the actual query runner assert runner.last_params() == { @@ -721,9 +721,9 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura ) assert ( - runner.last_query() - == "CALL gds.arrow.project(" - + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)" + runner.last_query() == "CALL gds.arrow.project(" + "$graph_name, $query, $concurrency, " + "$undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)" ) assert runner.last_params() == { From fe1f71d211696b998076b858c3f7f20c73cfe111 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Reichardt?= Date: Mon, 1 Jul 2024 14:01:56 +0200 Subject: [PATCH 11/12] Fix length check for params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Florentin Dörre --- graphdatascience/tests/unit/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 89f0415c8..4435a488f 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -73,7 +73,7 @@ def last_query(self) -> str: return self.queries[-1] def last_params(self) -> dict[str, Any]: - if len(self.queries) == 0: + if len(self.params) == 0: return {} return self.params[-1] From 277a87567d7d9785727a2cba4e8503fdfa6bba5a Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 2 Jul 2024 15:27:23 +0200 Subject: [PATCH 12/12] Use `Dict` instead of `dict` --- graphdatascience/tests/unit/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 4435a488f..a8a336963 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -72,7 +72,7 @@ def last_query(self) -> str: return "" return self.queries[-1] - def last_params(self) -> dict[str, Any]: + def last_params(self) -> Dict[str, Any]: if len(self.params) == 0: return {} return self.params[-1]