From 16fc227445840b4304ab9fbb413d031fbade84cf Mon Sep 17 00:00:00 2001 From: dhrudevalia Date: Mon, 9 Dec 2024 16:18:28 +0000 Subject: [PATCH 1/5] fix: remove flight client from serialisation --- .../query_runner/gds_arrow_client.py | 81 ++++++++++++++++--- 1 file changed, 68 insertions(+), 13 deletions(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index ae148e530..9f9baa10f 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -7,7 +7,7 @@ import warnings from dataclasses import dataclass from types import TracebackType -from typing import Any, Callable, Iterable, Optional, Type, Union +from typing import Any, Callable, Iterable, Optional, Type, Union, Dict import pyarrow from neo4j.exceptions import ClientError @@ -89,19 +89,32 @@ def __init__( self._host = host self._port = port self._auth = auth + self._encrypted = encrypted + self._disable_server_verification = disable_server_verification + self._tls_root_certs = tls_root_certs + self._user_agent = user_agent - location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port) - - client_options: dict[str, Any] = {"disable_server_verification": disable_server_verification} if auth: self._auth_middleware = AuthMiddleware(auth) - if not user_agent: - user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" - client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)] - if tls_root_certs: - client_options["tls_root_certs"] = tls_root_certs + if not self._user_agent: + self._user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" + + self._flight_client = self._instantiate_flight_client() - self._flight_client = flight.FlightClient(location, **client_options) + def _instantiate_flight_client(self) -> flight.FlightClient: + print("inside instantiate_flight_client") + location = flight.Location.for_grpc_tls(self._host, self._port) if self._encrypted else flight.Location.for_grpc_tcp(self._host, self._port) + print("after flight location") + client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} + print("after client options init") + if self._auth: + client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=self._user_agent)] + print("self. auth middleware") + if self._tls_root_certs: + client_options["tls_root_certs"] = self._tls_root_certs + print("self.tls roots") + print("returning flight client") + return flight.FlightClient(location, **client_options) def connection_info(self) -> tuple[str, int]: """ @@ -536,12 +549,46 @@ def upload_triplets( A callback function that is called with the number of rows uploaded after each batch """ self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback) + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + # Remove the FlightClient as it isn't serializable + if "_flight_client" in state: + del state["_flight_client"] + return state + + # def copy(self) -> "GDSArrowClient": + # client = GdsArrowClient( + # self._host, + # self._port, + # self._auth, + # self._encrypted, + # self._disable_server_verification, + # self._tls_root_certs, + # self._arrow_endpoint_version, + # self._user_agent + # ) + # client.state = self.state + # return client + + + def _client(self) -> flight.FlightClient: + """ + Lazy client construction to help pickle this class because a PyArrow + FlightClient is not serializable. + """ + print("checking flight client") + if not hasattr(self, "_flight_client") or not self._flight_client: + print("instantiating flight client") + self._flight_client = self._instantiate_flight_client() + return self._flight_client def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]: action_type = self._versioned_action_type(action_type) try: - result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) + client = self._client() + result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) # Consume result fully to sanity check and avoid cancelled streams collected_result = list(result) @@ -569,7 +616,13 @@ def _upload_data( flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type}) upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) - put_stream, ack_stream = self._flight_client.do_put(upload_descriptor, batches[0].schema) + + print(f"before do_put data: {data}") + client = self._client() + print("After flight client instantiation") + put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema) + + print(f"after do_put data: {data}") @retry( stop=(stop_after_delay(10) | stop_after_attempt(5)), @@ -585,10 +638,12 @@ def upload_batch(p: RecordBatch) -> None: try: with put_stream: - for partition in batches: + for idx, partition in enumerate(batches): + print(f"batch {idx}") upload_batch(partition) ack_stream.read() progress_callback(partition.num_rows) + print(f"upload completed for {data}") except Exception as e: GdsArrowClient.handle_flight_error(e) From 587893c512e0ac7570e3394c840de8743a7fc7df Mon Sep 17 00:00:00 2001 From: emrehizal Date: Wed, 18 Dec 2024 11:39:34 +0100 Subject: [PATCH 2/5] chore: polish the changes --- .../query_runner/gds_arrow_client.py | 33 ++----------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 9f9baa10f..c09358cf6 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -102,18 +102,12 @@ def __init__( self._flight_client = self._instantiate_flight_client() def _instantiate_flight_client(self) -> flight.FlightClient: - print("inside instantiate_flight_client") location = flight.Location.for_grpc_tls(self._host, self._port) if self._encrypted else flight.Location.for_grpc_tcp(self._host, self._port) - print("after flight location") client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} - print("after client options init") if self._auth: client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=self._user_agent)] - print("self. auth middleware") if self._tls_root_certs: client_options["tls_root_certs"] = self._tls_root_certs - print("self.tls roots") - print("returning flight client") return flight.FlightClient(location, **client_options) def connection_info(self) -> tuple[str, int]: @@ -556,30 +550,13 @@ def __getstate__(self) -> Dict[str, Any]: if "_flight_client" in state: del state["_flight_client"] return state - - # def copy(self) -> "GDSArrowClient": - # client = GdsArrowClient( - # self._host, - # self._port, - # self._auth, - # self._encrypted, - # self._disable_server_verification, - # self._tls_root_certs, - # self._arrow_endpoint_version, - # self._user_agent - # ) - # client.state = self.state - # return client - - + def _client(self) -> flight.FlightClient: """ Lazy client construction to help pickle this class because a PyArrow FlightClient is not serializable. """ - print("checking flight client") if not hasattr(self, "_flight_client") or not self._flight_client: - print("instantiating flight client") self._flight_client = self._instantiate_flight_client() return self._flight_client @@ -617,13 +594,9 @@ def _upload_data( flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type}) upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) - print(f"before do_put data: {data}") client = self._client() - print("After flight client instantiation") put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema) - print(f"after do_put data: {data}") - @retry( stop=(stop_after_delay(10) | stop_after_attempt(5)), wait=wait_exponential(multiplier=1, min=1, max=10), @@ -638,12 +611,10 @@ def upload_batch(p: RecordBatch) -> None: try: with put_stream: - for idx, partition in enumerate(batches): - print(f"batch {idx}") + for partition in batches: upload_batch(partition) ack_stream.read() progress_callback(partition.num_rows) - print(f"upload completed for {data}") except Exception as e: GdsArrowClient.handle_flight_error(e) From 58d0c5541ed1d56186aaaeae3cf4a2569742312e Mon Sep 17 00:00:00 2001 From: emrehizal Date: Wed, 18 Dec 2024 14:30:19 +0100 Subject: [PATCH 3/5] style: fix import sort --- graphdatascience/query_runner/gds_arrow_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index c09358cf6..8e5301977 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -7,7 +7,7 @@ import warnings from dataclasses import dataclass from types import TracebackType -from typing import Any, Callable, Iterable, Optional, Type, Union, Dict +from typing import Any, Callable, Dict, Iterable, Optional, Type, Union import pyarrow from neo4j.exceptions import ClientError From 2fcfe7b21a87208a1748770662d53bdaf35abb15 Mon Sep 17 00:00:00 2001 From: emrehizal Date: Wed, 18 Dec 2024 14:46:09 +0100 Subject: [PATCH 4/5] style: fix the code format --- graphdatascience/query_runner/gds_arrow_client.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 8e5301977..c548c8fbf 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -102,10 +102,17 @@ def __init__( self._flight_client = self._instantiate_flight_client() def _instantiate_flight_client(self) -> flight.FlightClient: - location = flight.Location.for_grpc_tls(self._host, self._port) if self._encrypted else flight.Location.for_grpc_tcp(self._host, self._port) + location = ( + flight.Location.for_grpc_tls(self._host, self._port) + if self._encrypted + else flight.Location.for_grpc_tcp(self._host, self._port) + ) client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} if self._auth: - client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=self._user_agent)] + client_options["middleware"] = [ + AuthFactory(self._auth_middleware), + UserAgentFactory(useragent=self._user_agent), + ] if self._tls_root_certs: client_options["tls_root_certs"] = self._tls_root_certs return flight.FlightClient(location, **client_options) @@ -543,7 +550,7 @@ def upload_triplets( A callback function that is called with the number of rows uploaded after each batch """ self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback) - + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() # Remove the FlightClient as it isn't serializable From f46e36fe093a752ac5c710c61ee83ffab4442384 Mon Sep 17 00:00:00 2001 From: emrehizal Date: Thu, 19 Dec 2024 11:29:54 +0100 Subject: [PATCH 5/5] style: fix the code format --- graphdatascience/query_runner/gds_arrow_client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index c548c8fbf..0562cb2cf 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -96,8 +96,6 @@ def __init__( if auth: self._auth_middleware = AuthMiddleware(auth) - if not self._user_agent: - self._user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" self._flight_client = self._instantiate_flight_client() @@ -109,9 +107,13 @@ def _instantiate_flight_client(self) -> flight.FlightClient: ) client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} if self._auth: + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" + if self._user_agent: + user_agent = self._user_agent + client_options["middleware"] = [ AuthFactory(self._auth_middleware), - UserAgentFactory(useragent=self._user_agent), + UserAgentFactory(useragent=user_agent), ] if self._tls_root_certs: client_options["tls_root_certs"] = self._tls_root_certs