diff --git a/graphdatascience/graph/graph_alpha_project_runner.py b/graphdatascience/graph/graph_alpha_project_runner.py index ae2cfce92..c5afd1a05 100644 --- a/graphdatascience/graph/graph_alpha_project_runner.py +++ b/graphdatascience/graph/graph_alpha_project_runner.py @@ -1,16 +1,18 @@ -from typing import Any +from typing import Any, Tuple from pandas import Series from graphdatascience.error.illegal_attr_checker import IllegalAttrChecker +from graphdatascience.graph.graph_object import Graph from graphdatascience.server_version.compatible_with import compatible_with from graphdatascience.server_version.server_version import ServerVersion class GraphAlphaProjectRunner(IllegalAttrChecker): @compatible_with("remote", min_inclusive=ServerVersion(2, 4, 0)) - def remote(self, graph_name: str, query: str, remote_database: str, **config: Any) -> "Series[Any]": + def remote(self, graph_name: str, query: str, remote_database: str, **config: Any) -> Tuple[Graph, "Series[Any]"]: self._namespace += ".remote" - query = f"CALL {self._namespace}($graph_name, $query, $token, $host, $remote_database, $config)" + procedure_query = f"CALL {self._namespace}($graph_name, $query, $token, $host, $remote_database, $config)" params = {"graph_name": graph_name, "query": query, "remote_database": remote_database, "config": config} - return self._query_runner.run_query(query, params).squeeze() # type: ignore + result = self._query_runner.run_query(procedure_query, params).squeeze() + return Graph(graph_name, self._query_runner, self._server_version), result diff --git a/graphdatascience/query_runner/aura_db_arrow_query_runner.py b/graphdatascience/query_runner/aura_db_arrow_query_runner.py index 3947ccfdc..98bd79fb7 100644 --- a/graphdatascience/query_runner/aura_db_arrow_query_runner.py +++ b/graphdatascience/query_runner/aura_db_arrow_query_runner.py @@ -64,6 +64,7 @@ def run_query( token, aura_db_arrow_endpoint = self._get_or_request_auth_pair() params["token"] = token params["host"] = aura_db_arrow_endpoint + params["config"] = {"useEncryption": False} return self._fallback_query_runner.run_query(query, params, database, custom_error) diff --git a/graphdatascience/tests/conftest.py b/graphdatascience/tests/conftest.py index fd5d32482..807be3b2f 100644 --- a/graphdatascience/tests/conftest.py +++ b/graphdatascience/tests/conftest.py @@ -13,3 +13,6 @@ def pytest_addoption(parser: Any) -> None: ) parser.addoption("--target-aura", action="store_true", help="the database targeted is an AuraDS instance") parser.addoption("--include-ogb", action="store_true", help="include tests requiring the ogb dependency") + parser.addoption( + "--include-cloud-architecture", action="store_true", help="include tests resuiring a cloud architecture setup" + ) diff --git a/graphdatascience/tests/integration/conftest.py b/graphdatascience/tests/integration/conftest.py index 899d6d2b3..d176978a2 100644 --- a/graphdatascience/tests/integration/conftest.py +++ b/graphdatascience/tests/integration/conftest.py @@ -1,11 +1,14 @@ import os from pathlib import Path -from typing import Any, Generator +from typing import Any, Generator, Optional import pytest from neo4j import Driver, GraphDatabase from graphdatascience.graph_data_science import GraphDataScience +from graphdatascience.query_runner.aura_db_arrow_query_runner import ( + AuraDbConnectionInfo, +) from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") @@ -20,6 +23,9 @@ DB = os.environ.get("NEO4J_DB", "neo4j") +AURA_DB_URI = os.environ.get("NEO4J_AURA_DB_URI", "bolt://localhost:7689") +AURA_DB_AUTH = ("neo4j", "password") + @pytest.fixture(scope="package") def neo4j_driver() -> Generator[Driver, None, None]: @@ -38,6 +44,16 @@ def runner(neo4j_driver: Driver) -> Neo4jQueryRunner: return _runner +@pytest.fixture(scope="package", autouse=False) +def auradb_runner() -> Neo4jQueryRunner: + driver = GraphDatabase.driver(AURA_DB_URI, auth=AURA_DB_AUTH) + + _runner = Neo4jQueryRunner(driver) + _runner.set_database(DB) + + return _runner + + @pytest.fixture(scope="package") def gds() -> GraphDataScience: _gds = GraphDataScience(URI, auth=AUTH) @@ -74,6 +90,18 @@ def gds_without_arrow() -> GraphDataScience: return _gds +@pytest.fixture(scope="package", autouse=False) +def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Optional[GraphDataScience]: + if "cloud_architecture" not in request.keywords: + _gds = GraphDataScience( + URI, auth=AUTH, arrow=True, aura_db_connection_info=AuraDbConnectionInfo(AURA_DB_URI, AURA_DB_AUTH) + ) + _gds.set_database(DB) + + return _gds + return None + + @pytest.fixture(autouse=True) def clean_up(gds: GraphDataScience) -> Generator[None, None, None]: yield @@ -139,6 +167,12 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None: if "model_store_location" in item.keywords: item.add_marker(skip_stored_models) + if not config.getoption("--include-cloud-architecture"): + skip_on_prem = pytest.mark.skip(reason="need --include-cloud-architecture option to run") + for item in items: + if "cloud_architecture" in item.keywords: + item.add_marker(skip_on_prem) + try: server_version = GraphDataScience(URI, auth=AUTH)._server_version except Exception as e: diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py new file mode 100644 index 000000000..a5df7a49a --- /dev/null +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -0,0 +1,45 @@ +from typing import Generator + +import pytest + +from graphdatascience import GraphDataScience +from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner +from graphdatascience.server_version.server_version import ServerVersion + +GRAPH_NAME = "g" + + +@pytest.fixture(autouse=True, scope="class") +def run_around_tests( + auradb_runner: Neo4jQueryRunner, gds_with_cloud_setup: GraphDataScience +) -> Generator[None, None, None]: + # Runs before each test + auradb_runner.run_query( + """ + CREATE + (a: Node {x: 1, y: 2, z: [42], name: "nodeA"}), + (b: Node {x: 2, y: 3, z: [1337], name: "nodeB"}), + (c: Node {x: 3, y: 4, z: [9], name: "nodeC"}), + (a)-[:REL {relX: 4, relY: 5}]->(b), + (a)-[:REL {relX: 5, relY: 6}]->(c), + (b)-[:REL {relX: 6, relY: 7}]->(c), + (b)-[:REL2]->(c) + """ + ) + + yield # Test runs here + + # Runs after each test + auradb_runner.run_query("MATCH (n) DETACH DELETE n") + gds_with_cloud_setup._query_runner.run_query(f"CALL gds.graph.drop('{GRAPH_NAME}', false)") + + +@pytest.mark.cloud_architecture +@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 4, 0)) +def test_remote_projection(gds_with_cloud_setup: GraphDataScience) -> None: + G, result = gds_with_cloud_setup.alpha.graph.project.remote( + GRAPH_NAME, "MATCH (n)-->(m) RETURN n as sourceNode, m as targetNode", "neo4j" + ) + + assert G.name() == GRAPH_NAME + assert result["nodeCount"] == 3 diff --git a/graphdatascience/tests/pytest.ini b/graphdatascience/tests/pytest.ini index f24f34cc3..59480e7f5 100644 --- a/graphdatascience/tests/pytest.ini +++ b/graphdatascience/tests/pytest.ini @@ -7,3 +7,4 @@ markers = skip_on_aura: mark a test to not be run when targeting an AuraDS instance only_on_aura: mark a test to be run only when targeting an AuraDS instance ogb: mark a test as requiring the ogb dependency + cloud_architecture: mark a test to require a cloud setup like environment diff --git a/tox.ini b/tox.ini index 90dbfa42f..87f15ce5b 100644 --- a/tox.ini +++ b/tox.ini @@ -28,6 +28,7 @@ passenv = NEO4J_USER NEO4J_PASSWORD NEO4J_DB + NEO4J_AURA_DB_URI allowlist_externals = ruby bash @@ -55,6 +56,7 @@ commands = aura: pytest graphdatascience/tests --include-enterprise --target-aura -Werror ogb: pytest graphdatascience/tests --include-enterprise --include-ogb -Werror nx: bash -ec 'pytest graphdatascience/tests/*/test_nx_loader.py --include-enterprise -Werror && ruby ./doc/tests/test_docs.rb python3 -n test_networkx' + cloud-architecture: pytest graphdatascience/tests --include-cloud-architecture -Werror rm -rf {envdir}/lib [testenv:jupyter-notebook-ci]