Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

* Add the new concept of GDS Sessions, used to manage GDS computations in Aura, based on data from an AuraDB instance.
* Add a new `gds.graph.project` endpoint to project graphs from AuraDB instances to GDS sessions.
* `nodePropertySchema` and `relationshipPropertySchema` can be used to optimise remote projections.
* Add a new top-level class `GdsSessions` to manage GDS sessions in Aura.
* `GdsSessions` support `get_or_create()`, `list()`, and `delete()`.
* Creating a new session supports various sizes.
Expand Down
11 changes: 0 additions & 11 deletions doc/modules/ROOT/pages/tutorials/gds-sessions.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ although we do not do that in this notebook.

[source, python, role=no-test]
----
from graphdatascience.session import GdsPropertyTypes

G, result = gds.graph.project(
"people-and-fruits",
"""
Expand All @@ -201,15 +199,6 @@ G, result = gds.graph.project(
relationshipType: type(rel)
})
""",
nodePropertySchema={
"age": GdsPropertyTypes.LONG,
"experience": GdsPropertyTypes.LONG,
"hipster": GdsPropertyTypes.LONG,
"tropical": GdsPropertyTypes.LONG,
"sourness": GdsPropertyTypes.DOUBLE,
"sweetness": GdsPropertyTypes.DOUBLE,
},
relationshipPropertySchema={},
)

str(G)
Expand Down
11 changes: 0 additions & 11 deletions examples/gds-sessions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,6 @@
"metadata": {},
"outputs": [],
"source": [
"from graphdatascience.session import GdsPropertyTypes\n",
"\n",
"G, result = gds.graph.project(\n",
" \"people-and-fruits\",\n",
" \"\"\"\n",
Expand All @@ -264,15 +262,6 @@
" relationshipType: type(rel)\n",
" })\n",
" \"\"\",\n",
" nodePropertySchema={\n",
" \"age\": GdsPropertyTypes.LONG,\n",
" \"experience\": GdsPropertyTypes.LONG,\n",
" \"hipster\": GdsPropertyTypes.LONG,\n",
" \"tropical\": GdsPropertyTypes.LONG,\n",
" \"sourness\": GdsPropertyTypes.DOUBLE,\n",
" \"sweetness\": GdsPropertyTypes.DOUBLE,\n",
" },\n",
" relationshipPropertySchema={},\n",
")\n",
"\n",
"str(G)"
Expand Down
2 changes: 1 addition & 1 deletion graphdatascience/graph/graph_entity_ops_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def add_property(query: str, prop: str) -> str:
return reduce(add_property, db_node_properties, query_prefix)

@compatible_with("write", min_inclusive=ServerVersion(2, 2, 0))
def write(self, G: Graph, node_properties: List[str], node_labels: Strings = ["*"], **config: Any) -> "Series[Any]":
def write(self, G: Graph, node_properties: Strings, node_labels: Strings = ["*"], **config: Any) -> "Series[Any]":
self._namespace += ".write"
return self._handle_properties(G, node_properties, node_labels, config).squeeze() # type: ignore

Expand Down
1 change: 0 additions & 1 deletion graphdatascience/graph/graph_remote_proc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@
class GraphRemoteProcRunner(BaseGraphProcRunner):
@property
def project(self) -> GraphProjectRemoteRunner:
self._namespace += ".project.remoteDb"
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
37 changes: 20 additions & 17 deletions graphdatascience/graph/graph_remote_project_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import Any
from typing import List, Optional

from ..error.illegal_attr_checker import IllegalAttrChecker
from ..query_runner.aura_db_query_runner import AuraDbQueryRunner
from ..server_version.compatible_with import compatible_with
from .graph_object import Graph
from graphdatascience.call_parameters import CallParameters
Expand All @@ -11,28 +12,30 @@


class GraphProjectRemoteRunner(IllegalAttrChecker):
_SCHEMA_KEYS = ["nodePropertySchema", "relationshipPropertySchema"]
@compatible_with("project", min_inclusive=ServerVersion(2, 7, 0))
def __call__(
self,
graph_name: str,
query: str,
concurrency: int = 4,
undirected_relationship_types: Optional[List[str]] = None,
inverse_indexed_relationship_types: Optional[List[str]] = None,
) -> GraphCreateResult:
if inverse_indexed_relationship_types is None:
inverse_indexed_relationship_types = []
if undirected_relationship_types is None:
undirected_relationship_types = []

@compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
placeholder = "<>" # host and token will be added by query runner
self.map_property_types(config)
params = CallParameters(
graph_name=graph_name,
query=query,
token=placeholder,
host=placeholder,
remote_database=self._query_runner.database(),
config=config,
concurrency=concurrency,
undirected_relationship_types=undirected_relationship_types,
inverse_indexed_relationship_types=inverse_indexed_relationship_types,
)

result = self._query_runner.call_procedure(
endpoint=self._namespace,
endpoint=AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME,
params=params,
).squeeze()
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)

@staticmethod
def map_property_types(config: dict[str, Any]) -> None:
for key in GraphProjectRemoteRunner._SCHEMA_KEYS:
if key in config:
config[key] = {k: v.value for k, v in config[key].items()}
45 changes: 7 additions & 38 deletions graphdatascience/query_runner/arrow_graph_constructor.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from __future__ import annotations

import concurrent
import json
import math
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, NoReturn, Optional

import numpy
import pyarrow.flight as flight
from pandas import DataFrame
from pyarrow import Table
from tqdm.auto import tqdm

from .arrow_endpoint_version import ArrowEndpointVersion
from .gds_arrow_client import GdsArrowClient
from .graph_constructor import GraphConstructor


Expand All @@ -22,17 +20,15 @@ def __init__(
self,
database: str,
graph_name: str,
flight_client: flight.FlightClient,
flight_client: GdsArrowClient,
concurrency: int,
arrow_endpoint_version: ArrowEndpointVersion,
undirected_relationship_types: Optional[List[str]],
chunk_size: int = 10_000,
):
self._database = database
self._concurrency = concurrency
self._graph_name = graph_name
self._client = flight_client
self._arrow_endpoint_version = arrow_endpoint_version
self._undirected_relationship_types = (
[] if undirected_relationship_types is None else undirected_relationship_types
)
Expand All @@ -49,20 +45,20 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
if self._undirected_relationship_types:
config["undirected_relationship_types"] = self._undirected_relationship_types

self._send_action(
self._client.send_action(
"CREATE_GRAPH",
config,
)

self._send_dfs(node_dfs, "node")

self._send_action("NODE_LOAD_DONE", {"name": self._graph_name})
self._client.send_action("NODE_LOAD_DONE", {"name": self._graph_name})

self._send_dfs(relationship_dfs, "relationship")

self._send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name})
self._client.send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name})
except (Exception, KeyboardInterrupt) as e:
self._send_action("ABORT", {"name": self._graph_name})
self._client.send_action("ABORT", {"name": self._graph_name})

raise e

Expand All @@ -85,25 +81,12 @@ def _partition_dfs(self, dfs: List[DataFrame]) -> List[DataFrame]:

return partitioned_dfs

def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
action_type = self._versioned_action_type(action_type)
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))

# Consume result fully to sanity check and avoid cancelled streams
collected_result = list(result)
assert len(collected_result) == 1

json.loads(collected_result[0].body.to_pybytes().decode())

def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None:
table = Table.from_pandas(df)
batches = table.to_batches(self._chunk_size)
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)

# Write schema
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
writer, _ = self._client.do_put(upload_descriptor, table.schema)
writer, _ = self._client.start_put(flight_descriptor, table.schema)

with writer:
# Write table in chunks
Expand All @@ -126,17 +109,3 @@ def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
if not future.exception():
continue
raise future.exception() # type: ignore

def _versioned_action_type(self, action_type: str) -> str:
return self._arrow_endpoint_version.prefix() + action_type

def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
return (
flight_descriptor
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
else {
"name": "PUT_MESSAGE",
"version": ArrowEndpointVersion.V1.version(),
"body": flight_descriptor,
}
)
Loading