diff --git a/graphdatascience/query_runner/arrow_graph_constructor.py b/graphdatascience/query_runner/arrow_graph_constructor.py index 2997ef8b9..353eaa5c4 100644 --- a/graphdatascience/query_runner/arrow_graph_constructor.py +++ b/graphdatascience/query_runner/arrow_graph_constructor.py @@ -88,13 +88,16 @@ def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> Non writer, _ = self._client.start_put(flight_descriptor, table.schema) - with writer: - # Write table in chunks - for partition in batches: - writer.write_batch(partition) - pbar.update(partition.num_rows) - # Force a refresh to avoid the progress bar getting stuck at 0% - pbar.refresh() + try: + with writer: + # Write table in chunks + for partition in batches: + writer.write_batch(partition) + pbar.update(partition.num_rows) + # Force a refresh to avoid the progress bar getting stuck at 0% + pbar.refresh() + except Exception as e: + GdsArrowClient.handle_flight_error(e) def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None: desc = "Uploading Nodes" if entity_type == "node" else "Uploading Relationships" diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 6fec7973b..6f80f57c1 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -1,9 +1,11 @@ import base64 import json +import re import time import warnings from typing import Any, Dict, Optional, Tuple +from neo4j.exceptions import ClientError from pandas import DataFrame from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight from pyarrow._flight import FlightStreamReader, FlightStreamWriter @@ -128,8 +130,12 @@ def get_property( } ticket = flight.Ticket(json.dumps(payload).encode("utf-8")) - get = self._flight_client.do_get(ticket) - arrow_table = get.read_all() + + try: + get = self._flight_client.do_get(ticket) + arrow_table = get.read_all() + except Exception as e: + self.handle_flight_error(e) if configuration.get("list_node_labels", False): # GDS 2.5 had an inconsistent naming of the node labels column @@ -147,13 +153,17 @@ def get_property( def send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None: action_type = self._versioned_action_type(action_type) - result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) - # Consume result fully to sanity check and avoid cancelled streams - collected_result = list(result) - assert len(collected_result) == 1 + try: + result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) + + # Consume result fully to sanity check and avoid cancelled streams + collected_result = list(result) + assert len(collected_result) == 1 - json.loads(collected_result[0].body.to_pybytes().decode()) + json.loads(collected_result[0].body.to_pybytes().decode()) + except Exception as e: + self.handle_flight_error(e) def start_put(self, payload: Dict[str, Any], schema: Schema) -> Tuple[FlightStreamWriter, FlightStreamReader]: flight_descriptor = self._versioned_flight_descriptor(payload) @@ -199,6 +209,30 @@ def _sanitize_arrow_table(arrow_table: Table) -> Table: arrow_table = arrow_table.set_column(idx, field.name, decoded_col) return arrow_table + @staticmethod + def handle_flight_error(e: Exception): + if ( + isinstance(e, flight.FlightServerError) + or isinstance(e, flight.FlightInternalError) + or isinstance(e, ClientError) + ): + original_message = e.args[0] + improved_message = original_message.replace( + "Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", "" + ) + improved_message = improved_message.replace( + "Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", "" + ) + improved_message = improved_message.replace( + "Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ", + "", + ) + improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message) + + raise flight.FlightServerError(improved_message) + else: + raise e + class AuthFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None: diff --git a/graphdatascience/query_runner/session_query_runner.py b/graphdatascience/query_runner/session_query_runner.py index 7ae94601c..7796ea90b 100644 --- a/graphdatascience/query_runner/session_query_runner.py +++ b/graphdatascience/query_runner/session_query_runner.py @@ -143,9 +143,12 @@ def _remote_projection( versioned_endpoint = self._resolved_protocol_version.versioned_procedure_name(endpoint) - return self._db_query_runner.call_procedure( - versioned_endpoint, remote_project_proc_params, yields, database, logging, False - ) + try: + return self._db_query_runner.call_procedure( + versioned_endpoint, remote_project_proc_params, yields, database, logging, False + ) + except Exception as e: + GdsArrowClient.handle_flight_error(e) @staticmethod def _project_params_v2( diff --git a/graphdatascience/tests/unit/test_gds_arrow_client.py b/graphdatascience/tests/unit/test_gds_arrow_client.py index 5e881179b..e95309d8a 100644 --- a/graphdatascience/tests/unit/test_gds_arrow_client.py +++ b/graphdatascience/tests/unit/test_gds_arrow_client.py @@ -1,6 +1,9 @@ +import re + import pytest +from pyarrow import flight -from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware +from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient def test_auth_middleware() -> None: @@ -27,3 +30,25 @@ def test_auth_middleware_bad_headers() -> None: with pytest.raises(ValueError, match="Incompatible header value received from server: `12342`"): middleware.received_headers({"authorization": [12342]}) + + +def test_handle_flight_error(): + with pytest.raises( + flight.FlightServerError, + match="FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.", + ): + GdsArrowClient.handle_flight_error( + flight.FlightServerError( + 'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal' + ) + ) + + with pytest.raises( + flight.FlightServerError, + match=re.escape("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"), + ): + GdsArrowClient.handle_flight_error( + flight.FlightServerError( + "FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]" + ) + )