Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions graphdatascience/query_runner/gds_arrow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Callable, Iterable, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, Optional, Type, Union

import pyarrow
from neo4j.exceptions import ClientError
Expand Down Expand Up @@ -89,19 +89,35 @@ def __init__(
self._host = host
self._port = port
self._auth = auth
self._encrypted = encrypted
self._disable_server_verification = disable_server_verification
self._tls_root_certs = tls_root_certs
self._user_agent = user_agent

location = flight.Location.for_grpc_tls(host, port) if encrypted else flight.Location.for_grpc_tcp(host, port)

client_options: dict[str, Any] = {"disable_server_verification": disable_server_verification}
if auth:
self._auth_middleware = AuthMiddleware(auth)
if not user_agent:
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)]
if tls_root_certs:
client_options["tls_root_certs"] = tls_root_certs

self._flight_client = flight.FlightClient(location, **client_options)
self._flight_client = self._instantiate_flight_client()

def _instantiate_flight_client(self) -> flight.FlightClient:
location = (
flight.Location.for_grpc_tls(self._host, self._port)
if self._encrypted
else flight.Location.for_grpc_tcp(self._host, self._port)
)
client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification}
if self._auth:
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
if self._user_agent:
user_agent = self._user_agent
Comment on lines +110 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it could be

Suggested change
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
if self._user_agent:
user_agent = self._user_agent
user_agent = (
self._user_agent
if self._user_agent
else f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
)

but it's not necessary. also maybe this will trip up the typer -- not sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah looks much better this way


client_options["middleware"] = [
AuthFactory(self._auth_middleware),
UserAgentFactory(useragent=user_agent),
]
if self._tls_root_certs:
client_options["tls_root_certs"] = self._tls_root_certs
return flight.FlightClient(location, **client_options)

def connection_info(self) -> tuple[str, int]:
"""
Expand Down Expand Up @@ -537,11 +553,28 @@ def upload_triplets(
"""
self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback)

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# Remove the FlightClient as it isn't serializable
if "_flight_client" in state:
del state["_flight_client"]
return state

def _client(self) -> flight.FlightClient:
"""
Lazy client construction to help pickle this class because a PyArrow
FlightClient is not serializable.
Comment on lines +565 to +566
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, but I am curious; in which situations do you need to serialise this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure at this point but we might have needed it this way on BigQuery Connector

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to know why, to stop future developers from removing this feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sure, I will let you know why when I make sure about the need on the connector side

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the BigQuery connector, we use Spark jobs to process data, and Spark requires a serialized version of this class to distribute the job across different workers. Since FlightClient is not inherently serializable, we needed this lazy initialization.

"""
if not hasattr(self, "_flight_client") or not self._flight_client:
self._flight_client = self._instantiate_flight_client()
return self._flight_client

def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]:
action_type = self._versioned_action_type(action_type)

try:
result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
client = self._client()
result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))

# Consume result fully to sanity check and avoid cancelled streams
collected_result = list(result)
Expand Down Expand Up @@ -569,7 +602,9 @@ def _upload_data(

flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type})
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
put_stream, ack_stream = self._flight_client.do_put(upload_descriptor, batches[0].schema)

client = self._client()
put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema)

@retry(
stop=(stop_after_delay(10) | stop_after_attempt(5)),
Expand Down