Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/graphdatascience/arrow_client/v2/gds_arrow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from graphdatascience.arrow_client.arrow_endpoint_version import ArrowEndpointVersion
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo
from graphdatascience.query_runner.termination_flag import TerminationFlag

from ...procedure_surface.api.default_values import ALL_TYPES
from ...procedure_surface.utils.config_converter import ConfigConverter
Expand Down Expand Up @@ -328,6 +329,7 @@ def upload_nodes(
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
batch_size: int = 10000,
progress_callback: Callable[[int], None] = lambda x: None,
termination_flag: TerminationFlag | None = None,
) -> None:
"""
Uploads node data to the server for a given job.
Expand All @@ -342,15 +344,20 @@ def upload_nodes(
The number of rows per batch
progress_callback
A callback function that is called with the number of rows uploaded after each batch
termination_flag
A termination flag to cancel the upload if requested
"""
self._upload_data("graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback)
self._upload_data(
"graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback, termination_flag
)

def upload_relationships(
self,
job_id: str,
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
batch_size: int = 10000,
progress_callback: Callable[[int], None] = lambda x: None,
termination_flag: TerminationFlag | None = None,
) -> None:
"""
Uploads relationship data to the server for a given job.
Expand All @@ -365,15 +372,20 @@ def upload_relationships(
The number of rows per batch
progress_callback
A callback function that is called with the number of rows uploaded after each batch
termination_flag
A termination flag to cancel the upload if requested
"""
self._upload_data("graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback)
self._upload_data(
"graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback, termination_flag
)

def upload_triplets(
self,
job_id: str,
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
batch_size: int = 10000,
progress_callback: Callable[[int], None] = lambda x: None,
termination_flag: TerminationFlag | None = None,
) -> None:
"""
Uploads triplet data to the server for a given job.
Expand All @@ -388,8 +400,10 @@ def upload_triplets(
The number of rows per batch
progress_callback
A callback function that is called with the number of rows uploaded after each batch
termination_flag
A termination flag to cancel the upload if requested
"""
self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback)
self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback, termination_flag)

def abort_job(self, job_id: str) -> None:
"""
Expand Down Expand Up @@ -464,6 +478,7 @@ def _upload_data(
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
batch_size: int = 10000,
progress_callback: Callable[[int], None] = lambda x: None,
termination_flag: TerminationFlag | None = None,
) -> None:
match data:
case pyarrow.Table():
Expand All @@ -490,6 +505,11 @@ def upload_batch(p: RecordBatch) -> None:

with put_stream:
for partition in batches:
if termination_flag is not None and termination_flag.is_set():
self.abort_job(job_id)
# closing the put_stream should raise an error. this is a safeguard to always signal the termination to the user.
raise RuntimeError(f"Upload for job '{job_id}' was aborted via termination flag.")

upload_batch(partition)
ack_stream.read()
progress_callback(partition.num_rows)
Expand Down
10 changes: 8 additions & 2 deletions src/graphdatascience/arrow_client/v2/job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,26 @@ def wait_for_job(
client: AuthenticatedArrowClient,
job_id: str,
show_progress: bool,
expected_status: str | None = None,
termination_flag: TerminationFlag | None = None,
) -> None:
progress_bar: TqdmProgressBar | None = None

def check_expected_status(status: JobStatus) -> bool:
return job_status.succeeded() if expected_status is None else status.status == expected_status

if termination_flag is None:
termination_flag = TerminationFlag.create()

for attempt in Retrying(retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5)):
for attempt in Retrying(
retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5), reraise=True
):
with attempt:
termination_flag.assert_running()

job_status = self.get_job_status(client, job_id)

if job_status.succeeded() or job_status.aborted():
if check_expected_status(job_status) or job_status.aborted():
if progress_bar:
progress_bar.finish(success=job_status.succeeded())
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from types import TracebackType
from typing import NamedTuple, Type

from pandas import DataFrame

from graphdatascience.procedure_surface.api.base_result import BaseResult
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
from graphdatascience.procedure_surface.api.catalog.graph_info import GraphInfo, GraphInfoWithDegrees
Expand All @@ -14,31 +16,76 @@


class CatalogEndpoints(ABC):
@abstractmethod
def construct(
self,
graph_name: str,
nodes: DataFrame | list[DataFrame],
relationships: DataFrame | list[DataFrame] | None = None,
concurrency: int | None = None,
undirected_relationship_types: list[str] | None = None,
) -> GraphV2:
"""Construct a graph from a list of node and relationship dataframes.

Parameters
----------
graph_name
Name of the graph to construct
nodes
Node dataframes. A dataframe should follow the schema:

- `nodeId` to identify uniquely the node overall dataframes
- `labels` to specify the labels of the node as a list of strings (optional)
- other columns are treated as node properties
relationships
Relationship dataframes. A dataframe should follow the schema:

- `sourceNodeId` to identify the start node of the relationship
- `targetNodeId` to identify the end node of the relationship
- `relationshipType` to specify the type of the relationship (optional)
- other columns are treated as relationship properties
concurrency
Number of concurrent threads to use.
undirected_relationship_types
List of relationship types to treat as undirected.

Returns
-------
GraphV2
Constructed graph object.
"""

@abstractmethod
def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
"""List graphs in the graph catalog.

Args:
G (GraphV2 | str | None, optional): GraphV2 object or name to filter results.
If None, list all graphs. Defaults to None.
Parameters
----------
G
GraphV2 object or name to filter results. If None, list all graphs.

Returns:
list[GraphListResult]: List of graph metadata objects containing information like
graph name, node count, relationship count, etc.
Returns
-------
list[GraphInfoWithDegrees]
List of graph metadata objects containing information like node count.
"""
pass

@abstractmethod
def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None:
"""Drop a graph from the graph catalog.

Args:
G (GraphV2 | str): GraphV2 object or name to drop.
fail_if_missing (bool): Whether to fail if the graph is missing. Defaults to True.
Parameters
----------
G
Graph to drop by name of object.
fail_if_missing
Whether to fail if the graph is missing

Returns:
GraphListResult: GraphV2 metadata object containing information like
graph name, node count, relationship count, etc.
Returns
-------
GraphListResult
GraphV2 metadata object containing information like node count.
"""

@abstractmethod
Expand Down Expand Up @@ -68,9 +115,10 @@ def filter(
job_id
Identifier for the computation.

Returns:
GraphWithFilterResult: tuple of the filtered graph object and the information like
graph name, node count, relationship count, etc.
Returns
-------
GraphWithFilterResult:
tuple of the filtered graph object and the information like graph name, node count, relationship count, etc.
"""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from typing import Any, NamedTuple, Type
from uuid import uuid4

from pandas import DataFrame

from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient
from graphdatascience.arrow_client.v2.job_client import JobClient
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
from graphdatascience.procedure_surface.api.base_result import BaseResult
Expand All @@ -29,6 +32,7 @@
)
from graphdatascience.procedure_surface.arrow.catalog.relationship_arrow_endpoints import RelationshipArrowEndpoints
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
from graphdatascience.query_runner.progress.progress_bar import NoOpProgressBar, ProgressBar, TqdmProgressBar
from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol
from graphdatascience.query_runner.query_runner import QueryRunner
from graphdatascience.query_runner.termination_flag import TerminationFlag
Expand All @@ -52,15 +56,6 @@ def __init__(
protocol_version = ProtocolVersionResolver(query_runner).resolve()
self._project_protocol = ProjectProtocol.select(protocol_version)

def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
graph_name: str | None = None
if isinstance(G, GraphV2):
graph_name = G.name()
elif isinstance(G, str):
graph_name = G

return self._graph_backend.list(graph_name)

def project(
self,
graph_name: str,
Expand Down Expand Up @@ -137,6 +132,79 @@ def project(

return GraphWithProjectResult(get_graph(graph_name, self._arrow_client), job_result)

def construct(
self,
graph_name: str,
nodes: DataFrame | list[DataFrame],
relationships: DataFrame | list[DataFrame] | None = None,
concurrency: int | None = None,
undirected_relationship_types: list[str] | None = None,
) -> GraphV2:
gds_arrow_client = GdsArrowClient(self._arrow_client)
job_client = JobClient()
termination_flag = TerminationFlag.create()

if self._show_progress:
progress_bar: ProgressBar = TqdmProgressBar(task_name="Constructing graph", relative_progress=0.0)
else:
progress_bar = NoOpProgressBar()

with progress_bar:
create_job_id: str = gds_arrow_client.create_graph(
graph_name=graph_name,
undirected_relationship_types=undirected_relationship_types or [],
concurrency=concurrency,
)
node_count = nodes.shape[0] if isinstance(nodes, DataFrame) else sum(df.shape[0] for df in nodes)
if isinstance(relationships, DataFrame):
rel_count = relationships.shape[0]
elif relationships is None:
rel_count = 0
relationships = []
else:
rel_count = sum(df.shape[0] for df in relationships)
total_count = node_count + rel_count

gds_arrow_client.upload_nodes(
create_job_id,
nodes,
progress_callback=lambda rows_imported: progress_bar.update(
sub_tasks_description="Uploading nodes", progress=rows_imported / total_count, status="Running"
),
termination_flag=termination_flag,
)

gds_arrow_client.node_load_done(create_job_id)

# skipping progress bar here as we have our own for the overall process
job_client.wait_for_job(
self._arrow_client,
create_job_id,
expected_status="RELATIONSHIP_LOADING",
termination_flag=termination_flag,
show_progress=False,
)

if rel_count > 0:
gds_arrow_client.upload_relationships(
create_job_id,
relationships,
progress_callback=lambda rows_imported: progress_bar.update(
sub_tasks_description="Uploading relationships",
progress=rows_imported / total_count,
status="Running",
),
termination_flag=termination_flag,
)

gds_arrow_client.relationship_load_done(create_job_id)

# will produce a second progress bar to show graph construction on the server side
job_client.wait_for_job(
self._arrow_client, create_job_id, termination_flag=termination_flag, show_progress=True
)
return get_graph(graph_name, self._arrow_client)

def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None:
graph_name = G.name() if isinstance(G, GraphV2) else G

Expand Down Expand Up @@ -212,6 +280,15 @@ def generate(
GraphGenerationStats(**JobClient.get_summary(self._arrow_client, job_id)),
)

def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
graph_name: str | None = None
if isinstance(G, GraphV2):
graph_name = G.name()
elif isinstance(G, str):
graph_name = G

return self._graph_backend.list(graph_name)

@property
def sample(self) -> GraphSamplingEndpoints:
return GraphSamplingArrowEndpoints(self._arrow_client, show_progress=self._show_progress)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
NodePropertySpec,
)
from graphdatascience.procedure_surface.api.default_values import ALL_LABELS
from graphdatascience.procedure_surface.cypher.catalog.utils import require_database
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
from graphdatascience.procedure_surface.utils.result_utils import join_db_node_properties, transpose_property_columns
from graphdatascience.query_runner.query_runner import QueryRunner
Expand All @@ -35,9 +36,7 @@ def stream(
db_node_properties: list[str] | None = None,
) -> DataFrame:
if self._gds_arrow_client is not None:
database = self._query_runner.database()
if database is None:
raise ValueError("The database is not set")
database = require_database(self._query_runner)

result = self._gds_arrow_client.get_node_properties(
G.name(), database, node_properties, node_labels, list_node_labels or False, concurrency
Expand Down
Loading