Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions graphdatascience/query_runner/arrow_graph_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ def __init__(
query_runner: QueryRunner,
graph_name: str,
flight_client: flight.FlightClient,
flight_options: flight.FlightCallOptions,
concurrency: int,
chunk_size: int = 10_000,
):
self._query_runner = query_runner
self._concurrency = concurrency
self._graph_name = graph_name
self._client = flight_client
self._flight_options = flight_options
self._chunk_size = chunk_size
self._min_batch_size = chunk_size * 10

Expand Down Expand Up @@ -58,9 +56,7 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]:
return partitioned_dfs

def _send_action(self, action_type: str, meta_data: Dict[str, str]) -> None:
result = self._client.do_action(
flight.Action(action_type, json.dumps(meta_data).encode("utf-8")), self._flight_options
)
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))

json.loads(next(result).body.to_pybytes().decode())

Expand All @@ -71,7 +67,7 @@ def _send_df(self, df: DataFrame, entity_type: str) -> None:
# Write schema
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))

writer, _ = self._client.do_put(upload_descriptor, table.schema, self._flight_options)
writer, _ = self._client.do_put(upload_descriptor, table.schema)

with writer:
# Write table in chunks
Expand Down
78 changes: 67 additions & 11 deletions graphdatascience/query_runner/arrow_query_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import base64
import json
import time
from typing import Any, Dict, Optional, Tuple

import pyarrow.flight as flight
from pandas.core.frame import DataFrame
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory

from .arrow_graph_constructor import ArrowGraphConstructor
from .graph_constructor import GraphConstructor
Expand All @@ -28,16 +31,15 @@ def __init__(
else flight.Location.for_grpc_tcp(host, int(port_string))
)

self._flight_client = flight.FlightClient(location, disable_server_verification=disable_server_verification)
self._flight_options = flight.FlightCallOptions()

client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
if auth:
username, password = auth
header, token = self._flight_client.authenticate_basic_token(username, password)
if header:
self._flight_options = flight.FlightCallOptions(headers=[(header, token)])
client_options["middleware"] = [AuthFactory(auth)]

self._flight_client = flight.FlightClient(location, **client_options)

def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
if params is None:
params = {}
if "gds.graph.streamNodeProperty" in query:
graph_name = params["graph_name"]
property_name = params["properties"]
Expand All @@ -57,8 +59,10 @@ def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:

return self._fallback_query_runner.run_query(query, params)

def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
# For now there's no logging support with Arrow queries.
if params is None:
params = {}
return self._fallback_query_runner.run_query_with_logging(query, params)

def set_database(self, db: str) -> None:
Expand All @@ -79,9 +83,61 @@ def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configur
}
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))

result: DataFrame = self._flight_client.do_get(ticket, self._flight_options).read_pandas()
get = self._flight_client.do_get(ticket)
result: DataFrame = get.read_pandas()

return result

def create_graph_constructor(self, graph_name: str, concurrency: int) -> GraphConstructor:
return ArrowGraphConstructor(self, graph_name, self._flight_client, self._flight_options, concurrency)
return ArrowGraphConstructor(self, graph_name, self._flight_client, concurrency)


class AuthFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._auth = auth
self._token: Optional[str] = None
self._token_timestamp = 0

def start_call(self, info: Any) -> "AuthMiddleware":
return AuthMiddleware(self)

def token(self) -> Optional[str]:
# check whether the token is older than 10 minutes. If so, reset it.
if self._token and int(time.time()) - self._token_timestamp > 600:
self._token = None

return self._token

def set_token(self, token: str) -> None:
self._token = token
self._token_timestamp = int(time.time())

@property
def auth(self) -> Tuple[str, str]:
return self._auth


class AuthMiddleware(ClientMiddleware): # type: ignore
def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._factory = factory

def received_headers(self, headers: Dict[str, Any]) -> None:
auth_header: str = headers.get("Authorization", None)
if not auth_header:
return
[auth_type, token] = auth_header.split(" ", 1)
if auth_type == "Bearer":
self._factory.set_token(token)

def sending_headers(self) -> Dict[str, str]:
token = self._factory.token()
if not token:
username, password = self._factory.auth
auth_token = f"{username}:{password}"
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
# There seems to be a bug, `authorization` must be lower key
return {"authorization": auth_token}
else:
return {"authorization": "Bearer " + token}
10 changes: 8 additions & 2 deletions graphdatascience/query_runner/neo4j_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def __init__(self, driver: neo4j.Driver, db: Optional[str] = neo4j.DEFAULT_DATAB
except Exception as e:
raise UnableToConnectError(e)

def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
def run_query(self, query: str, params: Optional[Dict[str, str]] = None) -> DataFrame:
if params is None:
params = {}

with self._driver.session(database=self._db) as session:
result = session.run(query, params)

Expand All @@ -44,7 +47,10 @@ def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:

return result.to_df() # type: ignore

def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
if params is None:
params = {}

if self._server_version < ServerVersion(2, 1, 0):
return self.run_query(query, params)

Expand Down
6 changes: 3 additions & 3 deletions graphdatascience/query_runner/query_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, Optional

from pandas.core.frame import DataFrame

Expand All @@ -9,10 +9,10 @@

class QueryRunner(ABC):
@abstractmethod
def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
pass

def run_query_with_logging(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
def run_query_with_logging(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
return self.run_query(query, params)

@abstractmethod
Expand Down
7 changes: 5 additions & 2 deletions graphdatascience/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import pandas
import pytest
Expand All @@ -19,7 +19,10 @@ def __init__(self, server_version: Union[str, ServerVersion]) -> None:
self.params: List[Dict[str, Any]] = []
self.server_version = server_version

def run_query(self, query: str, params: Dict[str, Any] = {}) -> DataFrame:
def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> DataFrame:
if params is None:
params = {}

self.queries.append(query)
self.params.append(params)

Expand Down