diff --git a/graphdatascience/query_runner/arrow_graph_constructor.py b/graphdatascience/query_runner/arrow_graph_constructor.py index a60b7a635..43641afd1 100644 --- a/graphdatascience/query_runner/arrow_graph_constructor.py +++ b/graphdatascience/query_runner/arrow_graph_constructor.py @@ -18,7 +18,6 @@ def __init__( query_runner: QueryRunner, graph_name: str, flight_client: flight.FlightClient, - flight_options: flight.FlightCallOptions, concurrency: int, chunk_size: int = 10_000, ): @@ -26,7 +25,6 @@ def __init__( self._concurrency = concurrency self._graph_name = graph_name self._client = flight_client - self._flight_options = flight_options self._chunk_size = chunk_size self._min_batch_size = chunk_size * 10 @@ -58,9 +56,7 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]: return partitioned_dfs def _send_action(self, action_type: str, meta_data: Dict[str, str]) -> None: - result = self._client.do_action( - flight.Action(action_type, json.dumps(meta_data).encode("utf-8")), self._flight_options - ) + result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))) json.loads(next(result).body.to_pybytes().decode()) @@ -71,7 +67,7 @@ def _send_df(self, df: DataFrame, entity_type: str) -> None: # Write schema upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8")) - writer, _ = self._client.do_put(upload_descriptor, table.schema, self._flight_options) + writer, _ = self._client.do_put(upload_descriptor, table.schema) with writer: # Write table in chunks diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index 6a2d0fd0a..082b71980 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -1,8 +1,11 @@ +import base64 import json +import time from typing import Any, Dict, Optional, Tuple import pyarrow.flight as flight from pandas.core.frame import DataFrame +from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory from .arrow_graph_constructor import ArrowGraphConstructor from .graph_constructor import GraphConstructor @@ -28,16 +31,15 @@ def __init__( else flight.Location.for_grpc_tcp(host, int(port_string)) ) - self._flight_client = flight.FlightClient(location, disable_server_verification=disable_server_verification) - self._flight_options = flight.FlightCallOptions() - + client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} if auth: - username, password = auth - header, token = self._flight_client.authenticate_basic_token(username, password) - if header: - self._flight_options = flight.FlightCallOptions(headers=[(header, token)]) + client_options["middleware"] = [AuthFactory(auth)] + + self._flight_client = flight.FlightClient(location, **client_options) - def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: + def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame: + if params is None: + params = {} if "gds.graph.streamNodeProperty" in query: graph_name = params["graph_name"] property_name = params["properties"] @@ -57,8 +59,10 @@ def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: return self._fallback_query_runner.run_query(query, params) - def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: + def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame: # For now there's no logging support with Arrow queries. + if params is None: + params = {} return self._fallback_query_runner.run_query_with_logging(query, params) def set_database(self, db: str) -> None: @@ -79,9 +83,61 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur } ticket = flight.Ticket(json.dumps(payload).encode("utf-8")) - result: DataFrame = self._flight_client.do_get(ticket, self._flight_options).read_pandas() + get = self._flight_client.do_get(ticket) + result: DataFrame = get.read_pandas() return result def create_graph_constructor(self, graph_name: str, concurrency: int) -> GraphConstructor: - return ArrowGraphConstructor(self, graph_name, self._flight_client, self._flight_options, concurrency) + return ArrowGraphConstructor(self, graph_name, self._flight_client, concurrency) + + +class AuthFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._auth = auth + self._token: Optional[str] = None + self._token_timestamp = 0 + + def start_call(self, info: Any) -> "AuthMiddleware": + return AuthMiddleware(self) + + def token(self) -> Optional[str]: + # check whether the token is older than 10 minutes. If so, reset it. + if self._token and int(time.time()) - self._token_timestamp > 600: + self._token = None + + return self._token + + def set_token(self, token: str) -> None: + self._token = token + self._token_timestamp = int(time.time()) + + @property + def auth(self) -> Tuple[str, str]: + return self._auth + + +class AuthMiddleware(ClientMiddleware): # type: ignore + def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._factory = factory + + def received_headers(self, headers: Dict[str, Any]) -> None: + auth_header: str = headers.get("Authorization", None) + if not auth_header: + return + [auth_type, token] = auth_header.split(" ", 1) + if auth_type == "Bearer": + self._factory.set_token(token) + + def sending_headers(self) -> Dict[str, str]: + token = self._factory.token() + if not token: + username, password = self._factory.auth + auth_token = f"{username}:{password}" + auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") + # There seems to be a bug, `authorization` must be lower key + return {"authorization": auth_token} + else: + return {"authorization": "Bearer " + token} diff --git a/graphdatascience/query_runner/neo4j_query_runner.py b/graphdatascience/query_runner/neo4j_query_runner.py index 14955b401..1ff65a9df 100644 --- a/graphdatascience/query_runner/neo4j_query_runner.py +++ b/graphdatascience/query_runner/neo4j_query_runner.py @@ -31,7 +31,10 @@ def __init__(self, driver: neo4j.Driver, db: Optional[str] = neo4j.DEFAULT_DATAB except Exception as e: raise UnableToConnectError(e) - def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: + def run_query(self, query: str, params: Optional[Dict[str, str]] = None) -> DataFrame: + if params is None: + params = {} + with self._driver.session(database=self._db) as session: result = session.run(query, params) @@ -44,7 +47,10 @@ def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: return result.to_df() # type: ignore - def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: + def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame: + if params is None: + params = {} + if self._server_version < ServerVersion(2, 1, 0): return self.run_query(query, params) diff --git a/graphdatascience/query_runner/query_runner.py b/graphdatascience/query_runner/query_runner.py index 4ea0a0cc8..dd4003a07 100644 --- a/graphdatascience/query_runner/query_runner.py +++ b/graphdatascience/query_runner/query_runner.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Optional from pandas.core.frame import DataFrame @@ -9,10 +9,10 @@ class QueryRunner(ABC): @abstractmethod - def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: + def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame: pass - def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: + def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame: return self.run_query(query, params) @abstractmethod diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 4f4e0e89a..5ee773646 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import pandas import pytest @@ -19,7 +19,10 @@ def __init__(self, server_version: Union[str, ServerVersion]) -> None: self.params: List[Dict[str, Any]] = [] self.server_version = server_version - def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame: + def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame: + if params is None: + params = {} + self.queries.append(query) self.params.append(params)