diff --git a/changelog.md b/changelog.md index 0bda93aaa..f8ca9d876 100644 --- a/changelog.md +++ b/changelog.md @@ -8,7 +8,6 @@ * Add the new concept of GDS Sessions, used to manage GDS computations in Aura, based on data from an AuraDB instance. * Add a new `gds.graph.project` endpoint to project graphs from AuraDB instances to GDS sessions. - * `nodePropertySchema` and `relationshipPropertySchema` can be used to optimise remote projections. * Add a new top-level class `GdsSessions` to manage GDS sessions in Aura. * `GdsSessions` support `get_or_create()`, `list()`, and `delete()`. * Creating a new session supports various sizes. diff --git a/doc/modules/ROOT/pages/tutorials/gds-sessions.adoc b/doc/modules/ROOT/pages/tutorials/gds-sessions.adoc index e61965b47..c6fdd4c62 100644 --- a/doc/modules/ROOT/pages/tutorials/gds-sessions.adoc +++ b/doc/modules/ROOT/pages/tutorials/gds-sessions.adoc @@ -173,8 +173,6 @@ although we do not do that in this notebook. [source, python, role=no-test] ---- -from graphdatascience.session import GdsPropertyTypes - G, result = gds.graph.project( "people-and-fruits", """ @@ -201,15 +199,6 @@ G, result = gds.graph.project( relationshipType: type(rel) }) """, - nodePropertySchema={ - "age": GdsPropertyTypes.LONG, - "experience": GdsPropertyTypes.LONG, - "hipster": GdsPropertyTypes.LONG, - "tropical": GdsPropertyTypes.LONG, - "sourness": GdsPropertyTypes.DOUBLE, - "sweetness": GdsPropertyTypes.DOUBLE, - }, - relationshipPropertySchema={}, ) str(G) diff --git a/examples/gds-sessions.ipynb b/examples/gds-sessions.ipynb index fc717ef90..31bccc2f0 100644 --- a/examples/gds-sessions.ipynb +++ b/examples/gds-sessions.ipynb @@ -236,8 +236,6 @@ "metadata": {}, "outputs": [], "source": [ - "from graphdatascience.session import GdsPropertyTypes\n", - "\n", "G, result = gds.graph.project(\n", " \"people-and-fruits\",\n", " \"\"\"\n", @@ -264,15 +262,6 @@ " relationshipType: type(rel)\n", " })\n", " \"\"\",\n", - " nodePropertySchema={\n", - " \"age\": GdsPropertyTypes.LONG,\n", - " \"experience\": GdsPropertyTypes.LONG,\n", - " \"hipster\": GdsPropertyTypes.LONG,\n", - " \"tropical\": GdsPropertyTypes.LONG,\n", - " \"sourness\": GdsPropertyTypes.DOUBLE,\n", - " \"sweetness\": GdsPropertyTypes.DOUBLE,\n", - " },\n", - " relationshipPropertySchema={},\n", ")\n", "\n", "str(G)" diff --git a/graphdatascience/graph/graph_entity_ops_runner.py b/graphdatascience/graph/graph_entity_ops_runner.py index 0c1a9ba5a..fd3a79716 100644 --- a/graphdatascience/graph/graph_entity_ops_runner.py +++ b/graphdatascience/graph/graph_entity_ops_runner.py @@ -177,7 +177,7 @@ def add_property(query: str, prop: str) -> str: return reduce(add_property, db_node_properties, query_prefix) @compatible_with("write", min_inclusive=ServerVersion(2, 2, 0)) - def write(self, G: Graph, node_properties: List[str], node_labels: Strings = ["*"], **config: Any) -> "Series[Any]": + def write(self, G: Graph, node_properties: Strings, node_labels: Strings = ["*"], **config: Any) -> "Series[Any]": self._namespace += ".write" return self._handle_properties(G, node_properties, node_labels, config).squeeze() # type: ignore diff --git a/graphdatascience/graph/graph_remote_proc_runner.py b/graphdatascience/graph/graph_remote_proc_runner.py index a08a268dd..cdb1e8158 100644 --- a/graphdatascience/graph/graph_remote_proc_runner.py +++ b/graphdatascience/graph/graph_remote_proc_runner.py @@ -5,5 +5,4 @@ class GraphRemoteProcRunner(BaseGraphProcRunner): @property def project(self) -> GraphProjectRemoteRunner: - self._namespace += ".project.remoteDb" return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version) diff --git a/graphdatascience/graph/graph_remote_project_runner.py b/graphdatascience/graph/graph_remote_project_runner.py index 26fbeb687..0771b3878 100644 --- a/graphdatascience/graph/graph_remote_project_runner.py +++ b/graphdatascience/graph/graph_remote_project_runner.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Any +from typing import List, Optional from ..error.illegal_attr_checker import IllegalAttrChecker +from ..query_runner.aura_db_query_runner import AuraDbQueryRunner from ..server_version.compatible_with import compatible_with from .graph_object import Graph from graphdatascience.call_parameters import CallParameters @@ -11,28 +12,30 @@ class GraphProjectRemoteRunner(IllegalAttrChecker): - _SCHEMA_KEYS = ["nodePropertySchema", "relationshipPropertySchema"] + @compatible_with("project", min_inclusive=ServerVersion(2, 7, 0)) + def __call__( + self, + graph_name: str, + query: str, + concurrency: int = 4, + undirected_relationship_types: Optional[List[str]] = None, + inverse_indexed_relationship_types: Optional[List[str]] = None, + ) -> GraphCreateResult: + if inverse_indexed_relationship_types is None: + inverse_indexed_relationship_types = [] + if undirected_relationship_types is None: + undirected_relationship_types = [] - @compatible_with("project", min_inclusive=ServerVersion(2, 6, 0)) - def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult: - placeholder = "<>" # host and token will be added by query runner - self.map_property_types(config) params = CallParameters( graph_name=graph_name, query=query, - token=placeholder, - host=placeholder, - remote_database=self._query_runner.database(), - config=config, + concurrency=concurrency, + undirected_relationship_types=undirected_relationship_types, + inverse_indexed_relationship_types=inverse_indexed_relationship_types, ) + result = self._query_runner.call_procedure( - endpoint=self._namespace, + endpoint=AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME, params=params, ).squeeze() return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result) - - @staticmethod - def map_property_types(config: dict[str, Any]) -> None: - for key in GraphProjectRemoteRunner._SCHEMA_KEYS: - if key in config: - config[key] = {k: v.value for k, v in config[key].items()} diff --git a/graphdatascience/query_runner/arrow_graph_constructor.py b/graphdatascience/query_runner/arrow_graph_constructor.py index f40f66d8e..2997ef8b9 100644 --- a/graphdatascience/query_runner/arrow_graph_constructor.py +++ b/graphdatascience/query_runner/arrow_graph_constructor.py @@ -1,19 +1,17 @@ from __future__ import annotations import concurrent -import json import math import warnings from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, NoReturn, Optional import numpy -import pyarrow.flight as flight from pandas import DataFrame from pyarrow import Table from tqdm.auto import tqdm -from .arrow_endpoint_version import ArrowEndpointVersion +from .gds_arrow_client import GdsArrowClient from .graph_constructor import GraphConstructor @@ -22,9 +20,8 @@ def __init__( self, database: str, graph_name: str, - flight_client: flight.FlightClient, + flight_client: GdsArrowClient, concurrency: int, - arrow_endpoint_version: ArrowEndpointVersion, undirected_relationship_types: Optional[List[str]], chunk_size: int = 10_000, ): @@ -32,7 +29,6 @@ def __init__( self._concurrency = concurrency self._graph_name = graph_name self._client = flight_client - self._arrow_endpoint_version = arrow_endpoint_version self._undirected_relationship_types = ( [] if undirected_relationship_types is None else undirected_relationship_types ) @@ -49,20 +45,20 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N if self._undirected_relationship_types: config["undirected_relationship_types"] = self._undirected_relationship_types - self._send_action( + self._client.send_action( "CREATE_GRAPH", config, ) self._send_dfs(node_dfs, "node") - self._send_action("NODE_LOAD_DONE", {"name": self._graph_name}) + self._client.send_action("NODE_LOAD_DONE", {"name": self._graph_name}) self._send_dfs(relationship_dfs, "relationship") - self._send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name}) + self._client.send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name}) except (Exception, KeyboardInterrupt) as e: - self._send_action("ABORT", {"name": self._graph_name}) + self._client.send_action("ABORT", {"name": self._graph_name}) raise e @@ -85,25 +81,12 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]: return partitioned_dfs - def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None: - action_type = self._versioned_action_type(action_type) - result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) - - # Consume result fully to sanity check and avoid cancelled streams - collected_result = list(result) - assert len(collected_result) == 1 - - json.loads(collected_result[0].body.to_pybytes().decode()) - def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None: table = Table.from_pandas(df) batches = table.to_batches(self._chunk_size) flight_descriptor = {"name": self._graph_name, "entity_type": entity_type} - flight_descriptor = self._versioned_flight_desriptor(flight_descriptor) - # Write schema - upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) - writer, _ = self._client.do_put(upload_descriptor, table.schema) + writer, _ = self._client.start_put(flight_descriptor, table.schema) with writer: # Write table in chunks @@ -126,17 +109,3 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None: if not future.exception(): continue raise future.exception() # type: ignore - - def _versioned_action_type(self, action_type: str) -> str: - return self._arrow_endpoint_version.prefix() + action_type - - def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]: - return ( - flight_descriptor - if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA - else { - "name": "PUT_MESSAGE", - "version": ArrowEndpointVersion.V1.version(), - "body": flight_descriptor, - } - ) diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index 49c96b447..ad3d53ae3 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -1,21 +1,14 @@ from __future__ import annotations -import base64 -import json -import time import warnings from typing import Any, Dict, List, Optional, Tuple -import pyarrow.flight as flight from pandas import DataFrame -from pyarrow import ChunkedArray, Table, chunked_array -from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory -from pyarrow.types import is_dictionary # type: ignore from ..call_parameters import CallParameters from ..server_version.server_version import ServerVersion -from .arrow_endpoint_version import ArrowEndpointVersion from .arrow_graph_constructor import ArrowGraphConstructor +from .gds_arrow_client import GdsArrowClient from .graph_constructor import GraphConstructor from .query_runner import QueryRunner from graphdatascience.server_version.compatible_with import ( @@ -33,61 +26,29 @@ def create( tls_root_certs: Optional[bytes] = None, connection_string_override: Optional[str] = None, ) -> QueryRunner: - arrow_info = ( - fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() - ) - server_version = fallback_query_runner.server_version() - connection_string: str - if connection_string_override is not None: - connection_string = connection_string_override - else: - connection_string = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) - arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", [])) - - if arrow_info["running"]: - return ArrowQueryRunner( - connection_string, - fallback_query_runner, - server_version, - auth, - encrypted, - disable_server_verification, - tls_root_certs, - arrow_endpoint_version, - ) - else: + if not GdsArrowClient.is_arrow_enabled(fallback_query_runner): return fallback_query_runner + gds_arrow_client = GdsArrowClient.create( + fallback_query_runner, + auth, + encrypted, + disable_server_verification, + tls_root_certs, + connection_string_override, + ) + + return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version()) + def __init__( self, - uri: str, + gds_arrow_client: GdsArrowClient, fallback_query_runner: QueryRunner, server_version: ServerVersion, - auth: Optional[Tuple[str, str]] = None, - encrypted: bool = False, - disable_server_verification: bool = False, - tls_root_certs: Optional[bytes] = None, - arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA, ): self._fallback_query_runner = fallback_query_runner + self._gds_arrow_client = gds_arrow_client self._server_version = server_version - self._arrow_endpoint_version = arrow_endpoint_version - - host, port_string = uri.split(":") - - location = ( - flight.Location.for_grpc_tls(host, int(port_string)) - if encrypted - else flight.Location.for_grpc_tcp(host, int(port_string)) - ) - - client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} - if auth: - client_options["middleware"] = [AuthFactory(auth)] - if tls_root_certs: - client_options["tls_root_certs"] = tls_root_certs - - self._flight_client = flight.FlightClient(location, **client_options) def warn_about_deprecation(self, old_endpoint: str, new_endpoint: str) -> None: warnings.warn( @@ -140,7 +101,7 @@ def call_procedure( old_endpoint="gds.graph.streamNodeProperty", new_endpoint="gds.graph.nodeProperty.stream" ) - return self._run_arrow_property_get(graph_name, endpoint, config) + return self._gds_arrow_client.get_property(self.database(), graph_name, endpoint, config) elif ( old_endpoint := ("gds.graph.streamNodeProperties" == endpoint) ) or "gds.graph.nodeProperties.stream" == endpoint: @@ -159,7 +120,8 @@ def call_procedure( self.warn_about_deprecation( old_endpoint="gds.graph.streamNodeProperties", new_endpoint="gds.graph.nodeProperties.stream" ) - return self._run_arrow_property_get( + return self._gds_arrow_client.get_property( + self.database(), graph_name, endpoint, config, @@ -180,7 +142,8 @@ def call_procedure( old_endpoint="gds.graph.streamRelationshipProperty", new_endpoint="gds.graph.relationshipProperty.stream", ) - return self._run_arrow_property_get( + return self._gds_arrow_client.get_property( + self.database(), graph_name, endpoint, {"relationship_property": property_name, "relationship_types": relationship_types}, @@ -202,7 +165,8 @@ def call_procedure( new_endpoint="gds.graph.relationshipProperties.stream", ) - return self._run_arrow_property_get( + return self._gds_arrow_client.get_property( + self.database(), graph_name, endpoint, {"relationship_properties": property_names, "relationship_types": relationship_types}, @@ -229,7 +193,9 @@ def call_procedure( new_endpoint="gds.graph.relationships.stream", ) - return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types}) + return self._gds_arrow_client.get_property( + self.database(), graph_name, endpoint, {"relationship_types": relationship_types} + ) return self._fallback_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) @@ -259,52 +225,11 @@ def last_bookmarks(self) -> Optional[Any]: def close(self) -> None: self._fallback_query_runner.close() - # PyArrow 7 did not expose a close method yet - if hasattr(self._flight_client, "close"): - self._flight_client.close() + self._gds_arrow_client.close() def fallback_query_runner(self) -> QueryRunner: return self._fallback_query_runner - def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configuration: Dict[str, Any]) -> DataFrame: - if not self.database(): - raise ValueError( - "For this call you must have explicitly specified a valid Neo4j database to execute on, " - "using `GraphDataScience.set_database`." - ) - - payload = { - "database_name": self.database(), - "graph_name": graph_name, - "procedure_name": procedure_name, - "configuration": configuration, - } - - if self._arrow_endpoint_version == ArrowEndpointVersion.V1: - payload = { - "name": "GET_COMMAND", - "version": ArrowEndpointVersion.V1.version(), - "body": payload, - } - - ticket = flight.Ticket(json.dumps(payload).encode("utf-8")) - get = self._flight_client.do_get(ticket) - arrow_table = get.read_all() - - if configuration.get("list_node_labels", False): - # GDS 2.5 had an inconsistent naming of the node labels column - new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names] - arrow_table = arrow_table.rename_columns(new_colum_names) - - # Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message=r"Passing a BlockManager to DataFrame is deprecated", - ) - - return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore - def create_graph_constructor( self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]] ) -> GraphConstructor: @@ -318,86 +243,7 @@ def create_graph_constructor( return ArrowGraphConstructor( database, graph_name, - self._flight_client, + self._gds_arrow_client, concurrency, - self._arrow_endpoint_version, undirected_relationship_types, ) - - def _sanitize_arrow_table(self, arrow_table: Table) -> Table: - # empty columns cannot be used to build a chunked_array in pyarrow - if len(arrow_table) == 0: - return arrow_table - - dict_encoded_fields = [ - (idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type) - ] - for idx, field in dict_encoded_fields: - try: - field.type.to_pandas_dtype() - except NotImplementedError: - # we need to decode the dictionary column before transforming to pandas - if isinstance(arrow_table[field.name], ChunkedArray): - decoded_col = chunked_array([chunk.dictionary_decode() for chunk in arrow_table[field.name].chunks]) - else: - decoded_col = arrow_table[field.name].dictionary_decode() - arrow_table = arrow_table.set_column(idx, field.name, decoded_col) - return arrow_table - - -class AuthFactory(ClientMiddlewareFactory): # type: ignore - def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._auth = auth - self._token: Optional[str] = None - self._token_timestamp = 0 - - def start_call(self, info: Any) -> "AuthMiddleware": - return AuthMiddleware(self) - - def token(self) -> Optional[str]: - # check whether the token is older than 10 minutes. If so, reset it. - if self._token and int(time.time()) - self._token_timestamp > 600: - self._token = None - - return self._token - - def set_token(self, token: str) -> None: - self._token = token - self._token_timestamp = int(time.time()) - - @property - def auth(self) -> Tuple[str, str]: - return self._auth - - -class AuthMiddleware(ClientMiddleware): # type: ignore - def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._factory = factory - - def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header = headers.get("authorization", None) - if not auth_header: - return - - # the result is always a list - header_value = auth_header[0] - - if not isinstance(header_value, str): - raise ValueError(f"Incompatible header value received from server: `{header_value}`") - - auth_type, token = header_value.split(" ", 1) - if auth_type == "Bearer": - self._factory.set_token(token) - - def sending_headers(self) -> Dict[str, str]: - token = self._factory.token() - if not token: - username, password = self._factory.auth - auth_token = f"{username}:{password}" - auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") - # There seems to be a bug, `authorization` must be lower key - return {"authorization": auth_token} - else: - return {"authorization": "Bearer " + token} diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py deleted file mode 100644 index c301f10cc..000000000 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -from pandas import DataFrame, Series -from pyarrow import flight -from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory - -from ..call_parameters import CallParameters -from ..session.dbms_connection_info import DbmsConnectionInfo -from .query_runner import QueryRunner -from graphdatascience.query_runner.graph_constructor import GraphConstructor -from graphdatascience.server_version.server_version import ServerVersion - - -class AuraDbArrowQueryRunner(QueryRunner): - GDS_REMOTE_PROJECTION_PROC_NAME = "gds.graph.project.remoteDb" - - def __init__( - self, - gds_query_runner: QueryRunner, - db_query_runner: QueryRunner, - encrypted: bool, - aura_db_connection_info: DbmsConnectionInfo, - ): - self._gds_query_runner = gds_query_runner - self._db_query_runner = db_query_runner - self._auth = aura_db_connection_info.auth() - - arrow_info: "Series[Any]" = db_query_runner.call_procedure( - endpoint="internal.arrow.status", custom_error=False - ).squeeze() - - if not arrow_info.get("running"): - raise RuntimeError(f"The Arrow Server is not running at `{aura_db_connection_info.uri}`") - listen_address: Optional[str] = arrow_info.get("advertisedListenAddress") # type: ignore - if not listen_address: - raise ConnectionError("Did not retrieve connection info from database") - - host, port_string = listen_address.split(":") - - self._auth_pair_middleware = AuthPairInterceptingMiddleware() - client_options: Dict[str, Any] = { - "middleware": [AuthPairInterceptingMiddlewareFactory(self._auth_pair_middleware)], - "disable_server_verification": True, - } - - self._encrypted = encrypted - location = ( - flight.Location.for_grpc_tls(host, int(port_string)) - if self._encrypted - else flight.Location.for_grpc_tcp(host, int(port_string)) - ) - self._client = flight.FlightClient(location, **client_options) - - def run_cypher( - self, - query: str, - params: Optional[Dict[str, Any]] = None, - database: Optional[str] = None, - custom_error: bool = True, - ) -> DataFrame: - return self._db_query_runner.run_cypher(query, params, database, custom_error) - - def call_procedure( - self, - endpoint: str, - params: Optional[CallParameters] = None, - yields: Optional[List[str]] = None, - database: Optional[str] = None, - logging: bool = False, - custom_error: bool = True, - ) -> DataFrame: - if params is None: - params = CallParameters() - - if AuraDbArrowQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME == endpoint: - token, aura_db_arrow_endpoint = self._get_or_request_auth_pair() - params["token"] = token - params["host"] = aura_db_arrow_endpoint - params["config"]["useEncryption"] = self._encrypted - - elif ".write" in endpoint and self.is_remote_projected_graph(params["graph_name"]): - token, aura_db_arrow_endpoint = self._get_or_request_auth_pair() - host, port_string = aura_db_arrow_endpoint.split(":") - params["config"]["arrowConnectionInfo"] = { - "hostname": host, - "port": int(port_string), - "bearerToken": token, - "useEncryption": self._encrypted, - } - - return self._gds_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) - - def is_remote_projected_graph(self, graph_name: str) -> bool: - database_location: str = self._gds_query_runner.call_procedure( - endpoint="gds.graph.list", - yields=["databaseLocation"], - params=CallParameters(graph_name=graph_name), - ).squeeze() - return database_location == "remote" - - def server_version(self) -> ServerVersion: - return self._db_query_runner.server_version() - - def driver_config(self) -> Dict[str, Any]: - return self._db_query_runner.driver_config() - - def encrypted(self) -> bool: - return self._db_query_runner.encrypted() - - def set_database(self, database: str) -> None: - self._db_query_runner.set_database(database) - - def set_bookmarks(self, bookmarks: Optional[Any]) -> None: - self._db_query_runner.set_bookmarks(bookmarks) - - def bookmarks(self) -> Optional[Any]: - return self._db_query_runner.bookmarks() - - def last_bookmarks(self) -> Optional[Any]: - return self._db_query_runner.last_bookmarks() - - def database(self) -> Optional[str]: - return self._db_query_runner.database() - - def create_graph_constructor( - self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]] - ) -> GraphConstructor: - return self._gds_query_runner.create_graph_constructor(graph_name, concurrency, undirected_relationship_types) - - def close(self) -> None: - self._client.close() - self._gds_query_runner.close() - self._db_query_runner.close() - - def _get_or_request_auth_pair(self) -> Tuple[str, str]: - self._client.authenticate_basic_token(self._auth[0], self._auth[1]) - return (self._auth_pair_middleware.token(), self._auth_pair_middleware.endpoint()) - - -class AuthPairInterceptingMiddlewareFactory(ClientMiddlewareFactory): # type: ignore - def __init__(self, middleware: "AuthPairInterceptingMiddleware", *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._middleware = middleware - - def start_call(self, info: Any) -> "AuthPairInterceptingMiddleware": - return self._middleware - - -class AuthPairInterceptingMiddleware(ClientMiddleware): # type: ignore - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - - def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header = headers.get("authorization") - auth_type, token = self._read_auth_header(auth_header) - if auth_type == "Bearer": - self._token = token - - self._arrow_address = self._read_address_header(headers.get("arrowpluginaddress")) - - def sending_headers(self) -> Dict[str, str]: - return {} - - def token(self) -> str: - return self._token - - def endpoint(self) -> str: - return self._arrow_address - - def _read_auth_header(self, auth_header: Any) -> Tuple[str, str]: - if isinstance(auth_header, List): - auth_header = auth_header[0] - elif not isinstance(auth_header, str): - raise ValueError("Incompatible header format '{}'", auth_header) - - auth_type, token = auth_header.split(" ", 1) - return (str(auth_type), str(token)) - - def _read_address_header(self, address_header: Any) -> str: - if isinstance(address_header, List): - return str(address_header[0]) - if isinstance(address_header, str): - return address_header - raise ValueError("Incompatible header format '{}'", address_header) diff --git a/graphdatascience/query_runner/aura_db_query_runner.py b/graphdatascience/query_runner/aura_db_query_runner.py new file mode 100644 index 000000000..9ac2ad7bb --- /dev/null +++ b/graphdatascience/query_runner/aura_db_query_runner.py @@ -0,0 +1,224 @@ +import time +from typing import Any, Dict, List, Optional + +from pandas import DataFrame + +from ..call_parameters import CallParameters +from .gds_arrow_client import GdsArrowClient +from .query_runner import QueryRunner +from graphdatascience.query_runner.graph_constructor import GraphConstructor +from graphdatascience.server_version.server_version import ServerVersion + + +class AuraDbQueryRunner(QueryRunner): + GDS_REMOTE_PROJECTION_PROC_NAME = "gds.arrow.project" + + def __init__( + self, + gds_query_runner: QueryRunner, + db_query_runner: QueryRunner, + arrow_client: GdsArrowClient, + encrypted: bool, + ): + self._gds_query_runner = gds_query_runner + self._db_query_runner = db_query_runner + self._gds_arrow_client = arrow_client + self._encrypted = encrypted + + def run_cypher( + self, + query: str, + params: Optional[Dict[str, Any]] = None, + database: Optional[str] = None, + custom_error: bool = True, + ) -> DataFrame: + return self._db_query_runner.run_cypher(query, params, database, custom_error) + + def call_procedure( + self, + endpoint: str, + params: Optional[CallParameters] = None, + yields: Optional[List[str]] = None, + database: Optional[str] = None, + logging: bool = False, + custom_error: bool = True, + ) -> DataFrame: + if params is None: + params = CallParameters() + + if AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME == endpoint: + return self._remote_projection(endpoint, params, yields, database, logging, custom_error) + + elif ".write" in endpoint and self.is_remote_projected_graph(params["graph_name"]): + return self._remote_write_back(endpoint, params, yields, database, logging, custom_error) + + return self._gds_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error) + + def is_remote_projected_graph(self, graph_name: str) -> bool: + database_location: str = self._gds_query_runner.call_procedure( + endpoint="gds.graph.list", + yields=["databaseLocation"], + params=CallParameters(graph_name=graph_name), + ).squeeze() + return database_location == "remote" + + def server_version(self) -> ServerVersion: + return self._db_query_runner.server_version() + + def driver_config(self) -> Dict[str, Any]: + return self._db_query_runner.driver_config() + + def encrypted(self) -> bool: + return self._db_query_runner.encrypted() + + def set_database(self, database: str) -> None: + self._db_query_runner.set_database(database) + + def set_bookmarks(self, bookmarks: Optional[Any]) -> None: + self._db_query_runner.set_bookmarks(bookmarks) + + def bookmarks(self) -> Optional[Any]: + return self._db_query_runner.bookmarks() + + def last_bookmarks(self) -> Optional[Any]: + return self._db_query_runner.last_bookmarks() + + def database(self) -> Optional[str]: + return self._db_query_runner.database() + + def create_graph_constructor( + self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]] + ) -> GraphConstructor: + return self._gds_query_runner.create_graph_constructor(graph_name, concurrency, undirected_relationship_types) + + def close(self) -> None: + self._gds_arrow_client.close() + self._gds_query_runner.close() + self._db_query_runner.close() + + def _remote_projection( + self, + endpoint: str, + params: CallParameters, + yields: Optional[List[str]] = None, + database: Optional[str] = None, + logging: bool = False, + custom_error: bool = True, + ) -> DataFrame: + self._inject_connection_parameters(params) + return self._db_query_runner.call_procedure(endpoint, params, yields, database, logging, False) + + def _remote_write_back( + self, + endpoint: str, + params: CallParameters, + yields: Optional[List[str]] = None, + database: Optional[str] = None, + logging: bool = False, + custom_error: bool = True, + ) -> DataFrame: + if params["config"] is None: + params["config"] = {} + + params["config"]["writeToResultStore"] = True # type: ignore + gds_write_result = self._gds_query_runner.call_procedure( + endpoint, params, yields, database, logging, custom_error + ) + + write_params = { + "graphName": params["graph_name"], + "databaseName": self._gds_query_runner.database(), + "writeConfiguration": self._extract_write_back_arguments(endpoint, params), + } + self._inject_connection_parameters(write_params) + + write_back_start = time.time() + database_write_result = self._db_query_runner.call_procedure( + "gds.arrow.write", CallParameters(write_params), yields, None, False, False + ) + write_millis = (time.time() - write_back_start) * 1000 + gds_write_result["writeMillis"] = write_millis + + if "nodePropertiesWritten" in gds_write_result: + gds_write_result["nodePropertiesWritten"] = database_write_result["writtenNodeProperties"] + if "propertiesWritten" in gds_write_result: + gds_write_result["propertiesWritten"] = database_write_result["writtenNodeProperties"] + if "nodeLabelsWritten" in gds_write_result: + gds_write_result["nodeLabelsWritten"] = database_write_result["writtenNodeLabels"] + if "relationshipsWritten" in gds_write_result: + gds_write_result["relationshipsWritten"] = database_write_result["writtenRelationships"] + + return gds_write_result + + def _inject_connection_parameters(self, params: Dict[str, Any]) -> None: + host, port = self._gds_arrow_client.connection_info() + token = self._gds_arrow_client.request_token() + if token is None: + token = "IGNORED" + params["arrowConfiguration"] = { + "host": host, + "port": port, + "token": token, + "encrypted": self._encrypted, + } + + @staticmethod + def _extract_write_back_arguments(proc_name: str, params: Dict[str, Any]) -> Dict[str, Any]: + config = params.get("config", {}) + write_config = {} + + if "writeConcurrency" in config: + write_config["concurrency"] = config["writeConcurrency"] + elif "concurrency" in config: + write_config["concurrency"] = config["concurrency"] + + if "gds.shortestPath" in proc_name or "gds.allShortestPaths" in proc_name: + write_config["relationshipType"] = config["writeRelationshipType"] + + write_node_ids = config.get("writeNodeIds") + write_costs = config.get("writeCosts") + + if write_node_ids and write_costs: + write_config["relationshipProperties"] = ["totalCost", "nodeIds", "costs"] + elif write_node_ids: + write_config["relationshipProperties"] = ["totalCost", "nodeIds"] + elif write_costs: + write_config["relationshipProperties"] = ["totalCost", "costs"] + else: + write_config["relationshipProperties"] = ["totalCost"] + + elif "gds.graph." in proc_name: + if "gds.graph.nodeProperties.write" == proc_name: + properties = params["properties"] + write_config["nodeProperties"] = properties if isinstance(properties, list) else [properties] + write_config["nodeLabels"] = params["entities"] + + elif "gds.graph.nodeLabel.write" == proc_name: + write_config["nodeLabels"] = [params["node_label"]] + + elif "gds.graph.relationshipProperties.write" == proc_name: + write_config["relationshipProperties"] = params["relationship_properties"] + write_config["relationshipType"] = params["relationship_type"] + + elif "gds.graph.relationship.write" == proc_name: + if "relationship_property" in params and params["relationship_property"] != "": + write_config["relationshipProperties"] = [params["relationship_property"]] + write_config["relationshipType"] = params["relationship_type"] + + else: + raise ValueError(f"Unsupported procedure name: {proc_name}") + + else: + if "writeRelationshipType" in config: + write_config["relationshipType"] = config["writeRelationshipType"] + if "writeProperty" in config: + write_config["relationshipProperties"] = [config["writeProperty"]] + else: + if "writeProperty" in config: + write_config["nodeProperties"] = [config["writeProperty"]] + if "nodeLabels" in params: + write_config["nodeLabels"] = params["nodeLabels"] + else: + write_config["nodeLabels"] = ["*"] + + return write_config diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py new file mode 100644 index 000000000..b06ead80e --- /dev/null +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -0,0 +1,242 @@ +import base64 +import json +import time +import warnings +from typing import Any, Dict, Optional, Tuple + +from pandas import DataFrame +from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight +from pyarrow._flight import FlightStreamReader, FlightStreamWriter +from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory +from pyarrow.types import is_dictionary + +from ..server_version.server_version import ServerVersion +from .arrow_endpoint_version import ArrowEndpointVersion +from .query_runner import QueryRunner + + +class GdsArrowClient: + @staticmethod + def is_arrow_enabled(query_runner: QueryRunner) -> bool: + arrow_info = query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() + return not not arrow_info["running"] + + @staticmethod + def create( + query_runner: QueryRunner, + auth: Optional[Tuple[str, str]] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + connection_string_override: Optional[str] = None, + ) -> "GdsArrowClient": + arrow_info = query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict() + + server_version = query_runner.server_version() + connection_string: str + if connection_string_override is not None: + connection_string = connection_string_override + else: + connection_string = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) + + host, port = connection_string.split(":") + + arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", [])) + + return GdsArrowClient( + host, + int(port), + server_version, + auth, + encrypted, + disable_server_verification, + tls_root_certs, + arrow_endpoint_version, + ) + + def __init__( + self, + host: str, + port: int, + server_version: ServerVersion, + auth: Optional[Tuple[str, str]] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA, + ): + self._server_version = server_version + self._arrow_endpoint_version = arrow_endpoint_version + self._host = host + self._port = port + self._auth = auth + + location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port) + + client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} + if auth: + self._auth_middleware = AuthMiddleware(auth) + client_options["middleware"] = [AuthFactory(self._auth_middleware)] + if tls_root_certs: + client_options["tls_root_certs"] = tls_root_certs + + self._flight_client = flight.FlightClient(location, **client_options) + + def connection_info(self) -> Tuple[str, int]: + return self._host, self._port + + def request_token(self) -> Optional[str]: + if self._auth: + self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + return self._auth_middleware.token() + else: + return "IGNORED" + + def get_property( + self, database: Optional[str], graph_name: str, procedure_name: str, configuration: Dict[str, Any] + ) -> DataFrame: + if not database: + raise ValueError( + "For this call you must have explicitly specified a valid Neo4j database to execute on, " + "using `GraphDataScience.set_database`." + ) + + payload = { + "database_name": database, + "graph_name": graph_name, + "procedure_name": procedure_name, + "configuration": configuration, + } + + if self._arrow_endpoint_version == ArrowEndpointVersion.V1: + payload = { + "name": "GET_COMMAND", + "version": ArrowEndpointVersion.V1.version(), + "body": payload, + } + + ticket = flight.Ticket(json.dumps(payload).encode("utf-8")) + get = self._flight_client.do_get(ticket) + arrow_table = get.read_all() + + if configuration.get("list_node_labels", False): + # GDS 2.5 had an inconsistent naming of the node labels column + new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names] + arrow_table = arrow_table.rename_columns(new_colum_names) + + # Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Passing a BlockManager to DataFrame is deprecated", + ) + + return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore + + def send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None: + action_type = self._versioned_action_type(action_type) + result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) + + # Consume result fully to sanity check and avoid cancelled streams + collected_result = list(result) + assert len(collected_result) == 1 + + json.loads(collected_result[0].body.to_pybytes().decode()) + + def start_put(self, payload: Dict[str, Any], schema: Schema) -> Tuple[FlightStreamWriter, FlightStreamReader]: + flight_descriptor = self._versioned_flight_descriptor(payload) + upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) + return self._flight_client.do_put(upload_descriptor, schema) # type: ignore + + def close(self) -> None: + self._flight_client.close() + + def _versioned_action_type(self, action_type: str) -> str: + return self._arrow_endpoint_version.prefix() + action_type + + def _versioned_flight_descriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]: + return ( + flight_descriptor + if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA + else { + "name": "PUT_MESSAGE", + "version": ArrowEndpointVersion.V1.version(), + "body": flight_descriptor, + } + ) + + @staticmethod + def _sanitize_arrow_table(arrow_table: Table) -> Table: + # empty columns cannot be used to build a chunked_array in pyarrow + if len(arrow_table) == 0: + return arrow_table + + dict_encoded_fields = [ + (idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type) + ] + + for idx, field in dict_encoded_fields: + try: + field.type.to_pandas_dtype() + except NotImplementedError: + # we need to decode the dictionary column before transforming to pandas + if isinstance(arrow_table[field.name], ChunkedArray): + decoded_col = chunked_array([chunk.dictionary_decode() for chunk in arrow_table[field.name].chunks]) + else: + decoded_col = arrow_table[field.name].dictionary_decode() + arrow_table = arrow_table.set_column(idx, field.name, decoded_col) + return arrow_table + + +class AuthFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = middleware + + def start_call(self, info: Any) -> "AuthMiddleware": + return self._middleware + + +class AuthMiddleware(ClientMiddleware): # type: ignore + def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._auth = auth + self._token: Optional[str] = None + self._token_timestamp = 0 + + def token(self) -> Optional[str]: + # check whether the token is older than 10 minutes. If so, reset it. + if self._token and int(time.time()) - self._token_timestamp > 600: + self._token = None + + return self._token + + def _set_token(self, token: str) -> None: + self._token = token + self._token_timestamp = int(time.time()) + + def received_headers(self, headers: Dict[str, Any]) -> None: + auth_header = headers.get("authorization", None) + if not auth_header: + return + + # the result is always a list + header_value = auth_header[0] + + if not isinstance(header_value, str): + raise ValueError(f"Incompatible header value received from server: `{header_value}`") + + auth_type, token = header_value.split(" ", 1) + if auth_type == "Bearer": + self._set_token(token) + + def sending_headers(self) -> Dict[str, str]: + token = self.token() + if not token: + username, password = self._auth + auth_token = f"{username}:{password}" + auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") + # There seems to be a bug, `authorization` must be lower key + return {"authorization": auth_token} + else: + return {"authorization": "Bearer " + token} diff --git a/graphdatascience/session/aura_api.py b/graphdatascience/session/aura_api.py index 3f33e2409..4ce34d18e 100644 --- a/graphdatascience/session/aura_api.py +++ b/graphdatascience/session/aura_api.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from urllib.parse import urlparse import requests as req @@ -28,7 +28,7 @@ class AuraAuthToken: expires_in: int token_type: str - def __init__(self, json: dict[str, Any]) -> None: + def __init__(self, json: Dict[str, Any]) -> None: self.access_token = json["access_token"] expires_in: int = json["expires_in"] self.expires_at = int(time.time()) + expires_in @@ -275,7 +275,7 @@ def tenant_details(self) -> TenantDetails: self._tenant_details = TenantDetails.from_json(response.json()["data"]) return self._tenant_details - def _build_header(self) -> dict[str, str]: + def _build_header(self) -> Dict[str, str]: return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"} def _auth_token(self) -> str: diff --git a/graphdatascience/session/aura_api_responses.py b/graphdatascience/session/aura_api_responses.py index 239086959..008a26cb1 100644 --- a/graphdatascience/session/aura_api_responses.py +++ b/graphdatascience/session/aura_api_responses.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any, NamedTuple, Optional, Set +from typing import Any, Dict, NamedTuple, Optional, Set @dataclass(repr=True, frozen=True) @@ -20,7 +20,7 @@ class SessionDetails: created_at: datetime @classmethod - def fromJson(cls, json: dict[str, Any]) -> SessionDetails: + def fromJson(cls, json: Dict[str, Any]) -> SessionDetails: expiry_date = json.get("expiry_date") return cls( @@ -46,7 +46,7 @@ class InstanceDetails: cloud_provider: str @classmethod - def fromJson(cls, json: dict[str, Any]) -> InstanceDetails: + def fromJson(cls, json: Dict[str, Any]) -> InstanceDetails: return cls( id=json["id"], name=json["name"], @@ -64,7 +64,7 @@ class InstanceSpecificDetails(InstanceDetails): region: str @classmethod - def fromJson(cls, json: dict[str, Any]) -> InstanceSpecificDetails: + def fromJson(cls, json: Dict[str, Any]) -> InstanceSpecificDetails: return cls( id=json["id"], name=json["name"], @@ -86,7 +86,7 @@ class InstanceCreateDetails: connection_url: str @classmethod - def from_json(cls, json: dict[str, Any]) -> InstanceCreateDetails: + def from_json(cls, json: Dict[str, Any]) -> InstanceCreateDetails: fields = dataclasses.fields(cls) if any(f.name not in json for f in fields): raise RuntimeError(f"Missing required field. Expected `{[f.name for f in fields]}` but got `{json}`") @@ -101,7 +101,7 @@ class EstimationDetails: did_exceed_maximum: bool @classmethod - def from_json(cls, json: dict[str, Any]) -> EstimationDetails: + def from_json(cls, json: Dict[str, Any]) -> EstimationDetails: fields = dataclasses.fields(cls) if any(f.name not in json for f in fields): raise RuntimeError(f"Missing required field. Expected `{[f.name for f in fields]}` but got `{json}`") @@ -126,10 +126,10 @@ def from_connection_url(cls, connection_url: str) -> WaitResult: class TenantDetails: id: str ds_type: str - regions_per_provider: dict[str, Set[str]] + regions_per_provider: Dict[str, Set[str]] @classmethod - def from_json(cls, json: dict[str, Any]) -> TenantDetails: + def from_json(cls, json: Dict[str, Any]) -> TenantDetails: regions_per_provider = defaultdict(set) instance_types = set() ds_type = None diff --git a/graphdatascience/session/aura_graph_data_science.py b/graphdatascience/session/aura_graph_data_science.py index 34a748ba9..c932f2899 100644 --- a/graphdatascience/session/aura_graph_data_science.py +++ b/graphdatascience/session/aura_graph_data_science.py @@ -7,9 +7,8 @@ from graphdatascience.error.uncallable_namespace import UncallableNamespace from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner -from graphdatascience.query_runner.aura_db_arrow_query_runner import ( - AuraDbArrowQueryRunner, -) +from graphdatascience.query_runner.aura_db_query_runner import AuraDbQueryRunner +from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.server_version.server_version import ServerVersion from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo @@ -63,8 +62,11 @@ def __init__( gds_query_runner.set_database("neo4j") self._db_query_runner.set_database("neo4j") - self._query_runner = AuraDbArrowQueryRunner( - gds_query_runner, self._db_query_runner, self._db_query_runner.encrypted(), aura_db_connection_info + arrow_client = GdsArrowClient.create( + gds_query_runner, aura_db_connection_info.auth(), self._db_query_runner.encrypted() + ) + self._query_runner = AuraDbQueryRunner( + gds_query_runner, self._db_query_runner, arrow_client, self._db_query_runner.encrypted() ) self._delete_fn = delete_fn diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index 7d9da3b7c..3227e4e73 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -14,9 +14,9 @@ def run_around_tests(gds_with_cloud_setup: AuraGraphDataScience) -> Generator[No gds_with_cloud_setup.run_cypher( """ CREATE - (a: Node {x: 1, y: 2, z: [42], name: "nodeA"}), - (b: Node {x: 2, y: 3, z: [1337], name: "nodeB"}), - (c: Node {x: 3, y: 4, z: [9], name: "nodeC"}), + (a: A:Node {x: 1, y: 2, z: [42], name: "nodeA"}), + (b: B:Node {x: 2, y: 3, z: [1337], name: "nodeB"}), + (c: C:Node {x: 3, y: 4, z: [9], name: "nodeC"}), (a)-[:REL {relX: 4, relY: 5}]->(b), (a)-[:REL {relX: 5, relY: 6}]->(c), (b)-[:REL {relX: 6, relY: 7}]->(c), @@ -35,7 +35,7 @@ def run_around_tests(gds_with_cloud_setup: AuraGraphDataScience) -> Generator[No @pytest.mark.cloud_architecture -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0)) +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None: G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") @@ -44,10 +44,125 @@ def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None: @pytest.mark.cloud_architecture -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0)) -def test_remote_write_back(gds_with_cloud_setup: AuraGraphDataScience) -> None: +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_page_rank(gds_with_cloud_setup: AuraGraphDataScience) -> None: G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") result = gds_with_cloud_setup.pageRank.write(G, writeProperty="score") assert result["nodePropertiesWritten"] == 3 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_similarity(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") + + result = gds_with_cloud_setup.nodeSimilarity.write( + G, writeRelationshipType="SIMILAR", writeProperty="score", similarityCutoff=0 + ) + + assert result["relationshipsWritten"] == 4 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_properties(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") + result = gds_with_cloud_setup.pageRank.mutate(G, mutateProperty="score") + result = gds_with_cloud_setup.graph.nodeProperties.write(G, node_properties=["score"]) + + assert result["propertiesWritten"] == 3 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_properties_with_multiple_labels(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, + "MATCH (n)-->(m) " + "RETURN gds.graph.project.remote(n, m, {sourceNodeLabels: labels(n), targetNodeLabels: labels(m)})", + ) + result = gds_with_cloud_setup.pageRank.write(G, writeProperty="score") + + assert result["nodePropertiesWritten"] == 3 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_properties_with_select_labels(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, + "MATCH (n)-->(m) " + "RETURN gds.graph.project.remote(n, m, {sourceNodeLabels: labels(n), targetNodeLabels: labels(m)})", + ) + result = gds_with_cloud_setup.pageRank.mutate(G, mutateProperty="score") + result = gds_with_cloud_setup.graph.nodeProperties.write(G, "score", ["A"]) + + assert result["propertiesWritten"] == 1 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_node_label(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") + result = gds_with_cloud_setup.graph.nodeLabel.write(G, "Foo", nodeFilter="*") + + assert result["nodeLabelsWritten"] == 3 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_relationship_topology(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m, {relationshipType: 'FOO'})" + ) + result = gds_with_cloud_setup.graph.relationship.write(G, "FOO") + + assert result["relationshipsWritten"] == 4 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_relationship_property(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, + "MATCH (n)-->(m) " + "RETURN gds.graph.project.remote(n, m, {relationshipType: 'FOO', relationshipProperties: {bar: 42}})", + ) + result = gds_with_cloud_setup.graph.relationship.write(G, "FOO", "bar") + + assert result["relationshipsWritten"] == 4 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_relationship_properties(gds_with_cloud_setup: AuraGraphDataScience) -> None: + G, result = gds_with_cloud_setup.graph.project( + GRAPH_NAME, + "MATCH (n)-->(m) " + "RETURN gds.graph.project.remote(" + " n, " + " m, " + " {relationshipType: 'FOO', relationshipProperties: {bar: 42, foo: 1337}}" + ")", + ) + result = gds_with_cloud_setup.graph.relationshipProperties.write(G, "FOO", ["bar", "foo"]) + + assert result["relationshipsWritten"] == 4 + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0)) +def test_remote_write_back_relationship_property_from_pathfinding_algo( + gds_with_cloud_setup: AuraGraphDataScience, +) -> None: + G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") + + source = gds_with_cloud_setup.find_node_id(properties={"x": 1}) + target = gds_with_cloud_setup.find_node_id(properties={"x": 2}) + result = gds_with_cloud_setup.shortestPath.dijkstra.write( + G, sourceNode=source, targetNodes=target, writeRelationshipType="PATH", writeCosts=True + ) + + assert result["relationshipsWritten"] == 1 diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 05709d11e..00f8c5052 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -115,10 +115,9 @@ def gds(runner: CollectingQueryRunner) -> Generator[GraphDataScience, None, None @pytest.fixture def aura_gds(runner: CollectingQueryRunner, mocker: MockerFixture) -> Generator[AuraGraphDataScience, None, None]: mocker.patch("graphdatascience.query_runner.neo4j_query_runner.Neo4jQueryRunner.create", return_value=runner) - mocker.patch( - "graphdatascience.query_runner.aura_db_arrow_query_runner.AuraDbArrowQueryRunner.__new__", return_value=runner - ) + mocker.patch("graphdatascience.query_runner.aura_db_query_runner.AuraDbQueryRunner.__new__", return_value=runner) mocker.patch("graphdatascience.query_runner.arrow_query_runner.ArrowQueryRunner.create", return_value=runner) + mocker.patch("graphdatascience.query_runner.gds_arrow_client.GdsArrowClient.create", return_value=None) aura_gds = AuraGraphDataScience( gds_session_connection_info=DbmsConnectionInfo("address", "some", "auth"), aura_db_connection_info=DbmsConnectionInfo("address", "some", "auth"), diff --git a/graphdatascience/tests/unit/test_arrow_runner.py b/graphdatascience/tests/unit/test_arrow_runner.py index af8284724..204e0b175 100644 --- a/graphdatascience/tests/unit/test_arrow_runner.py +++ b/graphdatascience/tests/unit/test_arrow_runner.py @@ -3,11 +3,7 @@ from pyarrow.flight import FlightUnavailableError from .conftest import CollectingQueryRunner -from graphdatascience.query_runner.arrow_query_runner import ( - ArrowQueryRunner, - AuthFactory, - AuthMiddleware, -) +from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner from graphdatascience.server_version.server_version import ServerVersion @@ -20,7 +16,7 @@ def test_create(runner: CollectingQueryRunner) -> None: assert isinstance(arrow_runner, ArrowQueryRunner) with pytest.raises(FlightUnavailableError, match=".+ failed to connect .+ ipv4:127.0.0.1:1234: .+"): - arrow_runner._flight_client.list_actions() + arrow_runner._gds_arrow_client.send_action("TEST", {}) @pytest.mark.parametrize("server_version", [ServerVersion(2, 6, 0)]) @@ -41,32 +37,4 @@ def test_create_with_provided_connection(runner: CollectingQueryRunner) -> None: assert isinstance(arrow_runner, ArrowQueryRunner) with pytest.raises(FlightUnavailableError, match=".+ failed to connect .+ ipv4:127.0.0.1:4321: .+"): - arrow_runner._flight_client.list_actions() - - -def test_auth_middleware() -> None: - factory = AuthFactory(("user", "password")) - middleware = AuthMiddleware(factory) - - first_header = middleware.sending_headers() - assert first_header == {"authorization": "Basic dXNlcjpwYXNzd29yZA=="} - - middleware.received_headers({"authorization": ["Bearer token"]}) - assert factory._token == "token" - - second_header = middleware.sending_headers() - assert second_header == {"authorization": "Bearer token"} - - middleware.received_headers({}) - assert factory._token == "token" - - second_header = middleware.sending_headers() - assert second_header == {"authorization": "Bearer token"} - - -def test_auth_middleware_bad_headers() -> None: - factory = AuthFactory(("user", "password")) - middleware = AuthMiddleware(factory) - - with pytest.raises(ValueError, match="Incompatible header value received from server: `12342`"): - middleware.received_headers({"authorization": [12342]}) + arrow_runner._gds_arrow_client.send_action("TEST", {}) diff --git a/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py b/graphdatascience/tests/unit/test_aura_db_arrow_query_runner.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/test_gds_arrow_client.py b/graphdatascience/tests/unit/test_gds_arrow_client.py new file mode 100644 index 000000000..5e881179b --- /dev/null +++ b/graphdatascience/tests/unit/test_gds_arrow_client.py @@ -0,0 +1,29 @@ +import pytest + +from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware + + +def test_auth_middleware() -> None: + middleware = AuthMiddleware(("user", "password")) + + first_header = middleware.sending_headers() + assert first_header == {"authorization": "Basic dXNlcjpwYXNzd29yZA=="} + + middleware.received_headers({"authorization": ["Bearer token"]}) + assert middleware._token == "token" + + second_header = middleware.sending_headers() + assert second_header == {"authorization": "Bearer token"} + + middleware.received_headers({}) + assert middleware._token == "token" + + second_header = middleware.sending_headers() + assert second_header == {"authorization": "Bearer token"} + + +def test_auth_middleware_bad_headers() -> None: + middleware = AuthMiddleware(("user", "password")) + + with pytest.raises(ValueError, match="Incompatible header value received from server: `12342`"): + middleware.received_headers({"authorization": [12342]}) diff --git a/graphdatascience/tests/unit/test_graph_ops.py b/graphdatascience/tests/unit/test_graph_ops.py index 62f13c031..eaeb7bf57 100644 --- a/graphdatascience/tests/unit/test_graph_ops.py +++ b/graphdatascience/tests/unit/test_graph_ops.py @@ -1,7 +1,6 @@ import pytest from pandas import DataFrame -from ...session.schema import GdsPropertyTypes from .conftest import CollectingQueryRunner from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.server_version.server_version import ServerVersion @@ -91,22 +90,22 @@ def test_project_subgraph(runner: CollectingQueryRunner, gds: GraphDataScience) } -@pytest.mark.parametrize("server_version", [ServerVersion(2, 6, 0)]) +@pytest.mark.parametrize("server_version", [ServerVersion(2, 7, 0)]) def test_project_remote(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None: aura_gds.graph.project("g", "RETURN gds.graph.project.remote(0, 1, null)") assert ( runner.last_query() - == "CALL gds.graph.project.remoteDb($graph_name, $query, $token, $host, $remote_database, $config)" + == "CALL gds.arrow.project(" + + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" ) # injection of token and host into the params is done by the actual query runner assert runner.last_params() == { "graph_name": "g", - "token": "<>", - "host": "<>", + "concurrency": 4, + "inverse_indexed_relationship_types": [], "query": "RETURN gds.graph.project.remote(0, 1, null)", - "remote_database": "neo4j", - "config": {}, + "undirected_relationship_types": [], } @@ -703,26 +702,7 @@ def test_graph_relationships_to_undirected(runner: CollectingQueryRunner, gds: G } -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0)) -def test_remote_projection_on_specific_database(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None: - aura_gds.set_database("bar") - G, _ = aura_gds.graph.project("g", "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)") - - assert ( - runner.last_query() - == "CALL gds.graph.project.remoteDb($graph_name, $query, $token, $host, $remote_database, $config)" - ) - assert runner.last_params() == { - "graph_name": "g", - "token": "<>", - "host": "<>", - "query": "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)", - "remote_database": "bar", - "config": {}, - } - - -@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 6, 0)) +@pytest.mark.parametrize("server_version", [ServerVersion(2, 7, 0)]) def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None: G, _ = aura_gds.graph.project( graph_name="g", @@ -734,21 +714,20 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura relationshipProperties: {y: [1, 2]} }) """, - undirectedRelationshipTypes=["R"], - inverseIndexedRelationshipTypes=["R"], - nodePropertySchema={"x": GdsPropertyTypes.LONG}, - relationshipPropertySchema={"y": GdsPropertyTypes.LONG_ARRAY}, + concurrency=8, + undirected_relationship_types=["R"], + inverse_indexed_relationship_types=["R"], ) assert ( runner.last_query() - == "CALL gds.graph.project.remoteDb($graph_name, $query, $token, $host, $remote_database, $config)" + == "CALL gds.arrow.project(" + + "$graph_name, $query, $concurrency, $undirected_relationship_types, $inverse_indexed_relationship_types)" ) + assert runner.last_params() == { "graph_name": "g", - "token": "<>", - "host": "<>", - "remote_database": "neo4j", + "concurrency": 8, "query": """ MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m, { @@ -757,10 +736,6 @@ def test_remote_projection_all_configuration(runner: CollectingQueryRunner, aura relationshipProperties: {y: [1, 2]} }) """, - "config": { - "undirectedRelationshipTypes": ["R"], - "inverseIndexedRelationshipTypes": ["R"], - "nodePropertySchema": {"x": "long"}, - "relationshipPropertySchema": {"y": "long[]"}, - }, + "undirected_relationship_types": ["R"], + "inverse_indexed_relationship_types": ["R"], } diff --git a/mypy.ini b/mypy.ini index e4ddbba40..3800f678f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -8,6 +8,12 @@ ignore_missing_imports = True [mypy-pyarrow.flight] ignore_missing_imports = True +[mypy-pyarrow._flight] +ignore_missing_imports = True + +[mypy-pyarrow.types] +ignore_missing_imports = True + [mypy-textdistance] ignore_missing_imports = True