Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions graphdatascience/graph/graph_alpha_project_runner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Any
from typing import Any, Tuple

from pandas import Series

from graphdatascience.error.illegal_attr_checker import IllegalAttrChecker
from graphdatascience.graph.graph_object import Graph
from graphdatascience.server_version.compatible_with import compatible_with
from graphdatascience.server_version.server_version import ServerVersion


class GraphAlphaProjectRunner(IllegalAttrChecker):
@compatible_with("remote", min_inclusive=ServerVersion(2, 4, 0))
def remote(self, graph_name: str, query: str, remote_database: str, **config: Any) -> "Series[Any]":
def remote(self, graph_name: str, query: str, remote_database: str, **config: Any) -> Tuple[Graph, "Series[Any]"]:
self._namespace += ".remote"
query = f"CALL {self._namespace}($graph_name, $query, $token, $host, $remote_database, $config)"
procedure_query = f"CALL {self._namespace}($graph_name, $query, $token, $host, $remote_database, $config)"
params = {"graph_name": graph_name, "query": query, "remote_database": remote_database, "config": config}
return self._query_runner.run_query(query, params).squeeze() # type: ignore
result = self._query_runner.run_query(procedure_query, params).squeeze()
return Graph(graph_name, self._query_runner, self._server_version), result
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def run_query(
token, aura_db_arrow_endpoint = self._get_or_request_auth_pair()
params["token"] = token
params["host"] = aura_db_arrow_endpoint
params["config"] = {"useEncryption": False}

return self._fallback_query_runner.run_query(query, params, database, custom_error)

Expand Down
3 changes: 3 additions & 0 deletions graphdatascience/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ def pytest_addoption(parser: Any) -> None:
)
parser.addoption("--target-aura", action="store_true", help="the database targeted is an AuraDS instance")
parser.addoption("--include-ogb", action="store_true", help="include tests requiring the ogb dependency")
parser.addoption(
"--include-cloud-architecture", action="store_true", help="include tests resuiring a cloud architecture setup"
)
36 changes: 35 additions & 1 deletion graphdatascience/tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
from pathlib import Path
from typing import Any, Generator
from typing import Any, Generator, Optional

import pytest
from neo4j import Driver, GraphDatabase

from graphdatascience.graph_data_science import GraphDataScience
from graphdatascience.query_runner.aura_db_arrow_query_runner import (
AuraDbConnectionInfo,
)
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner

URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
Expand All @@ -20,6 +23,9 @@

DB = os.environ.get("NEO4J_DB", "neo4j")

AURA_DB_URI = os.environ.get("NEO4J_AURA_DB_URI", "bolt://localhost:7689")
AURA_DB_AUTH = ("neo4j", "password")


@pytest.fixture(scope="package")
def neo4j_driver() -> Generator[Driver, None, None]:
Expand All @@ -38,6 +44,16 @@ def runner(neo4j_driver: Driver) -> Neo4jQueryRunner:
return _runner


@pytest.fixture(scope="package", autouse=False)
def auradb_runner() -> Neo4jQueryRunner:
driver = GraphDatabase.driver(AURA_DB_URI, auth=AURA_DB_AUTH)

_runner = Neo4jQueryRunner(driver)
_runner.set_database(DB)

return _runner


@pytest.fixture(scope="package")
def gds() -> GraphDataScience:
_gds = GraphDataScience(URI, auth=AUTH)
Expand Down Expand Up @@ -74,6 +90,18 @@ def gds_without_arrow() -> GraphDataScience:
return _gds


@pytest.fixture(scope="package", autouse=False)
def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Optional[GraphDataScience]:
if "cloud_architecture" not in request.keywords:
_gds = GraphDataScience(
URI, auth=AUTH, arrow=True, aura_db_connection_info=AuraDbConnectionInfo(AURA_DB_URI, AURA_DB_AUTH)
)
_gds.set_database(DB)

return _gds
return None


@pytest.fixture(autouse=True)
def clean_up(gds: GraphDataScience) -> Generator[None, None, None]:
yield
Expand Down Expand Up @@ -139,6 +167,12 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None:
if "model_store_location" in item.keywords:
item.add_marker(skip_stored_models)

if not config.getoption("--include-cloud-architecture"):
skip_on_prem = pytest.mark.skip(reason="need --include-cloud-architecture option to run")
for item in items:
if "cloud_architecture" in item.keywords:
item.add_marker(skip_on_prem)

try:
server_version = GraphDataScience(URI, auth=AUTH)._server_version
except Exception as e:
Expand Down
45 changes: 45 additions & 0 deletions graphdatascience/tests/integration/test_remote_graph_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Generator

import pytest

from graphdatascience import GraphDataScience
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
from graphdatascience.server_version.server_version import ServerVersion

GRAPH_NAME = "g"


@pytest.fixture(autouse=True, scope="class")
def run_around_tests(
auradb_runner: Neo4jQueryRunner, gds_with_cloud_setup: GraphDataScience
) -> Generator[None, None, None]:
# Runs before each test
auradb_runner.run_query(
"""
CREATE
(a: Node {x: 1, y: 2, z: [42], name: "nodeA"}),
(b: Node {x: 2, y: 3, z: [1337], name: "nodeB"}),
(c: Node {x: 3, y: 4, z: [9], name: "nodeC"}),
(a)-[:REL {relX: 4, relY: 5}]->(b),
(a)-[:REL {relX: 5, relY: 6}]->(c),
(b)-[:REL {relX: 6, relY: 7}]->(c),
(b)-[:REL2]->(c)
"""
)

yield # Test runs here

# Runs after each test
auradb_runner.run_query("MATCH (n) DETACH DELETE n")
gds_with_cloud_setup._query_runner.run_query(f"CALL gds.graph.drop('{GRAPH_NAME}', false)")


@pytest.mark.cloud_architecture
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 4, 0))
def test_remote_projection(gds_with_cloud_setup: GraphDataScience) -> None:
G, result = gds_with_cloud_setup.alpha.graph.project.remote(
GRAPH_NAME, "MATCH (n)-->(m) RETURN n as sourceNode, m as targetNode", "neo4j"
)

assert G.name() == GRAPH_NAME
assert result["nodeCount"] == 3
1 change: 1 addition & 0 deletions graphdatascience/tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ markers =
skip_on_aura: mark a test to not be run when targeting an AuraDS instance
only_on_aura: mark a test to be run only when targeting an AuraDS instance
ogb: mark a test as requiring the ogb dependency
cloud_architecture: mark a test to require a cloud setup like environment
2 changes: 2 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ passenv =
NEO4J_USER
NEO4J_PASSWORD
NEO4J_DB
NEO4J_AURA_DB_URI
allowlist_externals =
ruby
bash
Expand Down Expand Up @@ -55,6 +56,7 @@ commands =
aura: pytest graphdatascience/tests --include-enterprise --target-aura -Werror
ogb: pytest graphdatascience/tests --include-enterprise --include-ogb -Werror
nx: bash -ec 'pytest graphdatascience/tests/*/test_nx_loader.py --include-enterprise -Werror && ruby ./doc/tests/test_docs.rb python3 -n test_networkx'
cloud-architecture: pytest graphdatascience/tests --include-cloud-architecture -Werror
rm -rf {envdir}/lib

[testenv:jupyter-notebook-ci]
Expand Down