From c774985d5b364b8547cac5f71ae2c909ddffd5b9 Mon Sep 17 00:00:00 2001 From: dhrudevalia Date: Thu, 2 Jan 2025 13:12:53 +0000 Subject: [PATCH 1/2] fix: use serialisation helper function --- graphdatascience/query_runner/gds_arrow_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 0562cb2cf..e57ea6934 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -654,8 +654,9 @@ def _do_get( ticket = flight.Ticket(json.dumps(payload).encode("utf-8")) + client = self._client() try: - get = self._flight_client.do_get(ticket) + get = client.do_get(ticket) arrow_table = get.read_all() except Exception as e: self.handle_flight_error(e) From 613f44b866a84e297c9bae08634cc04dd8744d79 Mon Sep 17 00:00:00 2001 From: dhrudevalia Date: Thu, 2 Jan 2025 14:13:17 +0000 Subject: [PATCH 2/2] fix: use helper to fetch token, only close if flight client exists --- graphdatascience/query_runner/gds_arrow_client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index e57ea6934..9d42c3e1b 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -140,7 +140,8 @@ def request_token(self) -> Optional[str]: a token from the server and returns it. """ if self._auth: - self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + client = self._client() + client.authenticate_basic_token(self._auth[0], self._auth[1]) return self._auth_middleware.token() else: return "IGNORED" @@ -684,10 +685,11 @@ def __exit__( exception_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - self._flight_client.close() + self.close() def close(self) -> None: - self._flight_client.close() + if self._flight_client: + self._flight_client.close() def _versioned_action_type(self, action_type: str) -> str: return self._arrow_endpoint_version.prefix() + action_type