Skip to content

Commit

Permalink
Allow backends to implement should_run (#7257)
Browse files Browse the repository at this point in the history
* Allow backends to implement `should_run`

* Handle str return types for `can_run` and `should_run`
  • Loading branch information
eriknw committed Mar 5, 2024
1 parent 41fd8df commit 1f8d279
Showing 1 changed file with 41 additions and 4 deletions.
45 changes: 41 additions & 4 deletions networkx/utils/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ class WrappedSparse:
If a backend only partially implements some algorithms, it can define
a ``can_run(name, args, kwargs)`` function that returns True or False
indicating whether it can run the algorithm with the given arguments.
It may also return a string indicating why the algorithm can't be run;
this string may be used in the future to give helpful info to the user.
A backend may also define ``should_run(name, args, kwargs)`` that is similar
to ``can_run``, but answers whether the backend *should* be run (converting
if necessary). Like ``can_run``, it receives the original arguments so it
can decide whether it should be run by inspecting the arguments. ``can_run``
runs before ``should_run``, so ``should_run`` may assume ``can_run`` is True.
If not implemented by the backend, ``can_run`` and ``should_run`` are
assumed to always return True if the backend implements the algorithm.
A special ``on_start_tests(items)`` function may be defined by the backend.
It will be called with the list of NetworkX tests discovered. Each item
Expand Down Expand Up @@ -135,10 +146,18 @@ def _get_backends(group, *, load_and_call=False):
_loaded_backends = {} # type: ignore[var-annotated]


def _always_run(name, args, kwargs):
return True


def _load_backend(backend_name):
if backend_name in _loaded_backends:
return _loaded_backends[backend_name]
rv = _loaded_backends[backend_name] = backends[backend_name].load()
if not hasattr(rv, "can_run"):
rv.can_run = _always_run
if not hasattr(rv, "should_run"):
rv.should_run = _always_run
return rv


Expand Down Expand Up @@ -579,6 +598,7 @@ def __call__(self, /, *args, backend=None, **kwargs):
# Only networkx graphs; try to convert and run with a backend with automatic
# conversion, but don't do this by default for graph generators or loaders,
# or if the functions mutates an input graph or returns a graph.
# Only convert and run if `backend.should_run(...)` returns True.
if (
not self._returns_graph
and (
Expand All @@ -603,7 +623,7 @@ def __call__(self, /, *args, backend=None, **kwargs):
):
# Should we warn or log if we don't convert b/c the input will be mutated?
for backend_name in self._automatic_backends:
if self._can_backend_run(backend_name, *args, **kwargs):
if self._should_backend_run(backend_name, *args, **kwargs):
return self._convert_and_call(
backend_name,
args,
Expand All @@ -614,10 +634,27 @@ def __call__(self, /, *args, backend=None, **kwargs):
return self.orig_func(*args, **kwargs)

def _can_backend_run(self, backend_name, /, *args, **kwargs):
"""Can the specified backend run this algorithms with these arguments?"""
"""Can the specified backend run this algorithm with these arguments?"""
backend = _load_backend(backend_name)
# `backend.can_run` and `backend.should_run` may return strings that describe
# why they can't or shouldn't be run. We plan to use the strings in the future.
return (
hasattr(backend, self.name)
and (can_run := backend.can_run(self.name, args, kwargs))
and not isinstance(can_run, str)
)

def _should_backend_run(self, backend_name, /, *args, **kwargs):
"""Can/should the specified backend run this algorithm with these arguments?"""
backend = _load_backend(backend_name)
return hasattr(backend, self.name) and (
not hasattr(backend, "can_run") or backend.can_run(self.name, args, kwargs)
# `backend.can_run` and `backend.should_run` may return strings that describe
# why they can't or shouldn't be run. We plan to use the strings in the future.
return (
hasattr(backend, self.name)
and (can_run := backend.can_run(self.name, args, kwargs))
and not isinstance(can_run, str)
and (should_run := backend.should_run(self.name, args, kwargs))
and not isinstance(should_run, str)
)

def _convert_arguments(self, backend_name, args, kwargs):
Expand Down

0 comments on commit 1f8d279

Please sign in to comment.