diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index ae148e530..0562cb2cf 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -7,7 +7,7 @@ import warnings from dataclasses import dataclass from types import TracebackType -from typing import Any, Callable, Iterable, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, Optional, Type, Union import pyarrow from neo4j.exceptions import ClientError @@ -89,19 +89,35 @@ def __init__( self._host = host self._port = port self._auth = auth + self._encrypted = encrypted + self._disable_server_verification = disable_server_verification + self._tls_root_certs = tls_root_certs + self._user_agent = user_agent - location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port) - - client_options: dict[str, Any] = {"disable_server_verification": disable_server_verification} if auth: self._auth_middleware = AuthMiddleware(auth) - if not user_agent: - user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" - client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)] - if tls_root_certs: - client_options["tls_root_certs"] = tls_root_certs - self._flight_client = flight.FlightClient(location, **client_options) + self._flight_client = self._instantiate_flight_client() + + def _instantiate_flight_client(self) -> flight.FlightClient: + location = ( + flight.Location.for_grpc_tls(self._host, self._port) + if self._encrypted + else flight.Location.for_grpc_tcp(self._host, self._port) + ) + client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} + if self._auth: + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" + if self._user_agent: + user_agent = self._user_agent + + client_options["middleware"] = [ + AuthFactory(self._auth_middleware), + UserAgentFactory(useragent=user_agent), + ] + if self._tls_root_certs: + client_options["tls_root_certs"] = self._tls_root_certs + return flight.FlightClient(location, **client_options) def connection_info(self) -> tuple[str, int]: """ @@ -537,11 +553,28 @@ def upload_triplets( """ self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback) + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + # Remove the FlightClient as it isn't serializable + if "_flight_client" in state: + del state["_flight_client"] + return state + + def _client(self) -> flight.FlightClient: + """ + Lazy client construction to help pickle this class because a PyArrow + FlightClient is not serializable. + """ + if not hasattr(self, "_flight_client") or not self._flight_client: + self._flight_client = self._instantiate_flight_client() + return self._flight_client + def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]: action_type = self._versioned_action_type(action_type) try: - result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) + client = self._client() + result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) # Consume result fully to sanity check and avoid cancelled streams collected_result = list(result) @@ -569,7 +602,9 @@ def _upload_data( flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type}) upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) - put_stream, ack_stream = self._flight_client.do_put(upload_descriptor, batches[0].schema) + + client = self._client() + put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema) @retry( stop=(stop_after_delay(10) | stop_after_attempt(5)),