-
Notifications
You must be signed in to change notification settings - Fork 54
fix: remove flight client from serialization #804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
16fc227
587893c
58d0c55
2fcfe7b
f46e36f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |||||||||||||||||
| import warnings | ||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||
| from types import TracebackType | ||||||||||||||||||
| from typing import Any, Callable, Iterable, Optional, Type, Union | ||||||||||||||||||
| from typing import Any, Callable, Dict, Iterable, Optional, Type, Union | ||||||||||||||||||
|
|
||||||||||||||||||
| import pyarrow | ||||||||||||||||||
| from neo4j.exceptions import ClientError | ||||||||||||||||||
|
|
@@ -89,19 +89,35 @@ def __init__( | |||||||||||||||||
| self._host = host | ||||||||||||||||||
| self._port = port | ||||||||||||||||||
| self._auth = auth | ||||||||||||||||||
| self._encrypted = encrypted | ||||||||||||||||||
| self._disable_server_verification = disable_server_verification | ||||||||||||||||||
| self._tls_root_certs = tls_root_certs | ||||||||||||||||||
| self._user_agent = user_agent | ||||||||||||||||||
|
|
||||||||||||||||||
| location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port) | ||||||||||||||||||
|
|
||||||||||||||||||
| client_options: dict[str, Any] = {"disable_server_verification": disable_server_verification} | ||||||||||||||||||
| if auth: | ||||||||||||||||||
| self._auth_middleware = AuthMiddleware(auth) | ||||||||||||||||||
| if not user_agent: | ||||||||||||||||||
| user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" | ||||||||||||||||||
| client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)] | ||||||||||||||||||
| if tls_root_certs: | ||||||||||||||||||
| client_options["tls_root_certs"] = tls_root_certs | ||||||||||||||||||
|
|
||||||||||||||||||
| self._flight_client = flight.FlightClient(location, **client_options) | ||||||||||||||||||
| self._flight_client = self._instantiate_flight_client() | ||||||||||||||||||
|
|
||||||||||||||||||
| def _instantiate_flight_client(self) -> flight.FlightClient: | ||||||||||||||||||
| location = ( | ||||||||||||||||||
| flight.Location.for_grpc_tls(self._host, self._port) | ||||||||||||||||||
| if self._encrypted | ||||||||||||||||||
| else flight.Location.for_grpc_tcp(self._host, self._port) | ||||||||||||||||||
| ) | ||||||||||||||||||
| client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} | ||||||||||||||||||
| if self._auth: | ||||||||||||||||||
| user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" | ||||||||||||||||||
| if self._user_agent: | ||||||||||||||||||
| user_agent = self._user_agent | ||||||||||||||||||
|
Comment on lines
+110
to
+112
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so it could be
Suggested change
but it's not necessary. also maybe this will trip up the typer -- not sure.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah looks much better this way |
||||||||||||||||||
|
|
||||||||||||||||||
| client_options["middleware"] = [ | ||||||||||||||||||
| AuthFactory(self._auth_middleware), | ||||||||||||||||||
| UserAgentFactory(useragent=user_agent), | ||||||||||||||||||
| ] | ||||||||||||||||||
| if self._tls_root_certs: | ||||||||||||||||||
| client_options["tls_root_certs"] = self._tls_root_certs | ||||||||||||||||||
| return flight.FlightClient(location, **client_options) | ||||||||||||||||||
|
|
||||||||||||||||||
| def connection_info(self) -> tuple[str, int]: | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
@@ -537,11 +553,28 @@ def upload_triplets( | |||||||||||||||||
| """ | ||||||||||||||||||
| self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback) | ||||||||||||||||||
|
|
||||||||||||||||||
| def __getstate__(self) -> Dict[str, Any]: | ||||||||||||||||||
| state = self.__dict__.copy() | ||||||||||||||||||
| # Remove the FlightClient as it isn't serializable | ||||||||||||||||||
| if "_flight_client" in state: | ||||||||||||||||||
| del state["_flight_client"] | ||||||||||||||||||
| return state | ||||||||||||||||||
|
|
||||||||||||||||||
| def _client(self) -> flight.FlightClient: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Lazy client construction to help pickle this class because a PyArrow | ||||||||||||||||||
| FlightClient is not serializable. | ||||||||||||||||||
|
Comment on lines
+565
to
+566
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, but I am curious; in which situations do you need to serialise this class?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure at this point but we might have needed it this way on BigQuery Connector
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be good to know why, to stop future developers from removing this feature.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes sure, I will let you know why when I make sure about the need on the connector side
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the BigQuery connector, we use Spark jobs to process data, and Spark requires a serialized version of this class to distribute the job across different workers. Since FlightClient is not inherently serializable, we needed this lazy initialization. |
||||||||||||||||||
| """ | ||||||||||||||||||
| if not hasattr(self, "_flight_client") or not self._flight_client: | ||||||||||||||||||
| self._flight_client = self._instantiate_flight_client() | ||||||||||||||||||
| return self._flight_client | ||||||||||||||||||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]: | ||||||||||||||||||
| action_type = self._versioned_action_type(action_type) | ||||||||||||||||||
|
|
||||||||||||||||||
| try: | ||||||||||||||||||
| result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) | ||||||||||||||||||
| client = self._client() | ||||||||||||||||||
Mats-SX marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
| result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Consume result fully to sanity check and avoid cancelled streams | ||||||||||||||||||
| collected_result = list(result) | ||||||||||||||||||
|
|
@@ -569,7 +602,9 @@ def _upload_data( | |||||||||||||||||
|
|
||||||||||||||||||
| flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type}) | ||||||||||||||||||
| upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) | ||||||||||||||||||
| put_stream, ack_stream = self._flight_client.do_put(upload_descriptor, batches[0].schema) | ||||||||||||||||||
|
|
||||||||||||||||||
| client = self._client() | ||||||||||||||||||
| put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema) | ||||||||||||||||||
|
|
||||||||||||||||||
| @retry( | ||||||||||||||||||
| stop=(stop_after_delay(10) | stop_after_attempt(5)), | ||||||||||||||||||
|
|
||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.