Skip to content
Merged
6 changes: 6 additions & 0 deletions graphdatascience/graph/graph_remote_project_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@ def __call__(
concurrency: int = 4,
undirected_relationship_types: Optional[List[str]] = None,
inverse_indexed_relationship_types: Optional[List[str]] = None,
batch_size: Optional[int] = None,
) -> GraphCreateResult:
if inverse_indexed_relationship_types is None:
inverse_indexed_relationship_types = []
if undirected_relationship_types is None:
undirected_relationship_types = []

arrow_configuration = {}
if batch_size is not None:
arrow_configuration["batchSize"] = batch_size

params = CallParameters(
graph_name=graph_name,
query=query,
concurrency=concurrency,
undirected_relationship_types=undirected_relationship_types,
inverse_indexed_relationship_types=inverse_indexed_relationship_types,
arrow_configuration=arrow_configuration,
)

result = self._query_runner.call_procedure(
Expand Down
91 changes: 19 additions & 72 deletions graphdatascience/query_runner/aura_db_query_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import Any, Dict, List, Optional
from uuid import uuid4

from pandas import DataFrame

Expand Down Expand Up @@ -104,7 +105,7 @@ def _remote_projection(
database: Optional[str] = None,
logging: bool = False,
) -> DataFrame:
self._inject_connection_parameters(params)
self._inject_arrow_config(params["arrow_configuration"])
return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, False)

def _remote_write_back(
Expand All @@ -119,21 +120,29 @@ def _remote_write_back(
if params["config"] is None:
params["config"] = {}

# we pop these out so that they are not retained for the GDS proc call
db_arrow_config = params["config"].pop("arrowConfiguration", {}) # type: ignore
self._inject_arrow_config(db_arrow_config)

job_id = params["config"]["jobId"] if "jobId" in params["config"] else str(uuid4()) # type: ignore
params["config"]["jobId"] = job_id # type: ignore

params["config"]["writeToResultStore"] = True # type: ignore

gds_write_result = self._gds_query_runner.call_procedure(
endpoint, params, yields, database, logging, custom_error
)

write_params = {
db_write_proc_params = {
"graphName": params["graph_name"],
"databaseName": self._gds_query_runner.database(),
"writeConfiguration": self._extract_write_back_arguments(endpoint, params),
"jobId": job_id,
"arrowConfiguration": db_arrow_config,
}
self._inject_connection_parameters(write_params)

write_back_start = time.time()
database_write_result = self._db_query_runner.call_procedure(
"gds.arrow.write", CallParameters(write_params), yields, None, False, False
"gds.arrow.write", CallParameters(db_write_proc_params), yields, None, False, False
)
write_millis = (time.time() - write_back_start) * 1000
gds_write_result["writeMillis"] = write_millis
Expand All @@ -149,75 +158,13 @@ def _remote_write_back(

return gds_write_result

def _inject_connection_parameters(self, params: Dict[str, Any]) -> None:
def _inject_arrow_config(self, params: Dict[str, Any]) -> None:
host, port = self._gds_arrow_client.connection_info()
token = self._gds_arrow_client.request_token()
if token is None:
token = "IGNORED"
params["arrowConfiguration"] = {
"host": host,
"port": port,
"token": token,
"encrypted": self._encrypted,
}

@staticmethod
def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
config = params.get("config", {})
write_config = {}

if "writeConcurrency" in config:
write_config["concurrency"] = config["writeConcurrency"]
elif "concurrency" in config:
write_config["concurrency"] = config["concurrency"]

if "gds.shortestPath" in proc_name or "gds.allShortestPaths" in proc_name:
write_config["relationshipType"] = config["writeRelationshipType"]

write_node_ids = config.get("writeNodeIds")
write_costs = config.get("writeCosts")

if write_node_ids and write_costs:
write_config["relationshipProperties"] = ["totalCost", "nodeIds", "costs"]
elif write_node_ids:
write_config["relationshipProperties"] = ["totalCost", "nodeIds"]
elif write_costs:
write_config["relationshipProperties"] = ["totalCost", "costs"]
else:
write_config["relationshipProperties"] = ["totalCost"]

elif "gds.graph." in proc_name:
if "gds.graph.nodeProperties.write" == proc_name:
properties = params["properties"]
write_config["nodeProperties"] = properties if isinstance(properties, list) else [properties]
write_config["nodeLabels"] = params["entities"]

elif "gds.graph.nodeLabel.write" == proc_name:
write_config["nodeLabels"] = [params["node_label"]]

elif "gds.graph.relationshipProperties.write" == proc_name:
write_config["relationshipProperties"] = params["relationship_properties"]
write_config["relationshipType"] = params["relationship_type"]

elif "gds.graph.relationship.write" == proc_name:
if "relationship_property" in params and params["relationship_property"] != "":
write_config["relationshipProperties"] = [params["relationship_property"]]
write_config["relationshipType"] = params["relationship_type"]

else:
raise ValueError(f"Unsupported procedure name: {proc_name}")

else:
if "writeRelationshipType" in config:
write_config["relationshipType"] = config["writeRelationshipType"]
if "writeProperty" in config:
write_config["relationshipProperties"] = [config["writeProperty"]]
else:
if "writeProperty" in config:
write_config["nodeProperties"] = [config["writeProperty"]]
if "nodeLabels" in params:
write_config["nodeLabels"] = params["nodeLabels"]
else:
write_config["nodeLabels"] = ["*"]

return write_config
params["host"] = host
params["port"] = port
params["token"] = token
params["encrypted"] = self._encrypted
11 changes: 11 additions & 0 deletions graphdatascience/tests/integration/test_remote_graph_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None:
assert result["nodeCount"] == 3


@pytest.mark.cloud_architecture
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
def test_remote_projection_with_small_batch_size(gds_with_cloud_setup: AuraGraphDataScience) -> None:
G, result = gds_with_cloud_setup.graph.project(
GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)", batch_size=10
)

assert G.name() == GRAPH_NAME
assert result["nodeCount"] == 3


@pytest.mark.cloud_architecture
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
def test_remote_write_back_page_rank(gds_with_cloud_setup: AuraGraphDataScience) -> None:
Expand Down
4 changes: 4 additions & 0 deletions graphdatascience/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ def encrypted(self) -> bool:
return False

def last_query(self) -> str:
if len(self.queries) == 0:
return ""
return self.queries[-1]

def last_params(self) -> Dict[str, Any]:
if len(self.params) == 0:
return {}
return self.params[-1]

def set_database(self, database: str) -> None:
Expand Down
Empty file.
165 changes: 165 additions & 0 deletions graphdatascience/tests/unit/test_aura_db_query_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Tuple

from pandas import DataFrame

from graphdatascience import ServerVersion
from graphdatascience.call_parameters import CallParameters
from graphdatascience.query_runner.aura_db_query_runner import AuraDbQueryRunner
from graphdatascience.tests.unit.conftest import CollectingQueryRunner


class FakeArrowClient:

def connection_info(self) -> Tuple[str, str]:
return "myHost", "1234"

def request_token(self) -> str:
return "myToken"


def test_extracts_parameters_projection() -> None:
version = ServerVersion(2, 7, 0)
db_query_runner = CollectingQueryRunner(version)
gds_query_runner = CollectingQueryRunner(version)
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore

qr.call_procedure(
endpoint="gds.arrow.project",
params=CallParameters(
graph_name="g",
query="RETURN 1",
concurrency=2,
undirRels=[],
inverseRels=[],
arrow_configuration={"batchSize": 100},
),
)

# doesn't run anything on GDS
assert gds_query_runner.last_query() == ""
assert gds_query_runner.last_params() == {}
assert (
db_query_runner.last_query()
== "CALL gds.arrow.project($graph_name, $query, $concurrency, $undirRels, $inverseRels, $arrow_configuration)"
)
assert db_query_runner.last_params() == {
"graph_name": "g",
"query": "RETURN 1",
"concurrency": 2,
"undirRels": [],
"inverseRels": [],
"arrow_configuration": {
"encrypted": False,
"host": "myHost",
"port": "1234",
"token": "myToken",
"batchSize": 100,
},
}


def test_extracts_parameters_algo_write() -> None:
version = ServerVersion(2, 7, 0)
db_query_runner = CollectingQueryRunner(version)
gds_query_runner = CollectingQueryRunner(version)
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore

qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={"jobId": "my-job"}))

assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)"
assert gds_query_runner.last_params() == {
"graph_name": "g",
"config": {"jobId": "my-job", "writeToResultStore": True},
}
assert (
db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)"
)
assert db_query_runner.last_params() == {
"graphName": "g",
"databaseName": "dummy",
"jobId": "my-job",
"arrowConfiguration": {"encrypted": False, "host": "myHost", "port": "1234", "token": "myToken"},
}


def test_arrow_and_write_configuration() -> None:
version = ServerVersion(2, 7, 0)
db_query_runner = CollectingQueryRunner(version)
gds_query_runner = CollectingQueryRunner(version)
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore

qr.call_procedure(
endpoint="gds.degree.write",
params=CallParameters(
graph_name="g",
config={"arrowConfiguration": {"batchSize": 1000}, "jobId": "my-job"},
),
)

assert gds_query_runner.last_query() == "CALL gds.degree.write($graph_name, $config)"
assert gds_query_runner.last_params() == {
"graph_name": "g",
"config": {"writeToResultStore": True, "jobId": "my-job"},
}
assert (
db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)"
)
assert db_query_runner.last_params() == {
"graphName": "g",
"databaseName": "dummy",
"jobId": "my-job",
"arrowConfiguration": {
"encrypted": False,
"host": "myHost",
"port": "1234",
"token": "myToken",
"batchSize": 1000,
},
}


def test_arrow_and_write_configuration_graph_write() -> None:
version = ServerVersion(2, 7, 0)
db_query_runner = CollectingQueryRunner(version)
gds_query_runner = CollectingQueryRunner(version)
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
qr = AuraDbQueryRunner(gds_query_runner, db_query_runner, FakeArrowClient(), False) # type: ignore

qr.call_procedure(
endpoint="gds.graph.nodeProperties.write",
params=CallParameters(
graph_name="g",
properties=[],
entities=[],
config={"arrowConfiguration": {"batchSize": 42}, "jobId": "my-job"},
),
)

assert (
gds_query_runner.last_query()
== "CALL gds.graph.nodeProperties.write($graph_name, $properties, $entities, $config)"
)
assert gds_query_runner.last_params() == {
"graph_name": "g",
"entities": [],
"properties": [],
"config": {"writeToResultStore": True, "jobId": "my-job"},
}
assert (
db_query_runner.last_query() == "CALL gds.arrow.write($graphName, $databaseName, $jobId, $arrowConfiguration)"
)
assert db_query_runner.last_params() == {
"graphName": "g",
"databaseName": "dummy",
"jobId": "my-job",
"arrowConfiguration": {
"encrypted": False,
"host": "myHost",
"port": "1234",
"token": "myToken",
"batchSize": 42,
},
}
14 changes: 8 additions & 6 deletions graphdatascience/tests/unit/test_graph_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataSc
aura_gds.graph.project("g", "RETURN gds.graph.project.remote(0, 1, null)")

assert (
runner.last_query()
== "CALL gds.arrow.project("
+ "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)"
runner.last_query() == "CALL gds.arrow.project("
"$graph_name, $query, $concurrency, "
"$undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)"
)
# injection of token and host into the params is done by the actual query runner
assert runner.last_params() == {
Expand All @@ -106,6 +106,7 @@ def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataSc
"inverse_indexed_relationship_types": [],
"query": "RETURN gds.graph.project.remote(0, 1, null)",
"undirected_relationship_types": [],
"arrow_configuration": {},
}


Expand Down Expand Up @@ -720,9 +721,9 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura
)

assert (
runner.last_query()
== "CALL gds.arrow.project("
+ "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)"
runner.last_query() == "CALL gds.arrow.project("
"$graph_name, $query, $concurrency, "
"$undirected_relationship_types, $inverse_indexed_relationship_types, $arrow_configuration)"
)

assert runner.last_params() == {
Expand All @@ -738,4 +739,5 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura
""",
"undirected_relationship_types": ["R"],
"inverse_indexed_relationship_types": ["R"],
"arrow_configuration": {},
}