From 99c824b76208e03b2fb8d0860c9c0d848797a708 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Mar 2024 17:19:48 +0100 Subject: [PATCH 01/22] Introduce GdsArrowClient --- .../query_runner/gds_arrow_client.py | 228 ++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 graphdatascience/query_runner/gds_arrow_client.py diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py new file mode 100644 index 000000000..c56149600 --- /dev/null +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -0,0 +1,228 @@ +import base64 +import json +import time +import warnings +from abc import ABC +from typing import Optional, Tuple, Any, Dict + +from pandas import DataFrame +from pyarrow import flight, Table, ChunkedArray, chunked_array, Schema +from pyarrow.types import is_dictionary +from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory + +from .query_runner import QueryRunner +from ..server_version.server_version import ServerVersion +from .arrow_endpoint_version import ArrowEndpointVersion + + +class GdsArrowClient(ABC): + + @staticmethod + def create( + query_runner: QueryRunner, + auth: Optional[Tuple[str, str]] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + connection_string_override: Optional[str] = None, + ) -> "Optional[GdsArrowClient]": + arrow_info = ( + query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() + ) + + if not arrow_info["running"]: + return None + + server_version = query_runner.server_version() + connection_string: str + if connection_string_override is not None: + connection_string = connection_string_override + else: + connection_string = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) + + host, port = connection_string.split(":") + + arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", [])) + + return GdsArrowClient( + host, + int(port), + server_version, + auth, + encrypted, + disable_server_verification, + tls_root_certs, + arrow_endpoint_version, + ) + + def __init__( + self, + host: str, + port: int, + server_version: ServerVersion, + auth: Optional[Tuple[str, str]] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA, + ): + self._server_version = server_version + self._arrow_endpoint_version = arrow_endpoint_version + self._host = host + self._port = port + self._auth = auth + + location = ( + flight.Location.for_grpc_tls(host, port) + if encrypted + else flight.Location.for_grpc_tcp(host, port) + ) + + client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} + if auth: + self._auth_middleware = AuthMiddleware(auth) + client_options["middleware"] = [AuthFactory(self._auth_middleware)] + if tls_root_certs: + client_options["tls_root_certs"] = tls_root_certs + + self._flight_client = flight.FlightClient(location, **client_options) + + def connection_info(self) -> tuple[str, int]: + return self._host, self._port + + def get_or_request_token(self) -> str: + if self._auth: + self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + return self._auth_middleware.token() + else: + return "IGNORED" + + def get_property(self, database: str, graph_name: str, procedure_name: str, configuration: Dict[str, Any]) -> DataFrame: + payload = { + "database_name": database, + "graph_name": graph_name, + "procedure_name": procedure_name, + "configuration": configuration, + } + + if self._arrow_endpoint_version == ArrowEndpointVersion.V1: + payload = { + "name": "GET_MESSAGE", + "version": ArrowEndpointVersion.V1.version(), + "body": payload, + } + + ticket = flight.Ticket(json.dumps(payload).encode("utf-8")) + get = self._flight_client.do_get(ticket) + arrow_table = get.read_all() + + if configuration.get("list_node_labels", False): + # GDS 2.5 had an inconsistent naming of the node labels column + new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names] + arrow_table = arrow_table.rename_columns(new_colum_names) + + # Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Passing a BlockManager to DataFrame is deprecated", + ) + + return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore + + def send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None: + action_type = self._versioned_action_type(action_type) + result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) + + # Consume result fully to sanity check and avoid cancelled streams + collected_result = list(result) + assert len(collected_result) == 1 + + json.loads(collected_result[0].body.to_pybytes().decode()) + + def start_put(self, payload: dict[str, Any], schema: Schema) -> Tuple["FlightStreamWriter", "FlightStreamReader"]: + flight_descriptor = self._versioned_flight_descriptor(payload) + upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) + return self._flight_client.do_put(upload_descriptor, schema) + + def close(self) -> None: + self._flight_client.close() + + def _versioned_action_type(self, action_type: str) -> str: + return self._arrow_endpoint_version.prefix() + action_type + + def _versioned_flight_descriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]: + return ( + flight_descriptor + if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA + else { + "name": "PUT_MESSAGE", + "version": ArrowEndpointVersion.V1.version(), + "body": flight_descriptor, + } + ) + + @staticmethod + def _sanitize_arrow_table(arrow_table: Table) -> Table: + dict_encoded_fields = [ + (idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type) + ] + + for idx, field in dict_encoded_fields: + try: + field.type.to_pandas_dtype() + except NotImplementedError: + # we need to decode the dictionary column before transforming to pandas + if isinstance(arrow_table[field.name], ChunkedArray): + decoded_col = chunked_array([chunk.dictionary_decode() for chunk in arrow_table[field.name].chunks]) + else: + decoded_col = arrow_table[field.name].dictionary_decode() + arrow_table = arrow_table.set_column(idx, field.name, decoded_col) + return arrow_table + + +class AuthFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = middleware + + def start_call(self, info: Any) -> "AuthMiddleware": + return self._middleware + + +class AuthMiddleware(ClientMiddleware): # type: ignore + def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._auth = auth + self._token: Optional[str] = None + self._token_timestamp = 0 + + def token(self) -> Optional[str]: + # check whether the token is older than 10 minutes. If so, reset it. + if self._token and int(time.time()) - self._token_timestamp > 600: + self._token = None + + return self._token + + def _set_token(self, token: str) -> None: + self._token = token + self._token_timestamp = int(time.time()) + + def received_headers(self, headers: Dict[str, Any]) -> None: + auth_header: str = headers.get("Authorization", None) + if not auth_header: + return + [auth_type, token] = auth_header.split(" ", 1) + if auth_type == "Bearer": + self._set_token(token) + + def sending_headers(self) -> Dict[str, str]: + token = self.token() + if not token: + username, password = self._factory.auth + auth_token = f"{username}:{password}" + auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") + # There seems to be a bug, `authorization` must be lower key + return {"authorization": auth_token} + else: + return {"authorization": "Bearer " + token} From 09024080b9ae3c583b86ab49f21dc76e7d1af576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Mar 2024 17:22:22 +0100 Subject: [PATCH 02/22] Use GdsArrowClient in ArrowGraphConstruction and ArrowQueryRunner --- .../query_runner/arrow_graph_constructor.py | 45 +--- .../query_runner/arrow_query_runner.py | 206 +++--------------- .../query_runner/gds_arrow_client.py | 21 +- .../tests/unit/test_arrow_runner.py | 28 --- .../tests/unit/test_gds_arrow_client.py | 29 +++ 5 files changed, 80 insertions(+), 249 deletions(-) create mode 100644 graphdatascience/tests/unit/test_gds_arrow_client.py diff --git a/graphdatascience/query_runner/arrow_graph_constructor.py b/graphdatascience/query_runner/arrow_graph_constructor.py index f40f66d8e..2997ef8b9 100644 --- a/graphdatascience/query_runner/arrow_graph_constructor.py +++ b/graphdatascience/query_runner/arrow_graph_constructor.py @@ -1,19 +1,17 @@ from __future__ import annotations import concurrent -import json import math import warnings from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, NoReturn, Optional import numpy -import pyarrow.flight as flight from pandas import DataFrame from pyarrow import Table from tqdm.auto import tqdm -from .arrow_endpoint_version import ArrowEndpointVersion +from .gds_arrow_client import GdsArrowClient from .graph_constructor import GraphConstructor @@ -22,9 +20,8 @@ def __init__( self, database: str, graph_name: str, - flight_client: flight.FlightClient, + flight_client: GdsArrowClient, concurrency: int, - arrow_endpoint_version: ArrowEndpointVersion, undirected_relationship_types: Optional[List[str]], chunk_size: int = 10_000, ): @@ -32,7 +29,6 @@ def __init__( self._concurrency = concurrency self._graph_name = graph_name self._client = flight_client - self._arrow_endpoint_version = arrow_endpoint_version self._undirected_relationship_types = ( [] if undirected_relationship_types is None else undirected_relationship_types ) @@ -49,20 +45,20 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N if self._undirected_relationship_types: config["undirected_relationship_types"] = self._undirected_relationship_types - self._send_action( + self._client.send_action( "CREATE_GRAPH", config, ) self._send_dfs(node_dfs, "node") - self._send_action("NODE_LOAD_DONE", {"name": self._graph_name}) + self._client.send_action("NODE_LOAD_DONE", {"name": self._graph_name}) self._send_dfs(relationship_dfs, "relationship") - self._send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name}) + self._client.send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name}) except (Exception, KeyboardInterrupt) as e: - self._send_action("ABORT", {"name": self._graph_name}) + self._client.send_action("ABORT", {"name": self._graph_name}) raise e @@ -85,25 +81,12 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]: return partitioned_dfs - def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None: - action_type = self._versioned_action_type(action_type) - result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) - - # Consume result fully to sanity check and avoid cancelled streams - collected_result = list(result) - assert len(collected_result) == 1 - - json.loads(collected_result[0].body.to_pybytes().decode()) - def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None: table = Table.from_pandas(df) batches = table.to_batches(self._chunk_size) flight_descriptor = {"name": self._graph_name, "entity_type": entity_type} - flight_descriptor = self._versioned_flight_desriptor(flight_descriptor) - # Write schema - upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) - writer, _ = self._client.do_put(upload_descriptor, table.schema) + writer, _ = self._client.start_put(flight_descriptor, table.schema) with writer: # Write table in chunks @@ -126,17 +109,3 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None: if not future.exception(): continue raise future.exception() # type: ignore - - def _versioned_action_type(self, action_type: str) -> str: - return self._arrow_endpoint_version.prefix() + action_type - - def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]: - return ( - flight_descriptor - if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA - else { - "name": "PUT_MESSAGE", - "version": ArrowEndpointVersion.V1.version(), - "body": flight_descriptor, - } - ) diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index 49c96b447..6e0aba51b 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -1,26 +1,20 @@ from __future__ import annotations -import base64 -import json -import time import warnings from typing import Any, Dict, List, Optional, Tuple -import pyarrow.flight as flight from pandas import DataFrame -from pyarrow import ChunkedArray, Table, chunked_array -from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory from pyarrow.types import is_dictionary # type: ignore -from ..call_parameters import CallParameters -from ..server_version.server_version import ServerVersion -from .arrow_endpoint_version import ArrowEndpointVersion -from .arrow_graph_constructor import ArrowGraphConstructor -from .graph_constructor import GraphConstructor -from .query_runner import QueryRunner from graphdatascience.server_version.compatible_with import ( IncompatibleServerVersionError, ) +from .arrow_graph_constructor import ArrowGraphConstructor +from .gds_arrow_client import GdsArrowClient +from .graph_constructor import GraphConstructor +from .query_runner import QueryRunner +from ..call_parameters import CallParameters +from ..server_version.server_version import ServerVersion class ArrowQueryRunner(QueryRunner): @@ -33,61 +27,33 @@ def create( tls_root_certs: Optional[bytes] = None, connection_string_override: Optional[str] = None, ) -> QueryRunner: - arrow_info = ( - fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() + gds_arrow_client = GdsArrowClient.create( + fallback_query_runner, + auth, + encrypted, + disable_server_verification, + tls_root_certs, + connection_string_override, ) - server_version = fallback_query_runner.server_version() - connection_string: str - if connection_string_override is not None: - connection_string = connection_string_override - else: - connection_string = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) - arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", [])) - if arrow_info["running"]: + if gds_arrow_client is not None: return ArrowQueryRunner( - connection_string, + gds_arrow_client, fallback_query_runner, - server_version, - auth, - encrypted, - disable_server_verification, - tls_root_certs, - arrow_endpoint_version, + fallback_query_runner.server_version() ) else: return fallback_query_runner def __init__( self, - uri: str, + gds_arrow_client: GdsArrowClient, fallback_query_runner: QueryRunner, server_version: ServerVersion, - auth: Optional[Tuple[str, str]] = None, - encrypted: bool = False, - disable_server_verification: bool = False, - tls_root_certs: Optional[bytes] = None, - arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA, ): self._fallback_query_runner = fallback_query_runner + self._gds_arrow_client = gds_arrow_client self._server_version = server_version - self._arrow_endpoint_version = arrow_endpoint_version - - host, port_string = uri.split(":") - - location = ( - flight.Location.for_grpc_tls(host, int(port_string)) - if encrypted - else flight.Location.for_grpc_tcp(host, int(port_string)) - ) - - client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} - if auth: - client_options["middleware"] = [AuthFactory(auth)] - if tls_root_certs: - client_options["tls_root_certs"] = tls_root_certs - - self._flight_client = flight.FlightClient(location, **client_options) def warn_about_deprecation(self, old_endpoint: str, new_endpoint: str) -> None: warnings.warn( @@ -140,7 +106,7 @@ def call_procedure( old_endpoint="gds.graph.streamNodeProperty", new_endpoint="gds.graph.nodeProperty.stream" ) - return self._run_arrow_property_get(graph_name, endpoint, config) + return self._gds_arrow_client.get_property(self.database(), graph_name, endpoint, config) elif ( old_endpoint := ("gds.graph.streamNodeProperties" == endpoint) ) or "gds.graph.nodeProperties.stream" == endpoint: @@ -159,7 +125,8 @@ def call_procedure( self.warn_about_deprecation( old_endpoint="gds.graph.streamNodeProperties", new_endpoint="gds.graph.nodeProperties.stream" ) - return self._run_arrow_property_get( + return self._gds_arrow_client.get_property( + self.database(), graph_name, endpoint, config, @@ -180,7 +147,8 @@ def call_procedure( old_endpoint="gds.graph.streamRelationshipProperty", new_endpoint="gds.graph.relationshipProperty.stream", ) - return self._run_arrow_property_get( + return self._gds_arrow_client.get_property( + self.database(), graph_name, endpoint, {"relationship_property": property_name, "relationship_types": relationship_types}, @@ -202,7 +170,8 @@ def call_procedure( new_endpoint="gds.graph.relationshipProperties.stream", ) - return self._run_arrow_property_get( + return self._gds_arrow_client.get_property( + self.database(), graph_name, endpoint, {"relationship_properties": property_names, "relationship_types": relationship_types}, @@ -229,7 +198,7 @@ def call_procedure( new_endpoint="gds.graph.relationships.stream", ) - return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types}) + return self._gds_arrow_client.get_property(self.database(), graph_name, endpoint, {"relationship_types": relationship_types}) return self._fallback_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) @@ -259,52 +228,11 @@ def last_bookmarks(self) -> Optional[Any]: def close(self) -> None: self._fallback_query_runner.close() - # PyArrow 7 did not expose a close method yet - if hasattr(self._flight_client, "close"): - self._flight_client.close() + self._gds_arrow_client.close() def fallback_query_runner(self) -> QueryRunner: return self._fallback_query_runner - def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configuration: Dict[str, Any]) -> DataFrame: - if not self.database(): - raise ValueError( - "For this call you must have explicitly specified a valid Neo4j database to execute on, " - "using `GraphDataScience.set_database`." - ) - - payload = { - "database_name": self.database(), - "graph_name": graph_name, - "procedure_name": procedure_name, - "configuration": configuration, - } - - if self._arrow_endpoint_version == ArrowEndpointVersion.V1: - payload = { - "name": "GET_COMMAND", - "version": ArrowEndpointVersion.V1.version(), - "body": payload, - } - - ticket = flight.Ticket(json.dumps(payload).encode("utf-8")) - get = self._flight_client.do_get(ticket) - arrow_table = get.read_all() - - if configuration.get("list_node_labels", False): - # GDS 2.5 had an inconsistent naming of the node labels column - new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names] - arrow_table = arrow_table.rename_columns(new_colum_names) - - # Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message=r"Passing a BlockManager to DataFrame is deprecated", - ) - - return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore - def create_graph_constructor( self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]] ) -> GraphConstructor: @@ -318,86 +246,8 @@ def create_graph_constructor( return ArrowGraphConstructor( database, graph_name, - self._flight_client, + self._gds_arrow_client, concurrency, - self._arrow_endpoint_version, undirected_relationship_types, ) - def _sanitize_arrow_table(self, arrow_table: Table) -> Table: - # empty columns cannot be used to build a chunked_array in pyarrow - if len(arrow_table) == 0: - return arrow_table - - dict_encoded_fields = [ - (idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type) - ] - for idx, field in dict_encoded_fields: - try: - field.type.to_pandas_dtype() - except NotImplementedError: - # we need to decode the dictionary column before transforming to pandas - if isinstance(arrow_table[field.name], ChunkedArray): - decoded_col = chunked_array([chunk.dictionary_decode() for chunk in arrow_table[field.name].chunks]) - else: - decoded_col = arrow_table[field.name].dictionary_decode() - arrow_table = arrow_table.set_column(idx, field.name, decoded_col) - return arrow_table - - -class AuthFactory(ClientMiddlewareFactory): # type: ignore - def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._auth = auth - self._token: Optional[str] = None - self._token_timestamp = 0 - - def start_call(self, info: Any) -> "AuthMiddleware": - return AuthMiddleware(self) - - def token(self) -> Optional[str]: - # check whether the token is older than 10 minutes. If so, reset it. - if self._token and int(time.time()) - self._token_timestamp > 600: - self._token = None - - return self._token - - def set_token(self, token: str) -> None: - self._token = token - self._token_timestamp = int(time.time()) - - @property - def auth(self) -> Tuple[str, str]: - return self._auth - - -class AuthMiddleware(ClientMiddleware): # type: ignore - def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._factory = factory - - def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header = headers.get("authorization", None) - if not auth_header: - return - - # the result is always a list - header_value = auth_header[0] - - if not isinstance(header_value, str): - raise ValueError(f"Incompatible header value received from server: `{header_value}`") - - auth_type, token = header_value.split(" ", 1) - if auth_type == "Bearer": - self._factory.set_token(token) - - def sending_headers(self) -> Dict[str, str]: - token = self._factory.token() - if not token: - username, password = self._factory.auth - auth_token = f"{username}:{password}" - auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") - # There seems to be a bug, `authorization` must be lower key - return {"authorization": auth_token} - else: - return {"authorization": "Bearer " + token} diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index c56149600..c6697f2c3 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -7,12 +7,12 @@ from pandas import DataFrame from pyarrow import flight, Table, ChunkedArray, chunked_array, Schema -from pyarrow.types import is_dictionary from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory +from pyarrow.types import is_dictionary +from .arrow_endpoint_version import ArrowEndpointVersion from .query_runner import QueryRunner from ..server_version.server_version import ServerVersion -from .arrow_endpoint_version import ArrowEndpointVersion class GdsArrowClient(ABC): @@ -107,7 +107,7 @@ def get_property(self, database: str, graph_name: str, procedure_name: str, conf if self._arrow_endpoint_version == ArrowEndpointVersion.V1: payload = { - "name": "GET_MESSAGE", + "name": "GET_COMMAND", "version": ArrowEndpointVersion.V1.version(), "body": payload, } @@ -164,6 +164,10 @@ def _versioned_flight_descriptor(self, flight_descriptor: Dict[str, Any]) -> Dic @staticmethod def _sanitize_arrow_table(arrow_table: Table) -> Table: + # empty columns cannot be used to build a chunked_array in pyarrow + if len(arrow_table) == 0: + return arrow_table + dict_encoded_fields = [ (idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type) ] @@ -209,10 +213,17 @@ def _set_token(self, token: str) -> None: self._token_timestamp = int(time.time()) def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header: str = headers.get("Authorization", None) + auth_header = headers.get("authorization", None) if not auth_header: return - [auth_type, token] = auth_header.split(" ", 1) + + # the result is always a list + header_value = auth_header[0] + + if not isinstance(header_value, str): + raise ValueError(f"Incompatible header value received from server: `{header_value}`") + + auth_type, token = header_value.split(" ", 1) if auth_type == "Bearer": self._set_token(token) diff --git a/graphdatascience/tests/unit/test_arrow_runner.py b/graphdatascience/tests/unit/test_arrow_runner.py index af8284724..d8ed51125 100644 --- a/graphdatascience/tests/unit/test_arrow_runner.py +++ b/graphdatascience/tests/unit/test_arrow_runner.py @@ -42,31 +42,3 @@ def test_create_with_provided_connection(runner: CollectingQueryRunner) -> None: with pytest.raises(FlightUnavailableError, match=".+ failed to connect .+ ipv4:127.0.0.1:4321: .+"): arrow_runner._flight_client.list_actions() - - -def test_auth_middleware() -> None: - factory = AuthFactory(("user", "password")) - middleware = AuthMiddleware(factory) - - first_header = middleware.sending_headers() - assert first_header == {"authorization": "Basic dXNlcjpwYXNzd29yZA=="} - - middleware.received_headers({"authorization": ["Bearer token"]}) - assert factory._token == "token" - - second_header = middleware.sending_headers() - assert second_header == {"authorization": "Bearer token"} - - middleware.received_headers({}) - assert factory._token == "token" - - second_header = middleware.sending_headers() - assert second_header == {"authorization": "Bearer token"} - - -def test_auth_middleware_bad_headers() -> None: - factory = AuthFactory(("user", "password")) - middleware = AuthMiddleware(factory) - - with pytest.raises(ValueError, match="Incompatible header value received from server: `12342`"): - middleware.received_headers({"authorization": [12342]}) diff --git a/graphdatascience/tests/unit/test_gds_arrow_client.py b/graphdatascience/tests/unit/test_gds_arrow_client.py new file mode 100644 index 000000000..5e881179b --- /dev/null +++ b/graphdatascience/tests/unit/test_gds_arrow_client.py @@ -0,0 +1,29 @@ +import pytest + +from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware + + +def test_auth_middleware() -> None: + middleware = AuthMiddleware(("user", "password")) + + first_header = middleware.sending_headers() + assert first_header == {"authorization": "Basic dXNlcjpwYXNzd29yZA=="} + + middleware.received_headers({"authorization": ["Bearer token"]}) + assert middleware._token == "token" + + second_header = middleware.sending_headers() + assert second_header == {"authorization": "Bearer token"} + + middleware.received_headers({}) + assert middleware._token == "token" + + second_header = middleware.sending_headers() + assert second_header == {"authorization": "Bearer token"} + + +def test_auth_middleware_bad_headers() -> None: + middleware = AuthMiddleware(("user", "password")) + + with pytest.raises(ValueError, match="Incompatible header value received from server: `12342`"): + middleware.received_headers({"authorization": [12342]}) From fa3474c0d621b8b05449eb647a93d010c846f8c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Mar 2024 17:22:52 +0100 Subject: [PATCH 03/22] Adapt client to push based remote projection --- .../graph/graph_remote_proc_runner.py | 1 - .../graph/graph_remote_project_runner.py | 30 +++--- .../aura_db_arrow_query_runner.py | 101 ++++-------------- .../tests/unit/test_arrow_runner.py | 4 +- graphdatascience/tests/unit/test_graph_ops.py | 54 +++------- 5 files changed, 48 insertions(+), 142 deletions(-) diff --git a/graphdatascience/graph/graph_remote_proc_runner.py b/graphdatascience/graph/graph_remote_proc_runner.py index a08a268dd..cdb1e8158 100644 --- a/graphdatascience/graph/graph_remote_proc_runner.py +++ b/graphdatascience/graph/graph_remote_proc_runner.py @@ -5,5 +5,4 @@ class GraphRemoteProcRunner(BaseGraphProcRunner): @property def project(self) -> GraphProjectRemoteRunner: - self._namespace += ".project.remoteDb" return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version) diff --git a/graphdatascience/graph/graph_remote_project_runner.py b/graphdatascience/graph/graph_remote_project_runner.py index 26fbeb687..59fe45387 100644 --- a/graphdatascience/graph/graph_remote_project_runner.py +++ b/graphdatascience/graph/graph_remote_project_runner.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Any +from typing import Any, List, Optional from ..error.illegal_attr_checker import IllegalAttrChecker +from ..query_runner.aura_db_arrow_query_runner import AuraDbArrowQueryRunner from ..server_version.compatible_with import compatible_with from .graph_object import Graph from graphdatascience.call_parameters import CallParameters @@ -11,28 +12,23 @@ class GraphProjectRemoteRunner(IllegalAttrChecker): - _SCHEMA_KEYS = ["nodePropertySchema", "relationshipPropertySchema"] + @compatible_with("project", min_inclusive=ServerVersion(2, 7, 0)) + def __call__(self, graph_name: str, query: str, concurrency: int = 4, undirected_relationship_types: Optional[List[str]]=None, inverse_indexed_relationship_types: Optional[List[str]]=None) -> GraphCreateResult: + if inverse_indexed_relationship_types is None: + inverse_indexed_relationship_types = [] + if undirected_relationship_types is None: + undirected_relationship_types = [] - @compatible_with("project", min_inclusive=ServerVersion(2, 6, 0)) - def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult: - placeholder = "<>" # host and token will be added by query runner - self.map_property_types(config) params = CallParameters( graph_name=graph_name, query=query, - token=placeholder, - host=placeholder, - remote_database=self._query_runner.database(), - config=config, + concurrency=concurrency, + undirected_relationship_types=undirected_relationship_types, + inverse_indexed_relationship_types=inverse_indexed_relationship_types, ) + result = self._query_runner.call_procedure( - endpoint=self._namespace, + endpoint=AuraDbArrowQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME, params=params, ).squeeze() return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result) - - @staticmethod - def map_property_types(config: dict[str, Any]) -> None: - for key in GraphProjectRemoteRunner._SCHEMA_KEYS: - if key in config: - config[key] = {k: v.value for k, v in config[key].items()} diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index c301f10cc..c1cf90677 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -4,6 +4,7 @@ from pyarrow import flight from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory +from .gds_arrow_client import GdsArrowClient from ..call_parameters import CallParameters from ..session.dbms_connection_info import DbmsConnectionInfo from .query_runner import QueryRunner @@ -12,44 +13,24 @@ class AuraDbArrowQueryRunner(QueryRunner): - GDS_REMOTE_PROJECTION_PROC_NAME = "gds.graph.project.remoteDb" + GDS_REMOTE_PROJECTION_PROC_NAME = "gds.arrow.project" def __init__( self, gds_query_runner: QueryRunner, db_query_runner: QueryRunner, encrypted: bool, - aura_db_connection_info: DbmsConnectionInfo, + gds_connection_info: DbmsConnectionInfo, ): self._gds_query_runner = gds_query_runner self._db_query_runner = db_query_runner - self._auth = aura_db_connection_info.auth() + self._gds_connection_info = gds_connection_info - arrow_info: "Series[Any]" = db_query_runner.call_procedure( - endpoint="internal.arrow.status", custom_error=False - ).squeeze() - - if not arrow_info.get("running"): - raise RuntimeError(f"The Arrow Server is not running at `{aura_db_connection_info.uri}`") - listen_address: Optional[str] = arrow_info.get("advertisedListenAddress") # type: ignore - if not listen_address: - raise ConnectionError("Did not retrieve connection info from database") - - host, port_string = listen_address.split(":") - - self._auth_pair_middleware = AuthPairInterceptingMiddleware() - client_options: Dict[str, Any] = { - "middleware": [AuthPairInterceptingMiddlewareFactory(self._auth_pair_middleware)], - "disable_server_verification": True, - } - - self._encrypted = encrypted - location = ( - flight.Location.for_grpc_tls(host, int(port_string)) - if self._encrypted - else flight.Location.for_grpc_tcp(host, int(port_string)) + self._gds_arrow_client = GdsArrowClient.create( + gds_query_runner, + auth=self._gds_connection_info.auth(), + encrypted=encrypted ) - self._client = flight.FlightClient(location, **client_options) def run_cypher( self, @@ -73,12 +54,18 @@ def call_procedure( params = CallParameters() if AuraDbArrowQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME == endpoint: - token, aura_db_arrow_endpoint = self._get_or_request_auth_pair() - params["token"] = token - params["host"] = aura_db_arrow_endpoint - params["config"]["useEncryption"] = self._encrypted + host, port = self._gds_arrow_client.connection_info() + token = self._gds_arrow_client.get_or_request_token() + params["arrowConfiguration"] = { + "host": host, + "port": port, + "token": token, + } + + return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) elif ".write" in endpoint and self.is_remote_projected_graph(params["graph_name"]): + raise "todo" token, aura_db_arrow_endpoint = self._get_or_request_auth_pair() host, port_string = aura_db_arrow_endpoint.split(":") params["config"]["arrowConnectionInfo"] = { @@ -128,57 +115,7 @@ def create_graph_constructor( return self._gds_query_runner.create_graph_constructor(graph_name, concurrency, undirected_relationship_types) def close(self) -> None: - self._client.close() + self._gds_arrow_client.close() self._gds_query_runner.close() self._db_query_runner.close() - def _get_or_request_auth_pair(self) -> Tuple[str, str]: - self._client.authenticate_basic_token(self._auth[0], self._auth[1]) - return (self._auth_pair_middleware.token(), self._auth_pair_middleware.endpoint()) - - -class AuthPairInterceptingMiddlewareFactory(ClientMiddlewareFactory): # type: ignore - def __init__(self, middleware: "AuthPairInterceptingMiddleware", *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._middleware = middleware - - def start_call(self, info: Any) -> "AuthPairInterceptingMiddleware": - return self._middleware - - -class AuthPairInterceptingMiddleware(ClientMiddleware): # type: ignore - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - - def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header = headers.get("authorization") - auth_type, token = self._read_auth_header(auth_header) - if auth_type == "Bearer": - self._token = token - - self._arrow_address = self._read_address_header(headers.get("arrowpluginaddress")) - - def sending_headers(self) -> Dict[str, str]: - return {} - - def token(self) -> str: - return self._token - - def endpoint(self) -> str: - return self._arrow_address - - def _read_auth_header(self, auth_header: Any) -> Tuple[str, str]: - if isinstance(auth_header, List): - auth_header = auth_header[0] - elif not isinstance(auth_header, str): - raise ValueError("Incompatible header format '{}'", auth_header) - - auth_type, token = auth_header.split(" ", 1) - return (str(auth_type), str(token)) - - def _read_address_header(self, address_header: Any) -> str: - if isinstance(address_header, List): - return str(address_header[0]) - if isinstance(address_header, str): - return address_header - raise ValueError("Incompatible header format '{}'", address_header) diff --git a/graphdatascience/tests/unit/test_arrow_runner.py b/graphdatascience/tests/unit/test_arrow_runner.py index d8ed51125..e92b198ad 100644 --- a/graphdatascience/tests/unit/test_arrow_runner.py +++ b/graphdatascience/tests/unit/test_arrow_runner.py @@ -20,7 +20,7 @@ def test_create(runner: CollectingQueryRunner) -> None: assert isinstance(arrow_runner, ArrowQueryRunner) with pytest.raises(FlightUnavailableError, match=".+ failed to connect .+ ipv4:127.0.0.1:1234: .+"): - arrow_runner._flight_client.list_actions() + arrow_runner._gds_arrow_client.send_action("TEST", {}) @pytest.mark.parametrize("server_version", [ServerVersion(2, 6, 0)]) @@ -41,4 +41,4 @@ def test_create_with_provided_connection(runner: CollectingQueryRunner) -> None: assert isinstance(arrow_runner, ArrowQueryRunner) with pytest.raises(FlightUnavailableError, match=".+ failed to connect .+ ipv4:127.0.0.1:4321: .+"): - arrow_runner._flight_client.list_actions() + arrow_runner._gds_arrow_client.send_action("TEST", {}) diff --git a/graphdatascience/tests/unit/test_graph_ops.py b/graphdatascience/tests/unit/test_graph_ops.py index 62f13c031..9f11a8785 100644 --- a/graphdatascience/tests/unit/test_graph_ops.py +++ b/graphdatascience/tests/unit/test_graph_ops.py @@ -91,22 +91,21 @@ def test_project_subgraph(runner: CollectingQueryRunner, gds: GraphDataScience) } -@pytest.mark.parametrize("server_version", [ServerVersion(2, 6, 0)]) +@pytest.mark.parametrize("server_version", [ServerVersion(2, 7, 0)]) def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None: aura_gds.graph.project("g", "RETURN gds.graph.project.remote(0, 1, null)") assert ( runner.last_query() - == "CALL gds.graph.project.remoteDb($graph_name, $query, $token, $host, $remote_database, $config)" + == "CALL gds.arrow.project($graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" ) # injection of token and host into the params is done by the actual query runner assert runner.last_params() == { "graph_name": "g", - "token": "<>", - "host": "<>", + "concurrency": 4, + "inverse_indexed_relationship_types": [], "query": "RETURN gds.graph.project.remote(0, 1, null)", - "remote_database": "neo4j", - "config": {}, + "undirected_relationship_types": [], } @@ -703,26 +702,7 @@ def test_graph_relationships_to_undirected(runner: CollectingQueryRunner, gds: G } -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0)) -def test_remote_projection_on_specific_database(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None: - aura_gds.set_database("bar") - G, _ = aura_gds.graph.project("g", "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") - - assert ( - runner.last_query() - == "CALL gds.graph.project.remoteDb($graph_name, $query, $token, $host, $remote_database, $config)" - ) - assert runner.last_params() == { - "graph_name": "g", - "token": "<>", - "host": "<>", - "query": "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)", - "remote_database": "bar", - "config": {}, - } - - -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0)) +@pytest.mark.parametrize("server_version", [ServerVersion(2, 7, 0)]) def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None: G, _ = aura_gds.graph.project( graph_name="g", @@ -734,21 +714,19 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura relationshipProperties: {y: [1, 2]} }) """, - undirectedRelationshipTypes=["R"], - inverseIndexedRelationshipTypes=["R"], - nodePropertySchema={"x": GdsPropertyTypes.LONG}, - relationshipPropertySchema={"y": GdsPropertyTypes.LONG_ARRAY}, + concurrency=8, + undirected_relationship_types=["R"], + inverse_indexed_relationship_types=["R"], ) assert ( runner.last_query() - == "CALL gds.graph.project.remoteDb($graph_name, $query, $token, $host, $remote_database, $config)" + == "CALL gds.arrow.project($graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" ) + assert runner.last_params() == { "graph_name": "g", - "token": "<>", - "host": "<>", - "remote_database": "neo4j", + "concurrency": 8, "query": """ MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m, { @@ -757,10 +735,6 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura relationshipProperties: {y: [1, 2]} }) """, - "config": { - "undirectedRelationshipTypes": ["R"], - "inverseIndexedRelationshipTypes": ["R"], - "nodePropertySchema": {"x": "long"}, - "relationshipPropertySchema": {"y": "long[]"}, - }, + "undirected_relationship_types": ["R"], + "inverse_indexed_relationship_types": ["R"], } From 722e39793a65c087f6c6910091e716369258bd20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 27 Mar 2024 17:30:59 +0100 Subject: [PATCH 04/22] Fix most style and typing issues Co-authored-by: Mats Rydberg --- .../graph/graph_remote_project_runner.py | 11 ++++- .../query_runner/arrow_query_runner.py | 28 +++++------- .../aura_db_arrow_query_runner.py | 15 ++----- .../query_runner/gds_arrow_client.py | 44 +++++++++---------- graphdatascience/tests/unit/test_graph_ops.py | 7 +-- mypy.ini | 6 +++ 6 files changed, 56 insertions(+), 55 deletions(-) diff --git a/graphdatascience/graph/graph_remote_project_runner.py b/graphdatascience/graph/graph_remote_project_runner.py index 59fe45387..1eae82c90 100644 --- a/graphdatascience/graph/graph_remote_project_runner.py +++ b/graphdatascience/graph/graph_remote_project_runner.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional +from typing import List, Optional from ..error.illegal_attr_checker import IllegalAttrChecker from ..query_runner.aura_db_arrow_query_runner import AuraDbArrowQueryRunner @@ -13,7 +13,14 @@ class GraphProjectRemoteRunner(IllegalAttrChecker): @compatible_with("project", min_inclusive=ServerVersion(2, 7, 0)) - def __call__(self, graph_name: str, query: str, concurrency: int = 4, undirected_relationship_types: Optional[List[str]]=None, inverse_indexed_relationship_types: Optional[List[str]]=None) -> GraphCreateResult: + def __call__( + self, + graph_name: str, + query: str, + concurrency: int = 4, + undirected_relationship_types: Optional[List[str]] = None, + inverse_indexed_relationship_types: Optional[List[str]] = None, + ) -> GraphCreateResult: if inverse_indexed_relationship_types is None: inverse_indexed_relationship_types = [] if undirected_relationship_types is None: diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index 6e0aba51b..ad3d53ae3 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -4,17 +4,16 @@ from typing import Any, Dict, List, Optional, Tuple from pandas import DataFrame -from pyarrow.types import is_dictionary # type: ignore -from graphdatascience.server_version.compatible_with import ( - IncompatibleServerVersionError, -) +from ..call_parameters import CallParameters +from ..server_version.server_version import ServerVersion from .arrow_graph_constructor import ArrowGraphConstructor from .gds_arrow_client import GdsArrowClient from .graph_constructor import GraphConstructor from .query_runner import QueryRunner -from ..call_parameters import CallParameters -from ..server_version.server_version import ServerVersion +from graphdatascience.server_version.compatible_with import ( + IncompatibleServerVersionError, +) class ArrowQueryRunner(QueryRunner): @@ -27,6 +26,9 @@ def create( tls_root_certs: Optional[bytes] = None, connection_string_override: Optional[str] = None, ) -> QueryRunner: + if not GdsArrowClient.is_arrow_enabled(fallback_query_runner): + return fallback_query_runner + gds_arrow_client = GdsArrowClient.create( fallback_query_runner, auth, @@ -36,14 +38,7 @@ def create( connection_string_override, ) - if gds_arrow_client is not None: - return ArrowQueryRunner( - gds_arrow_client, - fallback_query_runner, - fallback_query_runner.server_version() - ) - else: - return fallback_query_runner + return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version()) def __init__( self, @@ -198,7 +193,9 @@ def call_procedure( new_endpoint="gds.graph.relationships.stream", ) - return self._gds_arrow_client.get_property(self.database(), graph_name, endpoint, {"relationship_types": relationship_types}) + return self._gds_arrow_client.get_property( + self.database(), graph_name, endpoint, {"relationship_types": relationship_types} + ) return self._fallback_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) @@ -250,4 +247,3 @@ def create_graph_constructor( concurrency, undirected_relationship_types, ) - diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index c1cf90677..49348a4a2 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -1,12 +1,10 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional -from pandas import DataFrame, Series -from pyarrow import flight -from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory +from pandas import DataFrame -from .gds_arrow_client import GdsArrowClient from ..call_parameters import CallParameters from ..session.dbms_connection_info import DbmsConnectionInfo +from .gds_arrow_client import GdsArrowClient from .query_runner import QueryRunner from graphdatascience.query_runner.graph_constructor import GraphConstructor from graphdatascience.server_version.server_version import ServerVersion @@ -25,11 +23,8 @@ def __init__( self._gds_query_runner = gds_query_runner self._db_query_runner = db_query_runner self._gds_connection_info = gds_connection_info - self._gds_arrow_client = GdsArrowClient.create( - gds_query_runner, - auth=self._gds_connection_info.auth(), - encrypted=encrypted + gds_query_runner, auth=self._gds_connection_info.auth(), encrypted=encrypted ) def run_cypher( @@ -65,7 +60,6 @@ def call_procedure( return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) elif ".write" in endpoint and self.is_remote_projected_graph(params["graph_name"]): - raise "todo" token, aura_db_arrow_endpoint = self._get_or_request_auth_pair() host, port_string = aura_db_arrow_endpoint.split(":") params["config"]["arrowConnectionInfo"] = { @@ -118,4 +112,3 @@ def close(self) -> None: self._gds_arrow_client.close() self._gds_query_runner.close() self._db_query_runner.close() - diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index c6697f2c3..49fc87506 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -3,10 +3,11 @@ import time import warnings from abc import ABC -from typing import Optional, Tuple, Any, Dict +from typing import Any, Dict, Optional, Tuple from pandas import DataFrame from pyarrow import flight, Table, ChunkedArray, chunked_array, Schema +from pyarrow._flight import FlightStreamReader, FlightStreamWriter from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory from pyarrow.types import is_dictionary @@ -16,22 +17,21 @@ class GdsArrowClient(ABC): + @staticmethod + def is_arrow_enabled(query_runner: QueryRunner) -> bool: + arrow_info = query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() + return not not arrow_info["running"] @staticmethod def create( - query_runner: QueryRunner, - auth: Optional[Tuple[str, str]] = None, - encrypted: bool = False, - disable_server_verification: bool = False, - tls_root_certs: Optional[bytes] = None, - connection_string_override: Optional[str] = None, - ) -> "Optional[GdsArrowClient]": - arrow_info = ( - query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() - ) - - if not arrow_info["running"]: - return None + query_runner: QueryRunner, + auth: Optional[Tuple[str, str]] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + connection_string_override: Optional[str] = None, + ) -> "GdsArrowClient": + arrow_info = query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() server_version = query_runner.server_version() connection_string: str @@ -72,11 +72,7 @@ def __init__( self._port = port self._auth = auth - location = ( - flight.Location.for_grpc_tls(host, port) - if encrypted - else flight.Location.for_grpc_tcp(host, port) - ) + location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port) client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} if auth: @@ -90,14 +86,16 @@ def __init__( def connection_info(self) -> tuple[str, int]: return self._host, self._port - def get_or_request_token(self) -> str: + def get_or_request_token(self) -> Optional[str]: if self._auth: self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) return self._auth_middleware.token() else: return "IGNORED" - def get_property(self, database: str, graph_name: str, procedure_name: str, configuration: Dict[str, Any]) -> DataFrame: + def get_property( + self, database: Optional[str], graph_name: str, procedure_name: str, configuration: Dict[str, Any] + ) -> DataFrame: payload = { "database_name": database, "graph_name": graph_name, @@ -143,7 +141,7 @@ def send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None: def start_put(self, payload: dict[str, Any], schema: Schema) -> Tuple["FlightStreamWriter", "FlightStreamReader"]: flight_descriptor = self._versioned_flight_descriptor(payload) upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) - return self._flight_client.do_put(upload_descriptor, schema) + return self._flight_client.do_put(upload_descriptor, schema) # type: ignore def close(self) -> None: self._flight_client.close() @@ -195,7 +193,7 @@ def start_call(self, info: Any) -> "AuthMiddleware": class AuthMiddleware(ClientMiddleware): # type: ignore - def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: + def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._auth = auth self._token: Optional[str] = None diff --git a/graphdatascience/tests/unit/test_graph_ops.py b/graphdatascience/tests/unit/test_graph_ops.py index 9f11a8785..eaeb7bf57 100644 --- a/graphdatascience/tests/unit/test_graph_ops.py +++ b/graphdatascience/tests/unit/test_graph_ops.py @@ -1,7 +1,6 @@ import pytest from pandas import DataFrame -from ...session.schema import GdsPropertyTypes from .conftest import CollectingQueryRunner from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.server_version.server_version import ServerVersion @@ -97,7 +96,8 @@ def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataSc assert ( runner.last_query() - == "CALL gds.arrow.project($graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" + == "CALL gds.arrow.project(" + + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" ) # injection of token and host into the params is done by the actual query runner assert runner.last_params() == { @@ -721,7 +721,8 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura assert ( runner.last_query() - == "CALL gds.arrow.project($graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" + == "CALL gds.arrow.project(" + + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" ) assert runner.last_params() == { diff --git a/mypy.ini b/mypy.ini index e4ddbba40..3800f678f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -8,6 +8,12 @@ ignore_missing_imports = True [mypy-pyarrow.flight] ignore_missing_imports = True +[mypy-pyarrow._flight] +ignore_missing_imports = True + +[mypy-pyarrow.types] +ignore_missing_imports = True + [mypy-textdistance] ignore_missing_imports = True From 06b7455a2555aed25783a376a5302927fd354a51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 3 Apr 2024 15:01:46 +0200 Subject: [PATCH 05/22] Fix test_remote_graph_ops tests Co-authored-by: Mats Rydberg --- graphdatascience/query_runner/aura_db_arrow_query_runner.py | 3 +++ graphdatascience/query_runner/gds_arrow_client.py | 2 +- graphdatascience/tests/integration/test_remote_graph_ops.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index 49348a4a2..7fba4fa80 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -51,6 +51,9 @@ def call_procedure( if AuraDbArrowQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME == endpoint: host, port = self._gds_arrow_client.connection_info() token = self._gds_arrow_client.get_or_request_token() + if token is None: + token = "IGNORED" + params["arrowConfiguration"] = { "host": host, "port": port, diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 49fc87506..e27842756 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -228,7 +228,7 @@ def received_headers(self, headers: Dict[str, Any]) -> None: def sending_headers(self) -> Dict[str, str]: token = self.token() if not token: - username, password = self._factory.auth + username, password = self._auth auth_token = f"{username}:{password}" auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") # There seems to be a bug, `authorization` must be lower key diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index 7d9da3b7c..e51dfaa48 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -35,7 +35,7 @@ def run_around_tests(gds_with_cloud_setup: AuraGraphDataScience) -> Generator[No @pytest.mark.cloud_architecture -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0)) +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None: G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") From 45655b59dd58860e28fa1cdeb325c3fe1497de82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Sun, 14 Apr 2024 10:14:22 +0200 Subject: [PATCH 06/22] Support remote writeback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sören Reichardt --- .../graph/base_graph_proc_runner.py | 3 +- .../aura_db_arrow_query_runner.py | 151 +++++++++++++++--- .../query_runner/gds_arrow_client.py | 4 +- .../integration/test_remote_graph_ops.py | 92 ++++++++++- .../tests/unit/test_arrow_runner.py | 6 +- 5 files changed, 226 insertions(+), 30 deletions(-) diff --git a/graphdatascience/graph/base_graph_proc_runner.py b/graphdatascience/graph/base_graph_proc_runner.py index 6d5531580..280e2a16b 100644 --- a/graphdatascience/graph/base_graph_proc_runner.py +++ b/graphdatascience/graph/base_graph_proc_runner.py @@ -517,7 +517,8 @@ def writeRelationship( ).squeeze() @multimethod - def removeNodeProperties(self) -> None: ... + def removeNodeProperties(self) -> None: + ... @removeNodeProperties.register @graph_type_check diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index 7fba4fa80..d0fe42b7a 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -1,3 +1,4 @@ +import datetime from typing import Any, Dict, List, Optional from pandas import DataFrame @@ -26,6 +27,7 @@ def __init__( self._gds_arrow_client = GdsArrowClient.create( gds_query_runner, auth=self._gds_connection_info.auth(), encrypted=encrypted ) + self._encrypted = encrypted def run_cypher( self, @@ -49,28 +51,10 @@ def call_procedure( params = CallParameters() if AuraDbArrowQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME == endpoint: - host, port = self._gds_arrow_client.connection_info() - token = self._gds_arrow_client.get_or_request_token() - if token is None: - token = "IGNORED" - - params["arrowConfiguration"] = { - "host": host, - "port": port, - "token": token, - } - - return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) + return self._remote_projection(endpoint, params, yields, database, logging, custom_error) elif ".write" in endpoint and self.is_remote_projected_graph(params["graph_name"]): - token, aura_db_arrow_endpoint = self._get_or_request_auth_pair() - host, port_string = aura_db_arrow_endpoint.split(":") - params["config"]["arrowConnectionInfo"] = { - "hostname": host, - "port": int(port_string), - "bearerToken": token, - "useEncryption": self._encrypted, - } + return self._remote_write_back(endpoint, params, yields, database, logging, custom_error) return self._gds_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) @@ -115,3 +99,130 @@ def close(self) -> None: self._gds_arrow_client.close() self._gds_query_runner.close() self._db_query_runner.close() + + def _remote_projection( + self, + endpoint: str, + params: CallParameters, + yields: Optional[List[str]] = None, + database: Optional[str] = None, + logging: bool = False, + custom_error: bool = True, + ) -> DataFrame: + host, port = self._gds_arrow_client.connection_info() + token = self._gds_arrow_client.get_or_request_token() + if token is None: + token = "IGNORED" + + params["arrowConfiguration"] = { + "host": host, + "port": port, + "token": token, + "encrypted": self._encrypted, + } + return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) + + def _remote_write_back( + self, + endpoint: str, + params: CallParameters, + yields: Optional[List[str]] = None, + database: Optional[str] = None, + logging: bool = False, + custom_error: bool = True, + ) -> DataFrame: + if params["config"] is None: + params["config"] = {} + + params["config"]["writeToResultStore"] = True # type: ignore + gds_write_result = self._gds_query_runner.call_procedure( + endpoint, params, yields, database, logging, custom_error + ) + + token = self._gds_arrow_client.get_or_request_token() + host, port = self._gds_arrow_client.connection_info() + write_params = { + "graphName": params["graph_name"], + "databaseName": self._gds_query_runner.database(), + "writeConfiguration": self._extract_write_back_arguments(endpoint, params), + "arrowConfiguration": { + "host": host, + "port": port, + "token": token, + "encrypted": self._encrypted, + }, + } + + write_back_start = datetime.datetime.now() + database_write_result = self._db_query_runner.call_procedure( + "gds.arrow.write", CallParameters(write_params), yields, None, False, False + ) + write_millis = (datetime.datetime.now() - write_back_start).microseconds / 100 + gds_write_result["writeMillis"] = write_millis + + if "nodePropertiesWritten" in gds_write_result: + gds_write_result["nodePropertiesWritten"] = database_write_result["writtenNodeProperties"] + if "propertiesWritten" in gds_write_result: + gds_write_result["propertiesWritten"] = database_write_result["writtenNodeProperties"] + if "nodeLabelsWritten" in gds_write_result: + gds_write_result["nodeLabelsWritten"] = database_write_result["writtenNodeLabels"] + if "relationshipsWritten" in gds_write_result: + gds_write_result["relationshipsWritten"] = database_write_result["writtenRelationships"] + + return gds_write_result + + @staticmethod + def _extract_write_back_arguments(proc_name: str, params: dict[str, Any]) -> dict[str, Any]: + config = params.get("config", {}) + write_config = {} + + if "concurrency" in config: + write_config["concurrency"] = config["concurrency"] + + if "gds.shortestPath" in proc_name: + write_config["relationshipType"] = config["writeRelationshipType"] + + write_node_ids = config.get("writeNodeIds") + write_costs = config.get("writeCosts") + + if write_node_ids and write_costs: + write_config["relationshipProperties"] = ["totalCost", "nodeIds", "costs"] + elif write_node_ids: + write_config["relationshipProperties"] = ["totalCost", "nodeIds"] + elif write_costs: + write_config["relationshipProperties"] = ["totalCost", "costs"] + + elif "gds.graph." in proc_name: + if "gds.graph.nodeProperties.write" == proc_name: + write_config["nodeProperties"] = params["properties"] + write_config["nodeLabels"] = params["entities"] + + elif "gds.graph.nodeLabel.write" == proc_name: + write_config["nodeLabels"] = [params["node_label"]] + + elif "gds.graph.relationshipProperties.write" == proc_name: + write_config["relationshipProperties"] = params["relationship_properties"] + write_config["relationshipType"] = params["relationship_type"] + + elif "gds.graph.relationship.write" == proc_name: + if "relationship_property" in params and params["relationship_property"] != "": + write_config["relationshipProperties"] = [params["relationship_property"]] + write_config["relationshipType"] = params["relationship_type"] + + else: + raise ValueError(f"Unsupported procedure name: {proc_name}") + + else: + if "writeRelationshipType" in config: + write_config["relationshipType"] = config["writeRelationshipType"] + if "writeProperty" in config: + write_config["relationshipProperties"] = [config["writeProperty"]] + else: + if "writeProperty" in config: + write_config["nodeProperties"] = [config["writeProperty"]] + if "nodeLabels" in params: + write_config["nodeLabels"] = params["nodeLabels"] + else: + write_config["nodeLabels"] = ["*"] + + return write_config diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index e27842756..97c10bccb 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -6,14 +6,14 @@ from typing import Any, Dict, Optional, Tuple from pandas import DataFrame -from pyarrow import flight, Table, ChunkedArray, chunked_array, Schema +from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight from pyarrow._flight import FlightStreamReader, FlightStreamWriter from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory from pyarrow.types import is_dictionary +from ..server_version.server_version import ServerVersion from .arrow_endpoint_version import ArrowEndpointVersion from .query_runner import QueryRunner -from ..server_version.server_version import ServerVersion class GdsArrowClient(ABC): diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index e51dfaa48..02aae161d 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -44,10 +44,98 @@ def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None: @pytest.mark.cloud_architecture -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0)) -def test_remote_write_back(gds_with_cloud_setup: AuraGraphDataScience) -> None: +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_page_rank(gds_with_cloud_setup: AuraGraphDataScience) -> None: G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") result = gds_with_cloud_setup.pageRank.write(G, writeProperty="score") assert result["nodePropertiesWritten"] == 3 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_similarity(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") + + result = gds_with_cloud_setup.nodeSimilarity.write( + G, writeRelationshipType="SIMILAR", writeProperty="score", similarityCutoff=0 + ) + + assert result["relationshipsWritten"] == 4 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_properties(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") + result = gds_with_cloud_setup.pageRank.mutate(G, mutateProperty="score") + result = gds_with_cloud_setup.graph.nodeProperties.write(G, node_properties=["score"]) + + assert result["propertiesWritten"] == 3 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_label(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") + result = gds_with_cloud_setup.graph.nodeLabel.write(G, "Foo", nodeFilter="*") + + assert result["nodeLabelsWritten"] == 3 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_relationship_topology(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m, {relationshipType: 'FOO'})" + ) + result = gds_with_cloud_setup.graph.relationship.write(G, "FOO") + + assert result["relationshipsWritten"] == 4 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_relationship_property(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, + "MATCH (n)-->(m) " + "RETURN gds.graph.project.remote(n, m, {relationshipType: 'FOO', relationshipProperties: {bar: 42}})", + ) + result = gds_with_cloud_setup.graph.relationship.write(G, "FOO", "bar") + + assert result["relationshipsWritten"] == 4 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_relationship_properties(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, + "MATCH (n)-->(m) " + "RETURN gds.graph.project.remote(" + " n, " + " m, " + " {relationshipType: 'FOO', relationshipProperties: {bar: 42, foo: 1337}}" + ")", + ) + result = gds_with_cloud_setup.graph.relationshipProperties.write(G, "FOO", ["bar", "foo"]) + + assert result["relationshipsWritten"] == 4 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_relationship_property_from_pathfinding_algo( + gds_with_cloud_setup: AuraGraphDataScience, +) -> None: + G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") + + source = gds_with_cloud_setup.find_node_id(properties={"x": 1}) + target = gds_with_cloud_setup.find_node_id(properties={"x": 2}) + result = gds_with_cloud_setup.shortestPath.dijkstra.write( + G, sourceNode=source, targetNodes=target, writeRelationshipType="PATH", writeCosts=True + ) + + assert result["relationshipsWritten"] == 1 diff --git a/graphdatascience/tests/unit/test_arrow_runner.py b/graphdatascience/tests/unit/test_arrow_runner.py index e92b198ad..204e0b175 100644 --- a/graphdatascience/tests/unit/test_arrow_runner.py +++ b/graphdatascience/tests/unit/test_arrow_runner.py @@ -3,11 +3,7 @@ from pyarrow.flight import FlightUnavailableError from .conftest import CollectingQueryRunner -from graphdatascience.query_runner.arrow_query_runner import ( - ArrowQueryRunner, - AuthFactory, - AuthMiddleware, -) +from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner from graphdatascience.server_version.server_version import ServerVersion From 9a4dee95b04fb39ae38d2833b8bb960319d3a5e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 24 Apr 2024 11:20:31 +0200 Subject: [PATCH 07/22] Apply suggestions from code review --- .../graph/base_graph_proc_runner.py | 3 +- .../aura_db_arrow_query_runner.py | 39 +++++++++---------- .../query_runner/gds_arrow_client.py | 3 +- 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/graphdatascience/graph/base_graph_proc_runner.py b/graphdatascience/graph/base_graph_proc_runner.py index 280e2a16b..6d5531580 100644 --- a/graphdatascience/graph/base_graph_proc_runner.py +++ b/graphdatascience/graph/base_graph_proc_runner.py @@ -517,8 +517,7 @@ def writeRelationship( ).squeeze() @multimethod - def removeNodeProperties(self) -> None: - ... + def removeNodeProperties(self) -> None: ... @removeNodeProperties.register @graph_type_check diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index d0fe42b7a..e5075cd10 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -109,17 +109,7 @@ def _remote_projection( logging: bool = False, custom_error: bool = True, ) -> DataFrame: - host, port = self._gds_arrow_client.connection_info() - token = self._gds_arrow_client.get_or_request_token() - if token is None: - token = "IGNORED" - - params["arrowConfiguration"] = { - "host": host, - "port": port, - "token": token, - "encrypted": self._encrypted, - } + self._inject_connection_parameters(params) return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) def _remote_write_back( @@ -139,19 +129,12 @@ def _remote_write_back( endpoint, params, yields, database, logging, custom_error ) - token = self._gds_arrow_client.get_or_request_token() - host, port = self._gds_arrow_client.connection_info() write_params = { "graphName": params["graph_name"], "databaseName": self._gds_query_runner.database(), "writeConfiguration": self._extract_write_back_arguments(endpoint, params), - "arrowConfiguration": { - "host": host, - "port": port, - "token": token, - "encrypted": self._encrypted, - }, } + self._inject_connection_parameters(write_params) write_back_start = datetime.datetime.now() database_write_result = self._db_query_runner.call_procedure( @@ -171,15 +154,29 @@ def _remote_write_back( return gds_write_result + def _inject_connection_parameters(self, params: dict[str, Any]) -> None: + host, port = self._gds_arrow_client.connection_info() + token = self._gds_arrow_client.get_or_request_token() + if token is None: + token = "IGNORED" + params["arrowConfiguration"] = { + "host": host, + "port": port, + "token": token, + "encrypted": self._encrypted, + } + @staticmethod def _extract_write_back_arguments(proc_name: str, params: dict[str, Any]) -> dict[str, Any]: config = params.get("config", {}) write_config = {} - if "concurrency" in config: + if "writeConcurrency" in config: + write_config["concurrency"] = config["writeConcurrency"] + elif "concurrency" in config: write_config["concurrency"] = config["concurrency"] - if "gds.shortestPath" in proc_name: + if "gds.shortestPath" in proc_name or "gds.allShortestPaths" in proc_name: write_config["relationshipType"] = config["writeRelationshipType"] write_node_ids = config.get("writeNodeIds") diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 97c10bccb..d6264505e 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -2,7 +2,6 @@ import json import time import warnings -from abc import ABC from typing import Any, Dict, Optional, Tuple from pandas import DataFrame @@ -16,7 +15,7 @@ from .query_runner import QueryRunner -class GdsArrowClient(ABC): +class GdsArrowClient: @staticmethod def is_arrow_enabled(query_runner: QueryRunner) -> bool: arrow_info = query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() From ff7fc447b0e742d0192e5d5a887b1c91b275a596 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 24 Apr 2024 11:51:14 +0200 Subject: [PATCH 08/22] Check that database is set in GdsArrowClient * also fix type info for Python 3.8 --- graphdatascience/query_runner/gds_arrow_client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index d6264505e..b23d1b388 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -82,7 +82,7 @@ def __init__( self._flight_client = flight.FlightClient(location, **client_options) - def connection_info(self) -> tuple[str, int]: + def connection_info(self) -> Tuple[str, int]: return self._host, self._port def get_or_request_token(self) -> Optional[str]: @@ -95,6 +95,12 @@ def get_or_request_token(self) -> Optional[str]: def get_property( self, database: Optional[str], graph_name: str, procedure_name: str, configuration: Dict[str, Any] ) -> DataFrame: + if not database: + raise ValueError( + "For this call you must have explicitly specified a valid Neo4j database to execute on, " + "using `GraphDataScience.set_database`." + ) + payload = { "database_name": database, "graph_name": graph_name, From c5408d4c254d437d1b57706ede16a57ebdd11080 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Mon, 29 Apr 2024 11:57:57 +0200 Subject: [PATCH 09/22] Fix bug related to push-based writeback of path finding algos MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Max Kießling --- graphdatascience/graph/graph_entity_ops_runner.py | 2 +- graphdatascience/query_runner/aura_db_arrow_query_runner.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/graphdatascience/graph/graph_entity_ops_runner.py b/graphdatascience/graph/graph_entity_ops_runner.py index 0c1a9ba5a..fd3a79716 100644 --- a/graphdatascience/graph/graph_entity_ops_runner.py +++ b/graphdatascience/graph/graph_entity_ops_runner.py @@ -177,7 +177,7 @@ def add_property(query: str, prop: str) -> str: return reduce(add_property, db_node_properties, query_prefix) @compatible_with("write", min_inclusive=ServerVersion(2, 2, 0)) - def write(self, G: Graph, node_properties: List[str], node_labels: Strings = ["*"], **config: Any) -> "Series[Any]": + def write(self, G: Graph, node_properties: Strings, node_labels: Strings = ["*"], **config: Any) -> "Series[Any]": self._namespace += ".write" return self._handle_properties(G, node_properties, node_labels, config).squeeze() # type: ignore diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index e5075cd10..a65f6336c 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -188,6 +188,8 @@ def _extract_write_back_arguments(proc_name: str, params: dict[str, Any]) -> dic write_config["relationshipProperties"] = ["totalCost", "nodeIds"] elif write_costs: write_config["relationshipProperties"] = ["totalCost", "costs"] + else: + write_config["relationshipProperties"] = ["totalCost"] elif "gds.graph." in proc_name: if "gds.graph.nodeProperties.write" == proc_name: From d741de664563d6e24a8d9c48bd453a49c065e80e Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Fri, 17 May 2024 14:54:58 +0200 Subject: [PATCH 10/22] Create GdsArrowClient outside AuraDbArrowQueryRunner --- .../query_runner/aura_db_arrow_query_runner.py | 8 ++------ graphdatascience/session/aura_graph_data_science.py | 6 +++++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index a65f6336c..a67d51b4b 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -4,7 +4,6 @@ from pandas import DataFrame from ..call_parameters import CallParameters -from ..session.dbms_connection_info import DbmsConnectionInfo from .gds_arrow_client import GdsArrowClient from .query_runner import QueryRunner from graphdatascience.query_runner.graph_constructor import GraphConstructor @@ -18,15 +17,12 @@ def __init__( self, gds_query_runner: QueryRunner, db_query_runner: QueryRunner, + arrow_client: GdsArrowClient, encrypted: bool, - gds_connection_info: DbmsConnectionInfo, ): self._gds_query_runner = gds_query_runner self._db_query_runner = db_query_runner - self._gds_connection_info = gds_connection_info - self._gds_arrow_client = GdsArrowClient.create( - gds_query_runner, auth=self._gds_connection_info.auth(), encrypted=encrypted - ) + self._gds_arrow_client = arrow_client self._encrypted = encrypted def run_cypher( diff --git a/graphdatascience/session/aura_graph_data_science.py b/graphdatascience/session/aura_graph_data_science.py index 34a748ba9..184c29c72 100644 --- a/graphdatascience/session/aura_graph_data_science.py +++ b/graphdatascience/session/aura_graph_data_science.py @@ -10,6 +10,7 @@ from graphdatascience.query_runner.aura_db_arrow_query_runner import ( AuraDbArrowQueryRunner, ) +from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.server_version.server_version import ServerVersion from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo @@ -63,8 +64,11 @@ def __init__( gds_query_runner.set_database("neo4j") self._db_query_runner.set_database("neo4j") + arrow_client = GdsArrowClient.create( + gds_query_runner, aura_db_connection_info.auth(), self._db_query_runner.encrypted() + ) self._query_runner = AuraDbArrowQueryRunner( - gds_query_runner, self._db_query_runner, self._db_query_runner.encrypted(), aura_db_connection_info + gds_query_runner, self._db_query_runner, arrow_client, self._db_query_runner.encrypted() ) self._delete_fn = delete_fn From 532f8572329ce642bec949d7c6ba9922add59f96 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Fri, 17 May 2024 14:56:07 +0200 Subject: [PATCH 11/22] Rename token method it doesn't get, it always requests --- graphdatascience/query_runner/aura_db_arrow_query_runner.py | 2 +- graphdatascience/query_runner/gds_arrow_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index a67d51b4b..5bd33e65a 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -152,7 +152,7 @@ def _remote_write_back( def _inject_connection_parameters(self, params: dict[str, Any]) -> None: host, port = self._gds_arrow_client.connection_info() - token = self._gds_arrow_client.get_or_request_token() + token = self._gds_arrow_client.request_token() if token is None: token = "IGNORED" params["arrowConfiguration"] = { diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index b23d1b388..c387cace1 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -85,7 +85,7 @@ def __init__( def connection_info(self) -> Tuple[str, int]: return self._host, self._port - def get_or_request_token(self) -> Optional[str]: + def request_token(self) -> Optional[str]: if self._auth: self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) return self._auth_middleware.token() From 9487fbc3b6fa8e451ea3c6d99bfdc203d1e90ba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Mon, 27 May 2024 16:41:15 +0200 Subject: [PATCH 12/22] Fix write back millis for remote write back --- graphdatascience/query_runner/aura_db_arrow_query_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index 5bd33e65a..814ae4145 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -1,4 +1,4 @@ -import datetime +import time from typing import Any, Dict, List, Optional from pandas import DataFrame @@ -132,11 +132,11 @@ def _remote_write_back( } self._inject_connection_parameters(write_params) - write_back_start = datetime.datetime.now() + write_back_start = time.time() database_write_result = self._db_query_runner.call_procedure( "gds.arrow.write", CallParameters(write_params), yields, None, False, False ) - write_millis = (datetime.datetime.now() - write_back_start).microseconds / 100 + write_millis = (time.time() - write_back_start) * 1000 gds_write_result["writeMillis"] = write_millis if "nodePropertiesWritten" in gds_write_result: From 0a859e2fb5daf5b788680aad965f2cfba2378d6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Mon, 27 May 2024 16:42:27 +0200 Subject: [PATCH 13/22] Fix write back of mutated node properties --- .../aura_db_arrow_query_runner.py | 3 +- .../integration/test_remote_graph_ops.py | 37 ++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index 814ae4145..55045ca8a 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -189,7 +189,8 @@ def _extract_write_back_arguments(proc_name: str, params: dict[str, Any]) -> dic elif "gds.graph." in proc_name: if "gds.graph.nodeProperties.write" == proc_name: - write_config["nodeProperties"] = params["properties"] + properties = params["properties"] + write_config["nodeProperties"] = properties if isinstance(properties, list) else [properties] write_config["nodeLabels"] = params["entities"] elif "gds.graph.nodeLabel.write" == proc_name: diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index 02aae161d..d4f45af73 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -14,9 +14,9 @@ def run_around_tests(gds_with_cloud_setup: AuraGraphDataScience) -> Generator[No gds_with_cloud_setup.run_cypher( """ CREATE - (a: Node {x: 1, y: 2, z: [42], name: "nodeA"}), - (b: Node {x: 2, y: 3, z: [1337], name: "nodeB"}), - (c: Node {x: 3, y: 4, z: [9], name: "nodeC"}), + (a: A:Node {x: 1, y: 2, z: [42], name: "nodeA"}), + (b: B:Node {x: 2, y: 3, z: [1337], name: "nodeB"}), + (c: C:Node {x: 3, y: 4, z: [9], name: "nodeC"}), (a)-[:REL {relX: 4, relY: 5}]->(b), (a)-[:REL {relX: 5, relY: 6}]->(c), (b)-[:REL {relX: 6, relY: 7}]->(c), @@ -75,6 +75,33 @@ def test_remote_write_back_node_properties(gds_with_cloud_setup: AuraGraphDataSc assert result["propertiesWritten"] == 3 +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_properties_with_multiple_labels(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, + "MATCH (n)-->(m) " + "RETURN gds.graph.project.remote(n, m, {sourceNodeLabels: labels(n), targetNodeLabels: labels(m)})", + ) + result = gds_with_cloud_setup.pageRank.write(G, writeProperty="score") + + assert result["nodePropertiesWritten"] == 3 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_properties_with_select_labels(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, + "MATCH (n)-->(m) " + "RETURN gds.graph.project.remote(n, m, {sourceNodeLabels: labels(n), targetNodeLabels: labels(m)})", + ) + result = gds_with_cloud_setup.pageRank.mutate(G, mutateProperty="score") + result = gds_with_cloud_setup.graph.nodeProperties.write(G, "score", ["A"]) + + assert result["propertiesWritten"] == 1 + + @pytest.mark.cloud_architecture @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) def test_remote_write_back_node_label(gds_with_cloud_setup: AuraGraphDataScience) -> None: @@ -108,7 +135,7 @@ def test_remote_write_back_relationship_property(gds_with_cloud_setup: AuraGraph assert result["relationshipsWritten"] == 4 -@pytest.mark.cloud_architecture +# @pytest.mark.cloud_architecture @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) def test_remote_write_back_relationship_properties(gds_with_cloud_setup: AuraGraphDataScience) -> None: G, result = gds_with_cloud_setup.graph.project( @@ -125,7 +152,7 @@ def test_remote_write_back_relationship_properties(gds_with_cloud_setup: AuraGra assert result["relationshipsWritten"] == 4 -@pytest.mark.cloud_architecture +# @pytest.mark.cloud_architecture @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) def test_remote_write_back_relationship_property_from_pathfinding_algo( gds_with_cloud_setup: AuraGraphDataScience, From e36aa6f471618c59ae27b43127256b740d515448 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 30 May 2024 16:08:29 +0200 Subject: [PATCH 14/22] Rename class It doesn't use Arrow on the DBMS anymore --- graphdatascience/graph/graph_remote_project_runner.py | 4 ++-- ...ura_db_arrow_query_runner.py => aura_db_query_runner.py} | 4 ++-- graphdatascience/session/aura_graph_data_science.py | 6 ++---- graphdatascience/tests/unit/conftest.py | 2 +- .../tests/unit/test_aura_db_arrow_query_runner.py | 0 5 files changed, 7 insertions(+), 9 deletions(-) rename graphdatascience/query_runner/{aura_db_arrow_query_runner.py => aura_db_query_runner.py} (98%) create mode 100644 graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py diff --git a/graphdatascience/graph/graph_remote_project_runner.py b/graphdatascience/graph/graph_remote_project_runner.py index 1eae82c90..0771b3878 100644 --- a/graphdatascience/graph/graph_remote_project_runner.py +++ b/graphdatascience/graph/graph_remote_project_runner.py @@ -3,7 +3,7 @@ from typing import List, Optional from ..error.illegal_attr_checker import IllegalAttrChecker -from ..query_runner.aura_db_arrow_query_runner import AuraDbArrowQueryRunner +from ..query_runner.aura_db_query_runner import AuraDbQueryRunner from ..server_version.compatible_with import compatible_with from .graph_object import Graph from graphdatascience.call_parameters import CallParameters @@ -35,7 +35,7 @@ def __call__( ) result = self._query_runner.call_procedure( - endpoint=AuraDbArrowQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME, + endpoint=AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME, params=params, ).squeeze() return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result) diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_query_runner.py similarity index 98% rename from graphdatascience/query_runner/aura_db_arrow_query_runner.py rename to graphdatascience/query_runner/aura_db_query_runner.py index 55045ca8a..c6b2f0371 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_query_runner.py @@ -10,7 +10,7 @@ from graphdatascience.server_version.server_version import ServerVersion -class AuraDbArrowQueryRunner(QueryRunner): +class AuraDbQueryRunner(QueryRunner): GDS_REMOTE_PROJECTION_PROC_NAME = "gds.arrow.project" def __init__( @@ -46,7 +46,7 @@ def call_procedure( if params is None: params = CallParameters() - if AuraDbArrowQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME == endpoint: + if AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME == endpoint: return self._remote_projection(endpoint, params, yields, database, logging, custom_error) elif ".write" in endpoint and self.is_remote_projected_graph(params["graph_name"]): diff --git a/graphdatascience/session/aura_graph_data_science.py b/graphdatascience/session/aura_graph_data_science.py index 184c29c72..c932f2899 100644 --- a/graphdatascience/session/aura_graph_data_science.py +++ b/graphdatascience/session/aura_graph_data_science.py @@ -7,9 +7,7 @@ from graphdatascience.error.uncallable_namespace import UncallableNamespace from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner -from graphdatascience.query_runner.aura_db_arrow_query_runner import ( - AuraDbArrowQueryRunner, -) +from graphdatascience.query_runner.aura_db_query_runner import AuraDbQueryRunner from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.server_version.server_version import ServerVersion @@ -67,7 +65,7 @@ def __init__( arrow_client = GdsArrowClient.create( gds_query_runner, aura_db_connection_info.auth(), self._db_query_runner.encrypted() ) - self._query_runner = AuraDbArrowQueryRunner( + self._query_runner = AuraDbQueryRunner( gds_query_runner, self._db_query_runner, arrow_client, self._db_query_runner.encrypted() ) diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 05709d11e..391dade2f 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -116,7 +116,7 @@ def gds(runner: CollectingQueryRunner) -> Generator[GraphDataScience, None, None def aura_gds(runner: CollectingQueryRunner, mocker: MockerFixture) -> Generator[AuraGraphDataScience, None, None]: mocker.patch("graphdatascience.query_runner.neo4j_query_runner.Neo4jQueryRunner.create", return_value=runner) mocker.patch( - "graphdatascience.query_runner.aura_db_arrow_query_runner.AuraDbArrowQueryRunner.__new__", return_value=runner + "graphdatascience.query_runner.aura_db_query_runner.AuraArrowQueryRunner.__new__", return_value=runner ) mocker.patch("graphdatascience.query_runner.arrow_query_runner.ArrowQueryRunner.create", return_value=runner) aura_gds = AuraGraphDataScience( diff --git a/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py b/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py new file mode 100644 index 000000000..e69de29bb From 13a629995d678b74b1e15849856e504baf7e7203 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 30 May 2024 16:44:39 +0200 Subject: [PATCH 15/22] Do not use custom errors for DBMS queries because there's no `gds.list` proc there --- graphdatascience/query_runner/aura_db_query_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/query_runner/aura_db_query_runner.py b/graphdatascience/query_runner/aura_db_query_runner.py index c6b2f0371..733aca9aa 100644 --- a/graphdatascience/query_runner/aura_db_query_runner.py +++ b/graphdatascience/query_runner/aura_db_query_runner.py @@ -106,7 +106,7 @@ def _remote_projection( custom_error: bool = True, ) -> DataFrame: self._inject_connection_parameters(params) - return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) + return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, False) def _remote_write_back( self, From 38d26add32382bfac139cfd57fd9c46a14534e0a Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 30 May 2024 16:58:11 +0200 Subject: [PATCH 16/22] Remove schema parameters from projection example These are not necessary in the push-based remote projection --- doc/modules/ROOT/pages/tutorials/gds-sessions.adoc | 11 ----------- examples/gds-sessions.ipynb | 11 ----------- 2 files changed, 22 deletions(-) diff --git a/doc/modules/ROOT/pages/tutorials/gds-sessions.adoc b/doc/modules/ROOT/pages/tutorials/gds-sessions.adoc index e61965b47..c6fdd4c62 100644 --- a/doc/modules/ROOT/pages/tutorials/gds-sessions.adoc +++ b/doc/modules/ROOT/pages/tutorials/gds-sessions.adoc @@ -173,8 +173,6 @@ although we do not do that in this notebook. [source, python, role=no-test] ---- -from graphdatascience.session import GdsPropertyTypes - G, result = gds.graph.project( "people-and-fruits", """ @@ -201,15 +199,6 @@ G, result = gds.graph.project( relationshipType: type(rel) }) """, - nodePropertySchema={ - "age": GdsPropertyTypes.LONG, - "experience": GdsPropertyTypes.LONG, - "hipster": GdsPropertyTypes.LONG, - "tropical": GdsPropertyTypes.LONG, - "sourness": GdsPropertyTypes.DOUBLE, - "sweetness": GdsPropertyTypes.DOUBLE, - }, - relationshipPropertySchema={}, ) str(G) diff --git a/examples/gds-sessions.ipynb b/examples/gds-sessions.ipynb index fc717ef90..31bccc2f0 100644 --- a/examples/gds-sessions.ipynb +++ b/examples/gds-sessions.ipynb @@ -236,8 +236,6 @@ "metadata": {}, "outputs": [], "source": [ - "from graphdatascience.session import GdsPropertyTypes\n", - "\n", "G, result = gds.graph.project(\n", " \"people-and-fruits\",\n", " \"\"\"\n", @@ -264,15 +262,6 @@ " relationshipType: type(rel)\n", " })\n", " \"\"\",\n", - " nodePropertySchema={\n", - " \"age\": GdsPropertyTypes.LONG,\n", - " \"experience\": GdsPropertyTypes.LONG,\n", - " \"hipster\": GdsPropertyTypes.LONG,\n", - " \"tropical\": GdsPropertyTypes.LONG,\n", - " \"sourness\": GdsPropertyTypes.DOUBLE,\n", - " \"sweetness\": GdsPropertyTypes.DOUBLE,\n", - " },\n", - " relationshipPropertySchema={},\n", ")\n", "\n", "str(G)" From 2179f35208c594e51f7b60d88ff1136934866964 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 30 May 2024 17:23:32 +0200 Subject: [PATCH 17/22] Remove schema parameters from changelog --- changelog.md | 1 - 1 file changed, 1 deletion(-) diff --git a/changelog.md b/changelog.md index 0bda93aaa..f8ca9d876 100644 --- a/changelog.md +++ b/changelog.md @@ -8,7 +8,6 @@ * Add the new concept of GDS Sessions, used to manage GDS computations in Aura, based on data from an AuraDB instance. * Add a new `gds.graph.project` endpoint to project graphs from AuraDB instances to GDS sessions. - * `nodePropertySchema` and `relationshipPropertySchema` can be used to optimise remote projections. * Add a new top-level class `GdsSessions` to manage GDS sessions in Aura. * `GdsSessions` support `get_or_create()`, `list()`, and `delete()`. * Creating a new session supports various sizes. From 26f2c23176a712fa09d80bf66a551efdd20bf043 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Fri, 31 May 2024 11:42:50 +0200 Subject: [PATCH 18/22] Use typing.Dict for Python 3.8 compatibility --- .../query_runner/aura_db_query_runner.py | 4 ++-- .../query_runner/gds_arrow_client.py | 2 +- graphdatascience/session/aura_api.py | 6 +++--- graphdatascience/session/aura_api_responses.py | 16 ++++++++-------- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/graphdatascience/query_runner/aura_db_query_runner.py b/graphdatascience/query_runner/aura_db_query_runner.py index 733aca9aa..9ac2ad7bb 100644 --- a/graphdatascience/query_runner/aura_db_query_runner.py +++ b/graphdatascience/query_runner/aura_db_query_runner.py @@ -150,7 +150,7 @@ def _remote_write_back( return gds_write_result - def _inject_connection_parameters(self, params: dict[str, Any]) -> None: + def _inject_connection_parameters(self, params: Dict[str, Any]) -> None: host, port = self._gds_arrow_client.connection_info() token = self._gds_arrow_client.request_token() if token is None: @@ -163,7 +163,7 @@ def _inject_connection_parameters(self, params: dict[str, Any]) -> None: } @staticmethod - def _extract_write_back_arguments(proc_name: str, params: dict[str, Any]) -> dict[str, Any]: + def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dict[str, Any]: config = params.get("config", {}) write_config = {} diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index c387cace1..378dd76df 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -143,7 +143,7 @@ def send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None: json.loads(collected_result[0].body.to_pybytes().decode()) - def start_put(self, payload: dict[str, Any], schema: Schema) -> Tuple["FlightStreamWriter", "FlightStreamReader"]: + def start_put(self, payload: Dict[str, Any], schema: Schema) -> Tuple["FlightStreamWriter", "FlightStreamReader"]: flight_descriptor = self._versioned_flight_descriptor(payload) upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) return self._flight_client.do_put(upload_descriptor, schema) # type: ignore diff --git a/graphdatascience/session/aura_api.py b/graphdatascience/session/aura_api.py index 3f33e2409..4ce34d18e 100644 --- a/graphdatascience/session/aura_api.py +++ b/graphdatascience/session/aura_api.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from urllib.parse import urlparse import requests as req @@ -28,7 +28,7 @@ class AuraAuthToken: expires_in: int token_type: str - def __init__(self, json: dict[str, Any]) -> None: + def __init__(self, json: Dict[str, Any]) -> None: self.access_token = json["access_token"] expires_in: int = json["expires_in"] self.expires_at = int(time.time()) + expires_in @@ -275,7 +275,7 @@ def tenant_details(self) -> TenantDetails: self._tenant_details = TenantDetails.from_json(response.json()["data"]) return self._tenant_details - def _build_header(self) -> dict[str, str]: + def _build_header(self) -> Dict[str, str]: return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"} def _auth_token(self) -> str: diff --git a/graphdatascience/session/aura_api_responses.py b/graphdatascience/session/aura_api_responses.py index 239086959..008a26cb1 100644 --- a/graphdatascience/session/aura_api_responses.py +++ b/graphdatascience/session/aura_api_responses.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any, NamedTuple, Optional, Set +from typing import Any, Dict, NamedTuple, Optional, Set @dataclass(repr=True, frozen=True) @@ -20,7 +20,7 @@ class SessionDetails: created_at: datetime @classmethod - def fromJson(cls, json: dict[str, Any]) -> SessionDetails: + def fromJson(cls, json: Dict[str, Any]) -> SessionDetails: expiry_date = json.get("expiry_date") return cls( @@ -46,7 +46,7 @@ class InstanceDetails: cloud_provider: str @classmethod - def fromJson(cls, json: dict[str, Any]) -> InstanceDetails: + def fromJson(cls, json: Dict[str, Any]) -> InstanceDetails: return cls( id=json["id"], name=json["name"], @@ -64,7 +64,7 @@ class InstanceSpecificDetails(InstanceDetails): region: str @classmethod - def fromJson(cls, json: dict[str, Any]) -> InstanceSpecificDetails: + def fromJson(cls, json: Dict[str, Any]) -> InstanceSpecificDetails: return cls( id=json["id"], name=json["name"], @@ -86,7 +86,7 @@ class InstanceCreateDetails: connection_url: str @classmethod - def from_json(cls, json: dict[str, Any]) -> InstanceCreateDetails: + def from_json(cls, json: Dict[str, Any]) -> InstanceCreateDetails: fields = dataclasses.fields(cls) if any(f.name not in json for f in fields): raise RuntimeError(f"Missing required field. Expected `{[f.name for f in fields]}` but got `{json}`") @@ -101,7 +101,7 @@ class EstimationDetails: did_exceed_maximum: bool @classmethod - def from_json(cls, json: dict[str, Any]) -> EstimationDetails: + def from_json(cls, json: Dict[str, Any]) -> EstimationDetails: fields = dataclasses.fields(cls) if any(f.name not in json for f in fields): raise RuntimeError(f"Missing required field. Expected `{[f.name for f in fields]}` but got `{json}`") @@ -126,10 +126,10 @@ def from_connection_url(cls, connection_url: str) -> WaitResult: class TenantDetails: id: str ds_type: str - regions_per_provider: dict[str, Set[str]] + regions_per_provider: Dict[str, Set[str]] @classmethod - def from_json(cls, json: dict[str, Any]) -> TenantDetails: + def from_json(cls, json: Dict[str, Any]) -> TenantDetails: regions_per_provider = defaultdict(set) instance_types = set() ds_type = None From 6d63f9136478d779209141bf87975409e2d25849 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Fri, 31 May 2024 11:50:27 +0200 Subject: [PATCH 19/22] Subscript Tuple type with non-string arguments --- graphdatascience/query_runner/gds_arrow_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 378dd76df..b06ead80e 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -143,7 +143,7 @@ def send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None: json.loads(collected_result[0].body.to_pybytes().decode()) - def start_put(self, payload: Dict[str, Any], schema: Schema) -> Tuple["FlightStreamWriter", "FlightStreamReader"]: + def start_put(self, payload: Dict[str, Any], schema: Schema) -> Tuple[FlightStreamWriter, FlightStreamReader]: flight_descriptor = self._versioned_flight_descriptor(payload) upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) return self._flight_client.do_put(upload_descriptor, schema) # type: ignore From cdb4637428e922a98a7b2b7c7e167a882717b1df Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Mon, 3 Jun 2024 10:15:09 +0200 Subject: [PATCH 20/22] checkstyle --- graphdatascience/tests/unit/conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 391dade2f..329f521ff 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -115,9 +115,7 @@ def gds(runner: CollectingQueryRunner) -> Generator[GraphDataScience, None, None @pytest.fixture def aura_gds(runner: CollectingQueryRunner, mocker: MockerFixture) -> Generator[AuraGraphDataScience, None, None]: mocker.patch("graphdatascience.query_runner.neo4j_query_runner.Neo4jQueryRunner.create", return_value=runner) - mocker.patch( - "graphdatascience.query_runner.aura_db_query_runner.AuraArrowQueryRunner.__new__", return_value=runner - ) + mocker.patch("graphdatascience.query_runner.aura_db_query_runner.AuraArrowQueryRunner.__new__", return_value=runner) mocker.patch("graphdatascience.query_runner.arrow_query_runner.ArrowQueryRunner.create", return_value=runner) aura_gds = AuraGraphDataScience( gds_session_connection_info=DbmsConnectionInfo("address", "some", "auth"), From 415c7b26bbcade25acc72d3a0c592b2051f26430 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 4 Jun 2024 11:36:17 +0200 Subject: [PATCH 21/22] Fix monkey patching for unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Max Kießling --- graphdatascience/tests/unit/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 329f521ff..00f8c5052 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -115,8 +115,9 @@ def gds(runner: CollectingQueryRunner) -> Generator[GraphDataScience, None, None @pytest.fixture def aura_gds(runner: CollectingQueryRunner, mocker: MockerFixture) -> Generator[AuraGraphDataScience, None, None]: mocker.patch("graphdatascience.query_runner.neo4j_query_runner.Neo4jQueryRunner.create", return_value=runner) - mocker.patch("graphdatascience.query_runner.aura_db_query_runner.AuraArrowQueryRunner.__new__", return_value=runner) + mocker.patch("graphdatascience.query_runner.aura_db_query_runner.AuraDbQueryRunner.__new__", return_value=runner) mocker.patch("graphdatascience.query_runner.arrow_query_runner.ArrowQueryRunner.create", return_value=runner) + mocker.patch("graphdatascience.query_runner.gds_arrow_client.GdsArrowClient.create", return_value=None) aura_gds = AuraGraphDataScience( gds_session_connection_info=DbmsConnectionInfo("address", "some", "auth"), aura_db_connection_info=DbmsConnectionInfo("address", "some", "auth"), From 62eea860c0848422f5e50062e182c6e06aeb3898 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 4 Jun 2024 13:32:10 +0200 Subject: [PATCH 22/22] Uncomment test annotation for remote tests --- graphdatascience/tests/integration/test_remote_graph_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index d4f45af73..3227e4e73 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -135,7 +135,7 @@ def test_remote_write_back_relationship_property(gds_with_cloud_setup: AuraGraph assert result["relationshipsWritten"] == 4 -# @pytest.mark.cloud_architecture +@pytest.mark.cloud_architecture @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) def test_remote_write_back_relationship_properties(gds_with_cloud_setup: AuraGraphDataScience) -> None: G, result = gds_with_cloud_setup.graph.project( @@ -152,7 +152,7 @@ def test_remote_write_back_relationship_properties(gds_with_cloud_setup: AuraGra assert result["relationshipsWritten"] == 4 -# @pytest.mark.cloud_architecture +@pytest.mark.cloud_architecture @pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) def test_remote_write_back_relationship_property_from_pathfinding_algo( gds_with_cloud_setup: AuraGraphDataScience,