Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions graphdatascience/query_runner/arrow_graph_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,16 @@ def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> Non

writer, _ = self._client.start_put(flight_descriptor, table.schema)

with writer:
# Write table in chunks
for partition in batches:
writer.write_batch(partition)
pbar.update(partition.num_rows)
# Force a refresh to avoid the progress bar getting stuck at 0%
pbar.refresh()
try:
with writer:
# Write table in chunks
for partition in batches:
writer.write_batch(partition)
pbar.update(partition.num_rows)
# Force a refresh to avoid the progress bar getting stuck at 0%
pbar.refresh()
except Exception as e:
GdsArrowClient.handle_flight_error(e)

def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
desc = "Uploading Nodes" if entity_type == "node" else "Uploading Relationships"
Expand Down
48 changes: 41 additions & 7 deletions graphdatascience/query_runner/gds_arrow_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import json
import re
import time
import warnings
from typing import Any, Dict, Optional, Tuple

from neo4j.exceptions import ClientError
from pandas import DataFrame
from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight
from pyarrow._flight import FlightStreamReader, FlightStreamWriter
Expand Down Expand Up @@ -128,8 +130,12 @@ def get_property(
}

ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
get = self._flight_client.do_get(ticket)
arrow_table = get.read_all()

try:
get = self._flight_client.do_get(ticket)
arrow_table = get.read_all()
except Exception as e:
self.handle_flight_error(e)

if configuration.get("list_node_labels", False):
# GDS 2.5 had an inconsistent naming of the node labels column
Expand All @@ -147,13 +153,17 @@ def get_property(

def send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
action_type = self._versioned_action_type(action_type)
result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))

# Consume result fully to sanity check and avoid cancelled streams
collected_result = list(result)
assert len(collected_result) == 1
try:
result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))

# Consume result fully to sanity check and avoid cancelled streams
collected_result = list(result)
assert len(collected_result) == 1

json.loads(collected_result[0].body.to_pybytes().decode())
json.loads(collected_result[0].body.to_pybytes().decode())
except Exception as e:
self.handle_flight_error(e)

def start_put(self, payload: Dict[str, Any], schema: Schema) -> Tuple[FlightStreamWriter, FlightStreamReader]:
flight_descriptor = self._versioned_flight_descriptor(payload)
Expand Down Expand Up @@ -199,6 +209,30 @@ def _sanitize_arrow_table(arrow_table: Table) -> Table:
arrow_table = arrow_table.set_column(idx, field.name, decoded_col)
return arrow_table

@staticmethod
def handle_flight_error(e: Exception):
if (
isinstance(e, flight.FlightServerError)
or isinstance(e, flight.FlightInternalError)
or isinstance(e, ClientError)
):
original_message = e.args[0]
improved_message = original_message.replace(
"Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
)
improved_message = improved_message.replace(
"Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
)
improved_message = improved_message.replace(
"Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ",
"",
)
improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message)

raise flight.FlightServerError(improved_message)
else:
raise e


class AuthFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None:
Expand Down
9 changes: 6 additions & 3 deletions graphdatascience/query_runner/session_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,12 @@ def _remote_projection(

versioned_endpoint = self._resolved_protocol_version.versioned_procedure_name(endpoint)

return self._db_query_runner.call_procedure(
versioned_endpoint, remote_project_proc_params, yields, database, logging, False
)
try:
return self._db_query_runner.call_procedure(
versioned_endpoint, remote_project_proc_params, yields, database, logging, False
)
except Exception as e:
GdsArrowClient.handle_flight_error(e)

@staticmethod
def _project_params_v2(
Expand Down
27 changes: 26 additions & 1 deletion graphdatascience/tests/unit/test_gds_arrow_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import re

import pytest
from pyarrow import flight

from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware
from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient


def test_auth_middleware() -> None:
Expand All @@ -27,3 +30,25 @@ def test_auth_middleware_bad_headers() -> None:

with pytest.raises(ValueError, match="Incompatible header value received from server: `12342`"):
middleware.received_headers({"authorization": [12342]})


def test_handle_flight_error():
with pytest.raises(
flight.FlightServerError,
match="FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.",
):
GdsArrowClient.handle_flight_error(
flight.FlightServerError(
'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal'
)
)

with pytest.raises(
flight.FlightServerError,
match=re.escape("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"),
):
GdsArrowClient.handle_flight_error(
flight.FlightServerError(
"FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"
)
)