diff --git a/src/graphdatascience/arrow_client/v2/gds_arrow_client.py b/src/graphdatascience/arrow_client/v2/gds_arrow_client.py index 1dc7e14cb..9a6974b92 100644 --- a/src/graphdatascience/arrow_client/v2/gds_arrow_client.py +++ b/src/graphdatascience/arrow_client/v2/gds_arrow_client.py @@ -11,6 +11,7 @@ from graphdatascience.arrow_client.arrow_endpoint_version import ArrowEndpointVersion from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo +from graphdatascience.query_runner.termination_flag import TerminationFlag from ...procedure_surface.api.default_values import ALL_TYPES from ...procedure_surface.utils.config_converter import ConfigConverter @@ -328,6 +329,7 @@ def upload_nodes( data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame, batch_size: int = 10000, progress_callback: Callable[[int], None] = lambda x: None, + termination_flag: TerminationFlag | None = None, ) -> None: """ Uploads node data to the server for a given job. @@ -342,8 +344,12 @@ def upload_nodes( The number of rows per batch progress_callback A callback function that is called with the number of rows uploaded after each batch + termination_flag + A termination flag to cancel the upload if requested """ - self._upload_data("graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback) + self._upload_data( + "graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback, termination_flag + ) def upload_relationships( self, @@ -351,6 +357,7 @@ def upload_relationships( data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame, batch_size: int = 10000, progress_callback: Callable[[int], None] = lambda x: None, + termination_flag: TerminationFlag | None = None, ) -> None: """ Uploads relationship data to the server for a given job. @@ -365,8 +372,12 @@ def upload_relationships( The number of rows per batch progress_callback A callback function that is called with the number of rows uploaded after each batch + termination_flag + A termination flag to cancel the upload if requested """ - self._upload_data("graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback) + self._upload_data( + "graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback, termination_flag + ) def upload_triplets( self, @@ -374,6 +385,7 @@ def upload_triplets( data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame, batch_size: int = 10000, progress_callback: Callable[[int], None] = lambda x: None, + termination_flag: TerminationFlag | None = None, ) -> None: """ Uploads triplet data to the server for a given job. @@ -388,8 +400,10 @@ def upload_triplets( The number of rows per batch progress_callback A callback function that is called with the number of rows uploaded after each batch + termination_flag + A termination flag to cancel the upload if requested """ - self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback) + self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback, termination_flag) def abort_job(self, job_id: str) -> None: """ @@ -464,6 +478,7 @@ def _upload_data( data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame, batch_size: int = 10000, progress_callback: Callable[[int], None] = lambda x: None, + termination_flag: TerminationFlag | None = None, ) -> None: match data: case pyarrow.Table(): @@ -490,6 +505,11 @@ def upload_batch(p: RecordBatch) -> None: with put_stream: for partition in batches: + if termination_flag is not None and termination_flag.is_set(): + self.abort_job(job_id) + # closing the put_stream should raise an error. this is a safeguard to always signal the termination to the user. + raise RuntimeError(f"Upload for job '{job_id}' was aborted via termination flag.") + upload_batch(partition) ack_stream.read() progress_callback(partition.num_rows) diff --git a/src/graphdatascience/arrow_client/v2/job_client.py b/src/graphdatascience/arrow_client/v2/job_client.py index 7800206fc..13ffd738e 100644 --- a/src/graphdatascience/arrow_client/v2/job_client.py +++ b/src/graphdatascience/arrow_client/v2/job_client.py @@ -39,20 +39,26 @@ def wait_for_job( client: AuthenticatedArrowClient, job_id: str, show_progress: bool, + expected_status: str | None = None, termination_flag: TerminationFlag | None = None, ) -> None: progress_bar: TqdmProgressBar | None = None + def check_expected_status(status: JobStatus) -> bool: + return job_status.succeeded() if expected_status is None else status.status == expected_status + if termination_flag is None: termination_flag = TerminationFlag.create() - for attempt in Retrying(retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5)): + for attempt in Retrying( + retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5), reraise=True + ): with attempt: termination_flag.assert_running() job_status = self.get_job_status(client, job_id) - if job_status.succeeded() or job_status.aborted(): + if check_expected_status(job_status) or job_status.aborted(): if progress_bar: progress_bar.finish(success=job_status.succeeded()) return diff --git a/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py b/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py index 82111490a..2181b9747 100644 --- a/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py +++ b/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py @@ -4,6 +4,8 @@ from types import TracebackType from typing import NamedTuple, Type +from pandas import DataFrame + from graphdatascience.procedure_surface.api.base_result import BaseResult from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 from graphdatascience.procedure_surface.api.catalog.graph_info import GraphInfo, GraphInfoWithDegrees @@ -14,17 +16,58 @@ class CatalogEndpoints(ABC): + @abstractmethod + def construct( + self, + graph_name: str, + nodes: DataFrame | list[DataFrame], + relationships: DataFrame | list[DataFrame] | None = None, + concurrency: int | None = None, + undirected_relationship_types: list[str] | None = None, + ) -> GraphV2: + """Construct a graph from a list of node and relationship dataframes. + + Parameters + ---------- + graph_name + Name of the graph to construct + nodes + Node dataframes. A dataframe should follow the schema: + + - `nodeId` to identify uniquely the node overall dataframes + - `labels` to specify the labels of the node as a list of strings (optional) + - other columns are treated as node properties + relationships + Relationship dataframes. A dataframe should follow the schema: + + - `sourceNodeId` to identify the start node of the relationship + - `targetNodeId` to identify the end node of the relationship + - `relationshipType` to specify the type of the relationship (optional) + - other columns are treated as relationship properties + concurrency + Number of concurrent threads to use. + undirected_relationship_types + List of relationship types to treat as undirected. + + Returns + ------- + GraphV2 + Constructed graph object. + """ + @abstractmethod def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]: """List graphs in the graph catalog. - Args: - G (GraphV2 | str | None, optional): GraphV2 object or name to filter results. - If None, list all graphs. Defaults to None. + Parameters + ---------- + G + GraphV2 object or name to filter results. If None, list all graphs. - Returns: - list[GraphListResult]: List of graph metadata objects containing information like - graph name, node count, relationship count, etc. + Returns + ------- + list[GraphInfoWithDegrees] + List of graph metadata objects containing information like node count. """ pass @@ -32,13 +75,17 @@ def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]: def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None: """Drop a graph from the graph catalog. - Args: - G (GraphV2 | str): GraphV2 object or name to drop. - fail_if_missing (bool): Whether to fail if the graph is missing. Defaults to True. + Parameters + ---------- + G + Graph to drop by name of object. + fail_if_missing + Whether to fail if the graph is missing - Returns: - GraphListResult: GraphV2 metadata object containing information like - graph name, node count, relationship count, etc. + Returns + ------- + GraphListResult + GraphV2 metadata object containing information like node count. """ @abstractmethod @@ -68,9 +115,10 @@ def filter( job_id Identifier for the computation. - Returns: - GraphWithFilterResult: tuple of the filtered graph object and the information like - graph name, node count, relationship count, etc. + Returns + ------- + GraphWithFilterResult: + tuple of the filtered graph object and the information like graph name, node count, relationship count, etc. """ pass diff --git a/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py b/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py index b72e56148..89131d4ca 100644 --- a/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py +++ b/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py @@ -5,7 +5,10 @@ from typing import Any, NamedTuple, Type from uuid import uuid4 +from pandas import DataFrame + from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient from graphdatascience.arrow_client.v2.job_client import JobClient from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient from graphdatascience.procedure_surface.api.base_result import BaseResult @@ -29,6 +32,7 @@ ) from graphdatascience.procedure_surface.arrow.catalog.relationship_arrow_endpoints import RelationshipArrowEndpoints from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter +from graphdatascience.query_runner.progress.progress_bar import NoOpProgressBar, ProgressBar, TqdmProgressBar from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol from graphdatascience.query_runner.query_runner import QueryRunner from graphdatascience.query_runner.termination_flag import TerminationFlag @@ -52,15 +56,6 @@ def __init__( protocol_version = ProtocolVersionResolver(query_runner).resolve() self._project_protocol = ProjectProtocol.select(protocol_version) - def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]: - graph_name: str | None = None - if isinstance(G, GraphV2): - graph_name = G.name() - elif isinstance(G, str): - graph_name = G - - return self._graph_backend.list(graph_name) - def project( self, graph_name: str, @@ -137,6 +132,79 @@ def project( return GraphWithProjectResult(get_graph(graph_name, self._arrow_client), job_result) + def construct( + self, + graph_name: str, + nodes: DataFrame | list[DataFrame], + relationships: DataFrame | list[DataFrame] | None = None, + concurrency: int | None = None, + undirected_relationship_types: list[str] | None = None, + ) -> GraphV2: + gds_arrow_client = GdsArrowClient(self._arrow_client) + job_client = JobClient() + termination_flag = TerminationFlag.create() + + if self._show_progress: + progress_bar: ProgressBar = TqdmProgressBar(task_name="Constructing graph", relative_progress=0.0) + else: + progress_bar = NoOpProgressBar() + + with progress_bar: + create_job_id: str = gds_arrow_client.create_graph( + graph_name=graph_name, + undirected_relationship_types=undirected_relationship_types or [], + concurrency=concurrency, + ) + node_count = nodes.shape[0] if isinstance(nodes, DataFrame) else sum(df.shape[0] for df in nodes) + if isinstance(relationships, DataFrame): + rel_count = relationships.shape[0] + elif relationships is None: + rel_count = 0 + relationships = [] + else: + rel_count = sum(df.shape[0] for df in relationships) + total_count = node_count + rel_count + + gds_arrow_client.upload_nodes( + create_job_id, + nodes, + progress_callback=lambda rows_imported: progress_bar.update( + sub_tasks_description="Uploading nodes", progress=rows_imported / total_count, status="Running" + ), + termination_flag=termination_flag, + ) + + gds_arrow_client.node_load_done(create_job_id) + + # skipping progress bar here as we have our own for the overall process + job_client.wait_for_job( + self._arrow_client, + create_job_id, + expected_status="RELATIONSHIP_LOADING", + termination_flag=termination_flag, + show_progress=False, + ) + + if rel_count > 0: + gds_arrow_client.upload_relationships( + create_job_id, + relationships, + progress_callback=lambda rows_imported: progress_bar.update( + sub_tasks_description="Uploading relationships", + progress=rows_imported / total_count, + status="Running", + ), + termination_flag=termination_flag, + ) + + gds_arrow_client.relationship_load_done(create_job_id) + + # will produce a second progress bar to show graph construction on the server side + job_client.wait_for_job( + self._arrow_client, create_job_id, termination_flag=termination_flag, show_progress=True + ) + return get_graph(graph_name, self._arrow_client) + def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None: graph_name = G.name() if isinstance(G, GraphV2) else G @@ -212,6 +280,15 @@ def generate( GraphGenerationStats(**JobClient.get_summary(self._arrow_client, job_id)), ) + def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]: + graph_name: str | None = None + if isinstance(G, GraphV2): + graph_name = G.name() + elif isinstance(G, str): + graph_name = G + + return self._graph_backend.list(graph_name) + @property def sample(self) -> GraphSamplingEndpoints: return GraphSamplingArrowEndpoints(self._arrow_client, show_progress=self._show_progress) diff --git a/src/graphdatascience/procedure_surface/cypher/catalog/node_properties_cypher_endpoints.py b/src/graphdatascience/procedure_surface/cypher/catalog/node_properties_cypher_endpoints.py index ae24b6a05..ef0409c0c 100644 --- a/src/graphdatascience/procedure_surface/cypher/catalog/node_properties_cypher_endpoints.py +++ b/src/graphdatascience/procedure_surface/cypher/catalog/node_properties_cypher_endpoints.py @@ -10,6 +10,7 @@ NodePropertySpec, ) from graphdatascience.procedure_surface.api.default_values import ALL_LABELS +from graphdatascience.procedure_surface.cypher.catalog.utils import require_database from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter from graphdatascience.procedure_surface.utils.result_utils import join_db_node_properties, transpose_property_columns from graphdatascience.query_runner.query_runner import QueryRunner @@ -35,9 +36,7 @@ def stream( db_node_properties: list[str] | None = None, ) -> DataFrame: if self._gds_arrow_client is not None: - database = self._query_runner.database() - if database is None: - raise ValueError("The database is not set") + database = require_database(self._query_runner) result = self._gds_arrow_client.get_node_properties( G.name(), database, node_properties, node_labels, list_node_labels or False, concurrency diff --git a/src/graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py b/src/graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py index b19a5eca1..ecda2359a 100644 --- a/src/graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py +++ b/src/graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py @@ -13,6 +13,7 @@ RelationshipsWriteResult, ) from graphdatascience.procedure_surface.api.default_values import ALL_TYPES +from graphdatascience.procedure_surface.cypher.catalog.utils import require_database from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter from graphdatascience.query_runner.query_runner import QueryRunner @@ -36,9 +37,7 @@ def stream( effective_rel_types = relationship_types if relationship_types is not None else ["*"] if self._gds_arrow_client is not None: - database = self._query_runner.database() - if database is None: - raise ValueError("The database is not set") + database = require_database(self._query_runner) if relationship_properties: return self._gds_arrow_client.get_relationship_properties( diff --git a/src/graphdatascience/procedure_surface/cypher/catalog/utils.py b/src/graphdatascience/procedure_surface/cypher/catalog/utils.py new file mode 100644 index 000000000..74a9aa613 --- /dev/null +++ b/src/graphdatascience/procedure_surface/cypher/catalog/utils.py @@ -0,0 +1,12 @@ +from graphdatascience.query_runner.query_runner import QueryRunner + + +def require_database(query_runner: QueryRunner) -> str: + database = query_runner.database() + if database is None: + raise ValueError( + "For this call you must have explicitly specified a valid Neo4j database to target, " + "using `gds.set_database`." + ) + + return database diff --git a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py index da218de60..0c217192e 100644 --- a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py +++ b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py @@ -4,6 +4,8 @@ from types import TracebackType from typing import Any, NamedTuple, Type +from pandas import DataFrame + from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient from graphdatascience.procedure_surface.api.catalog.catalog_endpoints import ( CatalogEndpoints, @@ -16,7 +18,11 @@ from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 from graphdatascience.procedure_surface.api.catalog.graph_info import GraphInfo, GraphInfoWithDegrees from graphdatascience.procedure_surface.api.catalog.graph_sampling_endpoints import GraphSamplingEndpoints -from graphdatascience.procedure_surface.cypher.catalog.graph_backend_cypher import get_graph +from graphdatascience.procedure_surface.cypher.catalog.graph_backend_cypher import CypherGraphBackend, get_graph +from graphdatascience.procedure_surface.cypher.catalog.utils import require_database +from graphdatascience.query_runner.arrow_graph_constructor import ArrowGraphConstructor +from graphdatascience.query_runner.cypher_graph_constructor import CypherGraphConstructor +from graphdatascience.query_runner.graph_constructor import GraphConstructor from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from ...call_parameters import CallParameters @@ -30,14 +36,51 @@ class CatalogCypherEndpoints(CatalogEndpoints): def __init__(self, cypher_runner: Neo4jQueryRunner, arrow_client: GdsArrowClient | None = None): - self.cypher_runner = cypher_runner + self._cypher_runner = cypher_runner self._arrow_client = arrow_client + def construct( + self, + graph_name: str, + nodes: DataFrame | list[DataFrame], + relationships: DataFrame | list[DataFrame] | None = None, + concurrency: int | None = None, + undirected_relationship_types: list[str] | None = None, + ) -> GraphV2: + if isinstance(nodes, DataFrame): + nodes = [nodes] + if relationships is None: + relationships = [] + elif isinstance(relationships, DataFrame): + relationships = [relationships] + + graph_constructor: GraphConstructor + if self._arrow_client is not None: + database = require_database(self._cypher_runner) + + graph_constructor = ArrowGraphConstructor( + database=database, + graph_name=graph_name, + flight_client=self._arrow_client, + concurrency=concurrency, + undirected_relationship_types=undirected_relationship_types, + ) + else: + graph_constructor = CypherGraphConstructor( + query_runner=self._cypher_runner, + graph_name=graph_name, + concurrency=concurrency, + undirected_relationship_types=undirected_relationship_types, + ) + + graph_constructor.run(node_dfs=nodes, relationship_dfs=relationships) + return GraphV2(name=graph_name, backend=CypherGraphBackend(graph_name, self._cypher_runner)) + def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]: graph_name = G if isinstance(G, str) else G.name() if G is not None else None params = CallParameters(graphName=graph_name) if graph_name else CallParameters() - result = self.cypher_runner.call_procedure(endpoint="gds.graph.list", params=params) + result = self._cypher_runner.call_procedure(endpoint="gds.graph.list", params=params) return [GraphInfoWithDegrees(**row.to_dict()) for _, row in result.iterrows()] def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None: @@ -49,7 +92,7 @@ def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | No else CallParameters(graphName=graph_name) ) - result = self.cypher_runner.call_procedure(endpoint="gds.graph.drop", params=params) + result = self._cypher_runner.call_procedure(endpoint="gds.graph.drop", params=params) if len(result) > 0: return GraphInfo(**result.iloc[0].to_dict()) else: @@ -85,11 +128,11 @@ def project( ) params.ensure_job_id_in_config() - result = self.cypher_runner.call_procedure( + result = self._cypher_runner.call_procedure( endpoint="gds.graph.project", params=params, logging=log_progress ).squeeze() project_result = GraphProjectResult(**result.to_dict()) - return GraphWithProjectResult(get_graph(project_result.graph_name, self.cypher_runner), project_result) + return GraphWithProjectResult(get_graph(project_result.graph_name, self._cypher_runner), project_result) def filter( self, @@ -115,10 +158,10 @@ def filter( ) params.ensure_job_id_in_config() - result = self.cypher_runner.call_procedure( + result = self._cypher_runner.call_procedure( endpoint="gds.graph.filter", params=params, logging=log_progress ).squeeze() - return GraphWithFilterResult(get_graph(graph_name, self.cypher_runner), GraphFilterResult(**result.to_dict())) + return GraphWithFilterResult(get_graph(graph_name, self._cypher_runner), GraphFilterResult(**result.to_dict())) def generate( self, @@ -159,28 +202,28 @@ def generate( params.ensure_job_id_in_config() - result = self.cypher_runner.call_procedure( + result = self._cypher_runner.call_procedure( endpoint="gds.graph.generate", params=params, logging=log_progress ).squeeze() return GraphWithGenerationStats( - get_graph(graph_name, self.cypher_runner), GraphGenerationStats(**result.to_dict()) + get_graph(graph_name, self._cypher_runner), GraphGenerationStats(**result.to_dict()) ) @property def sample(self) -> GraphSamplingEndpoints: - return GraphSamplingCypherEndpoints(self.cypher_runner) + return GraphSamplingCypherEndpoints(self._cypher_runner) @property def node_labels(self) -> NodeLabelCypherEndpoints: - return NodeLabelCypherEndpoints(self.cypher_runner) + return NodeLabelCypherEndpoints(self._cypher_runner) @property def node_properties(self) -> NodePropertiesCypherEndpoints: - return NodePropertiesCypherEndpoints(self.cypher_runner, self._arrow_client) + return NodePropertiesCypherEndpoints(self._cypher_runner, self._arrow_client) @property def relationships(self) -> RelationshipCypherEndpoints: - return RelationshipCypherEndpoints(self.cypher_runner, self._arrow_client) + return RelationshipCypherEndpoints(self._cypher_runner, self._arrow_client) class GraphProjectResult(BaseResult): diff --git a/src/graphdatascience/query_runner/arrow_graph_constructor.py b/src/graphdatascience/query_runner/arrow_graph_constructor.py index 7e49a197d..e71680ac3 100644 --- a/src/graphdatascience/query_runner/arrow_graph_constructor.py +++ b/src/graphdatascience/query_runner/arrow_graph_constructor.py @@ -22,8 +22,8 @@ def __init__( database: str, graph_name: str, flight_client: GdsArrowClient, - concurrency: int, - undirected_relationship_types: list[str] | None, + concurrency: int | None = None, + undirected_relationship_types: list[str] | None = None, chunk_size: int = 10_000, ): self._database = database diff --git a/src/graphdatascience/query_runner/cypher_graph_constructor.py b/src/graphdatascience/query_runner/cypher_graph_constructor.py index d1c8f42bf..12ca8edd7 100644 --- a/src/graphdatascience/query_runner/cypher_graph_constructor.py +++ b/src/graphdatascience/query_runner/cypher_graph_constructor.py @@ -58,14 +58,13 @@ def __init__( self, query_runner: QueryRunner, graph_name: str, - concurrency: int, - undirected_relationship_types: list[str] | None, - server_version: ServerVersion, + concurrency: int | None = None, + undirected_relationship_types: list[str] | None = None, ): self._query_runner = query_runner self._concurrency = concurrency self._graph_name = graph_name - self._server_version = server_version + self._server_version = query_runner.server_version() self._undirected_relationship_types = undirected_relationship_types def run(self, node_dfs: list[DataFrame], relationship_dfs: list[DataFrame]) -> None: @@ -81,9 +80,9 @@ def run(self, node_dfs: list[DataFrame], relationship_dfs: list[DataFrame]) -> N self.CypherProjectionRunner( self._query_runner, self._graph_name, + self._server_version, self._concurrency, self._undirected_relationship_types, - self._server_version, ).run(node_dfs, relationship_dfs) else: assert not self._undirected_relationship_types, "This should have been raised earlier." @@ -130,9 +129,9 @@ def __init__( self, query_runner: QueryRunner, graph_name: str, - concurrency: int, - undirected_relationship_types: list[str] | None, server_version: ServerVersion, + concurrency: int | None = None, + undirected_relationship_types: list[str] | None = None, ): self._query_runner = query_runner self._concurrency = concurrency @@ -359,9 +358,9 @@ def rels_config_part(self, rel_cols: list[EntityColumnSchema], rel_properties_ke return rels_config_fields class LegacyCypherProjectionRunner: - def __init__(self, query_runner: QueryRunner, graph_name: str, concurrency: int): + def __init__(self, query_runner: QueryRunner, graph_name: str, concurrency: int | None = None): self._query_runner = query_runner - self._concurrency = concurrency + self._concurrency = concurrency if concurrency is not None else 4 self._graph_name = graph_name def run(self, node_df: DataFrame, relationship_df: DataFrame) -> None: diff --git a/src/graphdatascience/query_runner/neo4j_query_runner.py b/src/graphdatascience/query_runner/neo4j_query_runner.py index 66a6ec1f9..32a93a245 100644 --- a/src/graphdatascience/query_runner/neo4j_query_runner.py +++ b/src/graphdatascience/query_runner/neo4j_query_runner.py @@ -375,9 +375,7 @@ def __del__(self) -> None: def create_graph_constructor( self, graph_name: str, concurrency: int, undirected_relationship_types: list[str] | None ) -> GraphConstructor: - return CypherGraphConstructor( - self, graph_name, concurrency, undirected_relationship_types, self.server_version() - ) + return CypherGraphConstructor(self, graph_name, concurrency, undirected_relationship_types) def set_show_progress(self, show_progress: bool) -> None: self._show_progress = show_progress diff --git a/src/graphdatascience/query_runner/progress/progress_bar.py b/src/graphdatascience/query_runner/progress/progress_bar.py index 0af4a8193..b860e6630 100644 --- a/src/graphdatascience/query_runner/progress/progress_bar.py +++ b/src/graphdatascience/query_runner/progress/progress_bar.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod from types import TracebackType from typing import Any, Type @@ -8,7 +9,35 @@ from graphdatascience.query_runner.progress.progress_provider import TaskWithProgress -class TqdmProgressBar: +class ProgressBar(ABC): + @abstractmethod + def __enter__(self) -> ProgressBar: + pass + + @abstractmethod + def __exit__( + self, + exception_type: Type[BaseException] | None, + exception_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + @abstractmethod + def update( + self, + status: str, + progress: float | None, + sub_tasks_description: str | None = None, + ) -> None: + pass + + @abstractmethod + def finish(self, success: bool) -> None: + pass + + +class TqdmProgressBar(ProgressBar): def __init__(self, task_name: str, relative_progress: float | None, bar_options: dict[str, Any] = {}): root_task_name = task_name if relative_progress is None: # Qualitative progress report @@ -68,3 +97,27 @@ def _relative_progress(task: TaskWithProgress) -> float | None: return float(task.progress_percent.removesuffix("%")) except ValueError: return None + + +class NoOpProgressBar(ProgressBar): + def __enter__(self: NoOpProgressBar) -> NoOpProgressBar: + return self + + def __exit__( + self, + exception_type: Type[BaseException] | None, + exception_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def update( + self, + status: str, + progress: float | None, + sub_tasks_description: str | None = None, + ) -> None: + pass + + def finish(self, success: bool) -> None: + pass diff --git a/tests/integrationV2/arrow_client/v2/test_gds_arrow_client_v2.py b/tests/integrationV2/arrow_client/v2/test_gds_arrow_client_v2.py index fa487acf7..721560d08 100644 --- a/tests/integrationV2/arrow_client/v2/test_gds_arrow_client_v2.py +++ b/tests/integrationV2/arrow_client/v2/test_gds_arrow_client_v2.py @@ -10,6 +10,8 @@ from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient from graphdatascience.procedure_surface.api.catalog import GraphV2 from graphdatascience.procedure_surface.arrow.catalog import CatalogArrowEndpoints +from graphdatascience.procedure_surface.arrow.catalog.graph_backend_arrow import get_graph +from graphdatascience.query_runner.termination_flag import TerminationFlag from tests.integrationV2.conftest import GdsSessionConnectionInfo, create_arrow_client, start_session from tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph @@ -89,17 +91,32 @@ def test_project_from_triplets(arrow_client: AuthenticatedArrowClient, gds_arrow {"sourceNode": np.array([1, 2, 3], dtype=np.int64), "targetNode": np.array([4, 5, 6], dtype=np.int64)} ) - job_id = gds_arrow_client.create_graph_from_triplets("triplets") + graph_name = "triplets" + + job_id = gds_arrow_client.create_graph_from_triplets(graph_name) gds_arrow_client.upload_triplets(job_id, df) gds_arrow_client.triplet_load_done(job_id) while gds_arrow_client.job_status(job_id).status != "Done": pass - listing = CatalogArrowEndpoints(arrow_client).list("triplets")[0] - assert listing.node_count == 6 - assert listing.relationship_count == 3 - assert listing.graph_name == "triplets" + with get_graph(graph_name, arrow_client) as G: + assert G.node_count() == 6 + assert G.relationship_count() == 3 + assert G.name() == graph_name + + +def test_project_from_triplets_interrupted(gds_arrow_client: GdsArrowClient) -> None: + df = pd.DataFrame( + {"sourceNode": np.array([1, 2, 3], dtype=np.int64), "targetNode": np.array([4, 5, 6], dtype=np.int64)} + ) + + termination_flag = TerminationFlag.create() + termination_flag.set() + + job_id = gds_arrow_client.create_graph_from_triplets("triplets") + with pytest.raises(Exception, match=".*was aborted.*"): + gds_arrow_client.upload_triplets(job_id, df, termination_flag=termination_flag) def test_project_from_tables(arrow_client: AuthenticatedArrowClient, gds_arrow_client: GdsArrowClient) -> None: diff --git a/tests/integrationV2/procedure_surface/arrow/catalog/test_catalog_arrow_endpoints.py b/tests/integrationV2/procedure_surface/arrow/catalog/test_catalog_arrow_endpoints.py index 24faa3fdc..23396ab77 100644 --- a/tests/integrationV2/procedure_surface/arrow/catalog/test_catalog_arrow_endpoints.py +++ b/tests/integrationV2/procedure_surface/arrow/catalog/test_catalog_arrow_endpoints.py @@ -2,6 +2,7 @@ from typing import Generator import pytest +from pandas import DataFrame from pyarrow import ArrowKeyError from pyarrow._flight import FlightServerError @@ -98,6 +99,36 @@ def test_projection(arrow_client: AuthenticatedArrowClient, query_runner: QueryR endpoints.drop("g", fail_if_missing=False) +def test_construct(arrow_client: AuthenticatedArrowClient) -> None: + nodes = DataFrame( + { + "nodeId": [0, 1], + "labels": [["A"], ["B"]], + "propA": [1337, 42.1], + } + ) + relationships = DataFrame( + { + "sourceNodeId": [0, 1], + "targetNodeId": [1, 0], + "relationshipType": ["REL", "REL2"], + "relPropA": [1337.2, 42], + } + ) + + endpoints = CatalogArrowEndpoints(arrow_client) + with endpoints.construct( + graph_name="g", + nodes=nodes, + relationships=relationships, + ) as G: + assert G.name() == "g" + assert G.node_count() == 2 + assert G.relationship_count() == 2 + + assert len(endpoints.list("g")) == 1 + + def test_graph_filter(catalog_endpoints: CatalogArrowEndpoints, sample_graph: GraphV2) -> None: try: G, result = catalog_endpoints.filter( diff --git a/tests/integrationV2/procedure_surface/cypher/test_catalog_cypher_endpoints.py b/tests/integrationV2/procedure_surface/cypher/test_catalog_cypher_endpoints.py index a45a66e8b..638084643 100644 --- a/tests/integrationV2/procedure_surface/cypher/test_catalog_cypher_endpoints.py +++ b/tests/integrationV2/procedure_surface/cypher/test_catalog_cypher_endpoints.py @@ -2,6 +2,7 @@ from typing import Generator import pytest +from pandas import DataFrame from graphdatascience import QueryRunner from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient @@ -48,6 +49,13 @@ def catalog_endpoints( yield CatalogCypherEndpoints(query_runner, gds_arrow_client) +@pytest.fixture +def catalog_endpoints_arrow( + query_runner: Neo4jQueryRunner, gds_arrow_client: GdsArrowClient +) -> Generator[CatalogCypherEndpoints, None, None]: + yield CatalogCypherEndpoints(query_runner, gds_arrow_client) + + def test_list_with_graph(catalog_endpoints: CatalogCypherEndpoints, sample_graph: GraphV2) -> None: results = catalog_endpoints.list(G=sample_graph) @@ -189,3 +197,62 @@ def test_graph_generate(catalog_endpoints: CatalogCypherEndpoints) -> None: assert result.relationship_property == RelationshipPropertySpec.fixed("weight", 42) assert catalog_endpoints.list("generated") is not None + + +@pytest.mark.filterwarnings("ignore: .*use Apache Arrow.*") +def test_graph_construct_cypher(catalog_endpoints: CatalogCypherEndpoints) -> None: + nodes = DataFrame( + { + "nodeId": [0, 1], + "labels": [["A"], ["B"]], + "propA": [1337, 42.1], + } + ) + relationships = DataFrame( + { + "sourceNodeId": [0, 1], + "targetNodeId": [1, 0], + "relationshipType": ["REL", "REL2"], + "relPropA": [1337.2, 42], + } + ) + + with catalog_endpoints.construct( + graph_name="constructed_graph", + nodes=nodes, + relationships=relationships, + ) as G: + assert G.name() == "constructed_graph" + assert G.node_count() == 2 + assert G.relationship_count() == 2 + assert "REL" in G.relationship_types() + + +def test_graph_construct_arrow_v1(catalog_endpoints_arrow: CatalogCypherEndpoints) -> None: + nodes = DataFrame( + { + "nodeId": [0, 1], + "labels": [["A"], ["B"]], + "propA": [1337, 42.1], + } + ) + relationships = DataFrame( + { + "sourceNodeId": [0, 1], + "targetNodeId": [1, 0], + "relationshipType": ["REL", "REL2"], + "relPropA": [1337.2, 42], + } + ) + + catalog_endpoints_arrow._cypher_runner.set_database("neo4j") + + with catalog_endpoints_arrow.construct( + graph_name="constructed_graph", + nodes=nodes, + relationships=relationships, + ) as G: + assert G.name() == "constructed_graph" + assert G.node_count() == 2 + assert G.relationship_count() == 2 + assert "REL" in G.relationship_types() diff --git a/tests/unit/arrow_client/V2/test_job_client.py b/tests/unit/arrow_client/V2/test_job_client.py index c6c5fc296..b5c9ace35 100644 --- a/tests/unit/arrow_client/V2/test_job_client.py +++ b/tests/unit/arrow_client/V2/test_job_client.py @@ -99,6 +99,35 @@ def test_wait_for_job_waits_for_completion(mocker: MockerFixture) -> None: assert mock_client.do_action_with_retry.call_count == 2 +def test_wait_for_job_waits_for_expected_status(mocker: MockerFixture) -> None: + mock_client = mocker.Mock() + job_id = "test-job-waiting" + status_running = JobStatus( + jobId=job_id, + progress=0.5, + status="RUNNING", + description="", + ) + status_done = JobStatus( + jobId=job_id, + progress=1.0, + status="RELATIONSHIP_LOADING", + description="", + ) + + do_action_with_retry = mocker.Mock() + do_action_with_retry.side_effect = [ + iter([ArrowTestResult(status_running.dump_camel())]), + iter([ArrowTestResult(status_done.dump_camel())]), + ] + + mock_client.do_action_with_retry = do_action_with_retry + + JobClient().wait_for_job(mock_client, job_id, show_progress=False, expected_status="RELATIONSHIP_LOADING") + + assert mock_client.do_action_with_retry.call_count == 2 + + def test_wait_for_job_waits_for_aborted(mocker: MockerFixture) -> None: mock_client = mocker.Mock() job_id = "test-job-waiting" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index db9726a0a..d9a51d080 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -160,9 +160,7 @@ def set_show_progress(self, show_progress: bool) -> None: def create_graph_constructor( self, graph_name: str, concurrency: int, undirected_relationship_types: list[str] | None ) -> GraphConstructor: - return CypherGraphConstructor( - self, graph_name, concurrency, undirected_relationship_types, self._server_version - ) + return CypherGraphConstructor(self, graph_name, concurrency, undirected_relationship_types) def cloneWithoutRouting(self, host: str, port: int) -> QueryRunner: return self diff --git a/tests/unit/procedure_surface/arrow/__init__.py b/tests/unit/procedure_surface/arrow/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/procedure_surface/arrow/test_catalog_arrow_endpoints.py b/tests/unit/procedure_surface/arrow/test_catalog_arrow_endpoints.py new file mode 100644 index 000000000..19bd03206 --- /dev/null +++ b/tests/unit/procedure_surface/arrow/test_catalog_arrow_endpoints.py @@ -0,0 +1,120 @@ +from contextlib import ExitStack +from unittest import mock + +from pandas import DataFrame +from pytest_mock import MockerFixture + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.api_types import JobStatus +from graphdatascience.procedure_surface.arrow.catalog.catalog_arrow_endpoints import CatalogArrowEndpoints +from tests.unit.arrow_client.arrow_test_utils import ArrowTestResult + + +def test_construct_with_no_rels(mocker: MockerFixture) -> None: + arrow_client = mocker.Mock(spec=AuthenticatedArrowClient) + job_id = "job-123" + + relationship_loading_done_status = JobStatus( + jobId=job_id, + status="RELATIONSHIP_LOADING", + progress=-1, + description="", + ) + construct_done_status = JobStatus( + jobId=job_id, + status="Done", + progress=-1, + description="", + ) + + do_action_with_retry = mocker.Mock() + do_action_with_retry.side_effect = [ + iter([ArrowTestResult(relationship_loading_done_status.dump_camel())]), + iter([ArrowTestResult(construct_done_status.dump_camel())]), + ] + + arrow_client.do_action_with_retry = do_action_with_retry + + endpoints = CatalogArrowEndpoints(arrow_client=arrow_client) + + nodes = DataFrame( + { + "nodeId": [0, 1], + "labels": [["A"], ["B"]], + "propA": [1337, 42.1], + } + ) + with patch_gds_arrow_client(job_id): + G = endpoints.construct(graph_name="g", nodes=nodes, relationships=[]) + assert G.name() == "g" + + +def test_construct_with_df_lists(mocker: MockerFixture) -> None: + arrow_client = mocker.Mock(spec=AuthenticatedArrowClient) + job_id = "foo" + relationship_loading_done_status = JobStatus( + jobId=job_id, + status="RELATIONSHIP_LOADING", + progress=-1, + description="", + ) + construct_done_status = JobStatus( + jobId=job_id, + status="Done", + progress=-1, + description="", + ) + + do_action_with_retry = mocker.Mock() + do_action_with_retry.side_effect = [ + iter([ArrowTestResult(relationship_loading_done_status.dump_camel())]), + iter([ArrowTestResult(construct_done_status.dump_camel())]), + ] + arrow_client.do_action_with_retry = do_action_with_retry + + endpoints = CatalogArrowEndpoints(arrow_client=arrow_client) + + nodes = [ + DataFrame({"nodeId": [0, 1], "labels": ["a", "a"], "property": [6.0, 7.0]}), + DataFrame({"nodeId": [2, 3], "labels": ["b", "b"], "q": [-500, -400]}), + ] + relationships = [ + DataFrame( + {"sourceNodeId": [0, 1], "targetNodeId": [1, 2], "relationshipType": ["A", "A"], "weights": [0.2, 0.3]} + ), + DataFrame({"sourceNodeId": [2, 3], "targetNodeId": [3, 0], "relationshipType": ["B", "B"]}), + ] + with patch_gds_arrow_client(job_id): + G = endpoints.construct(graph_name="g", nodes=nodes, relationships=relationships) + assert G.name() == "g" + + +def patch_gds_arrow_client(create_graph_job_id: str) -> ExitStack: + exit_stack = ExitStack() + patches = [ + mock.patch( + "graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.create_graph", + return_value=create_graph_job_id, + ), + mock.patch( + "graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.upload_nodes", + return_value=None, + ), + mock.patch( + "graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.upload_relationships", + return_value=None, + ), + mock.patch( + "graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.node_load_done", + return_value=None, + ), + mock.patch( + "graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.relationship_load_done", + return_value=None, + ), + ] + + for p in patches: + exit_stack.enter_context(p) + + return exit_stack