From 88a334b2297c8cb9d1f01819f778275a10cc2f25 Mon Sep 17 00:00:00 2001 From: Adam Schill Collberg Date: Fri, 21 Jan 2022 13:45:39 +0100 Subject: [PATCH 1/2] Add consistent error handling for client side only methods --- .../error/client_only_endpoint.py | 22 +++++++++++++++++++ graphdatascience/graph/graph_proc_runner.py | 5 ++--- graphdatascience/model/model_proc_runner.py | 5 ++--- .../tests/unit/test_error_handling.py | 5 +++++ graphdatascience/utils/util_endpoints.py | 2 ++ 5 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 graphdatascience/error/client_only_endpoint.py diff --git a/graphdatascience/error/client_only_endpoint.py b/graphdatascience/error/client_only_endpoint.py new file mode 100644 index 000000000..f93cfa05a --- /dev/null +++ b/graphdatascience/error/client_only_endpoint.py @@ -0,0 +1,22 @@ +from typing import Any, Callable, Protocol, TypeVar, cast + +F = TypeVar("F", bound=Callable[..., Any]) + + +class WithNamespace(Protocol): + _namespace: str + + +def client_only_endpoint(expected_namespace_prefix: str) -> Callable[[F], F]: + def decorator(func: F) -> F: + def wrapper(self: WithNamespace, *args: Any, **kwargs: Any) -> Any: + if self._namespace != expected_namespace_prefix: + raise SyntaxError( + f"There is no '{self._namespace}.{func.__name__}' to call" + ) + + return func(self, *args, **kwargs) + + return cast(F, wrapper) + + return decorator diff --git a/graphdatascience/graph/graph_proc_runner.py b/graphdatascience/graph/graph_proc_runner.py index e765d796e..2adf5526a 100644 --- a/graphdatascience/graph/graph_proc_runner.py +++ b/graphdatascience/graph/graph_proc_runner.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional, Union +from ..error.client_only_endpoint import client_only_endpoint from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace from ..query_runner.query_runner import QueryResult, QueryRunner @@ -67,10 +68,8 @@ def list(self, G: Optional[Graph] = None) -> QueryResult: return self._query_runner.run_query(query, params) + @client_only_endpoint("gds.graph") def get(self, graph_name: str) -> Graph: - if self._namespace != "gds.graph": - raise SyntaxError(f"There is no {self._namespace + '.get'} to call") - if not self.exists(graph_name)[0]["exists"]: raise ValueError(f"No projected graph named '{graph_name}' exists") diff --git a/graphdatascience/model/model_proc_runner.py b/graphdatascience/model/model_proc_runner.py index d611ae7a0..46173951f 100644 --- a/graphdatascience/model/model_proc_runner.py +++ b/graphdatascience/model/model_proc_runner.py @@ -1,5 +1,6 @@ from typing import Optional, Union +from ..error.client_only_endpoint import client_only_endpoint from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace from ..pipeline.lp_prediction_pipeline import LPPredictionPipeline @@ -101,10 +102,8 @@ def delete(self, model_id: ModelId) -> QueryResult: return self._query_runner.run_query(query, params) + @client_only_endpoint("gds.model") def get(self, model_name: str) -> Model: - if self._namespace != "gds.model": - raise SyntaxError(f"There is no {self._namespace + '.get'} to call") - self._namespace = "gds.beta.model" result = self.list(model_name) if len(result) == 0: diff --git a/graphdatascience/tests/unit/test_error_handling.py b/graphdatascience/tests/unit/test_error_handling.py index d7d84cd7f..36ab621fc 100644 --- a/graphdatascience/tests/unit/test_error_handling.py +++ b/graphdatascience/tests/unit/test_error_handling.py @@ -120,3 +120,8 @@ def test_nonexisting_similarity_endpoint(gds: GraphDataScience) -> None: SyntaxError, match="There is no 'gds.alpha.similarity.pearson.bogus' to call" ): gds.alpha.similarity.pearson.bogus() # type: ignore + + +def test_wrong_client_only_prefix(gds: GraphDataScience) -> None: + with pytest.raises(SyntaxError, match="There is no 'gds.beta.model.get' to call"): + gds.beta.model.get("model") diff --git a/graphdatascience/utils/util_endpoints.py b/graphdatascience/utils/util_endpoints.py index a4ef2df11..00b00da1c 100644 --- a/graphdatascience/utils/util_endpoints.py +++ b/graphdatascience/utils/util_endpoints.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List +from ..error.client_only_endpoint import client_only_endpoint from ..query_runner.query_runner import QueryResult, QueryRunner from .util_proc_runner import UtilProcRunner @@ -13,6 +14,7 @@ def __init__(self, query_runner: QueryRunner, namespace: str): def util(self) -> UtilProcRunner: return UtilProcRunner(self._query_runner, f"{self._namespace}.util") + @client_only_endpoint("gds") def find_node_id( self, labels: List[str] = [], properties: Dict[str, Any] = {} ) -> int: From 76fabda5c046d2b022647a2117172a4ffdd66090 Mon Sep 17 00:00:00 2001 From: Adam Schill Collberg Date: Fri, 21 Jan 2022 17:08:49 +0100 Subject: [PATCH 2/2] Use abstract class instea of Protocol To support all Python versions. --- graphdatascience/error/client_only_endpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/graphdatascience/error/client_only_endpoint.py b/graphdatascience/error/client_only_endpoint.py index f93cfa05a..c15716fe6 100644 --- a/graphdatascience/error/client_only_endpoint.py +++ b/graphdatascience/error/client_only_endpoint.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, Protocol, TypeVar, cast +from abc import ABC +from typing import Any, Callable, TypeVar, cast F = TypeVar("F", bound=Callable[..., Any]) -class WithNamespace(Protocol): +class WithNamespace(ABC): _namespace: str