Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare graphs for generator functions when running tests with backend #7066

Merged
merged 8 commits into from
Dec 7, 2023
2 changes: 1 addition & 1 deletion networkx/algorithms/bipartite/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
]


@nodes_or_number([0, 1])
@nx._dispatch(graphs=None)
@nodes_or_number([0, 1])
def complete_bipartite_graph(n1, n2, create_using=None):
"""Returns the complete bipartite graph `K_{n_1,n_2}`.

Expand Down
8 changes: 4 additions & 4 deletions networkx/algorithms/operators/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _init_product_graph(G, H):
return GH


@nx._dispatch(graphs=_G_H)
@nx._dispatch(graphs=_G_H, preserve_node_attrs=True)
def tensor_product(G, H):
r"""Returns the tensor product of G and H.

Expand Down Expand Up @@ -179,7 +179,7 @@ def tensor_product(G, H):
return GH


@nx._dispatch(graphs=_G_H)
@nx._dispatch(graphs=_G_H, preserve_node_attrs=True)
def cartesian_product(G, H):
r"""Returns the Cartesian product of G and H.

Expand Down Expand Up @@ -231,7 +231,7 @@ def cartesian_product(G, H):
return GH


@nx._dispatch(graphs=_G_H)
@nx._dispatch(graphs=_G_H, preserve_node_attrs=True)
def lexicographic_product(G, H):
r"""Returns the lexicographic product of G and H.

Expand Down Expand Up @@ -284,7 +284,7 @@ def lexicographic_product(G, H):
return GH


@nx._dispatch(graphs=_G_H)
@nx._dispatch(graphs=_G_H, preserve_node_attrs=True)
def strong_product(G, H):
r"""Returns the strong product of G and H.

Expand Down
2 changes: 1 addition & 1 deletion networkx/algorithms/regular.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def is_k_regular(G, k):

@not_implemented_for("directed")
@not_implemented_for("multigraph")
@nx._dispatch(edge_attrs="matching_weight")
@nx._dispatch(preserve_edge_attrs=True)
def k_factor(G, k, matching_weight="weight"):
"""Compute a k-factor of G

Expand Down
2 changes: 1 addition & 1 deletion networkx/algorithms/triads.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def triad_type(G):

@not_implemented_for("undirected")
@py_random_state(1)
@nx._dispatch
@nx._dispatch(preserve_all_attrs=True)
def random_triad(G, seed=None):
"""Returns a random triad from a directed graph.

Expand Down
6 changes: 3 additions & 3 deletions networkx/classes/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def from_scipy_sparse_array(self, *args, **kwargs):
side_effects.append(1) # Just to prove this was called
return self.convert_from_nx(
self.__getattr__("from_scipy_sparse_array")(*args, **kwargs),
preserve_edge_attrs=None,
preserve_node_attrs=None,
preserve_graph_attrs=None,
preserve_edge_attrs=True,
preserve_node_attrs=True,
preserve_graph_attrs=True,
)

@staticmethod
Expand Down
16 changes: 8 additions & 8 deletions networkx/generators/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def binomial_tree(n, create_using=None):
return G


@nodes_or_number(0)
@nx._dispatch(graphs=None)
@nodes_or_number(0)
def complete_graph(n, create_using=None):
"""Return the complete graph `K_n` with n nodes.

Expand Down Expand Up @@ -385,8 +385,8 @@ def circulant_graph(n, offsets, create_using=None):
return G


@nodes_or_number(0)
@nx._dispatch(graphs=None)
@nodes_or_number(0)
def cycle_graph(n, create_using=None):
"""Returns the cycle graph $C_n$ of cyclically connected nodes.

Expand Down Expand Up @@ -471,8 +471,8 @@ def dorogovtsev_goltsev_mendes_graph(n, create_using=None):
return G


@nodes_or_number(0)
@nx._dispatch(graphs=None)
@nodes_or_number(0)
def empty_graph(n=0, create_using=None, default=Graph):
"""Returns the empty graph with n nodes and zero edges.

Expand Down Expand Up @@ -585,8 +585,8 @@ def ladder_graph(n, create_using=None):
return G


@nodes_or_number([0, 1])
@nx._dispatch(graphs=None)
@nodes_or_number([0, 1])
def lollipop_graph(m, n, create_using=None):
"""Returns the Lollipop Graph; ``K_m`` connected to ``P_n``.

Expand Down Expand Up @@ -659,8 +659,8 @@ def null_graph(create_using=None):
return G


@nodes_or_number(0)
@nx._dispatch(graphs=None)
@nodes_or_number(0)
def path_graph(n, create_using=None):
"""Returns the Path graph `P_n` of linearly connected nodes.

Expand All @@ -681,8 +681,8 @@ def path_graph(n, create_using=None):
return G


@nodes_or_number(0)
@nx._dispatch(graphs=None)
@nodes_or_number(0)
def star_graph(n, create_using=None):
"""Return the star graph

Expand Down Expand Up @@ -716,8 +716,8 @@ def star_graph(n, create_using=None):
return G


@nodes_or_number([0, 1])
@nx._dispatch(graphs=None)
@nodes_or_number([0, 1])
def tadpole_graph(m, n, create_using=None):
"""Returns the (m,n)-tadpole graph; ``C_m`` connected to ``P_n``.

Expand Down Expand Up @@ -815,8 +815,8 @@ def turan_graph(n, r):
return G


@nodes_or_number(0)
@nx._dispatch(graphs=None)
@nodes_or_number(0)
def wheel_graph(n, create_using=None):
"""Return the wheel graph

Expand Down
2 changes: 1 addition & 1 deletion networkx/generators/ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import networkx as nx


@nx._dispatch(edge_attrs="distance")
@nx._dispatch(preserve_all_attrs=True)
def ego_graph(G, n, radius=1, center=True, undirected=False, distance=None):
"""Returns induced subgraph of neighbors centered at node n within
a given radius.
Expand Down
2 changes: 1 addition & 1 deletion networkx/generators/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
]


@nodes_or_number([0, 1])
@nx._dispatch(graphs=None)
@nodes_or_number([0, 1])
def grid_2d_graph(m, n, periodic=False, create_using=None):
"""Returns the two-dimensional grid graph.

Expand Down
90 changes: 84 additions & 6 deletions networkx/utils/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class WrappedSparse:
from functools import partial
from importlib.metadata import entry_points

import networkx as nx

from ..exception import NetworkXNotImplemented

__all__ = ["_dispatch"]
Expand Down Expand Up @@ -832,14 +834,51 @@ def _convert_and_call_for_tests(
msg += " with the given arguments"
pytest.xfail(msg)

from collections.abc import Iterator
from copy import copy
from io import BufferedReader, BytesIO
from itertools import tee
from random import Random

# We sometimes compare the backend result to the original result,
# so we need two sets of arguments. We tee iterators and copy
# random state so that they may be used twice.
if not args:
args1 = args2 = args
else:
args1, args2 = zip(
*(
(arg, copy(arg))
if isinstance(arg, Random | BytesIO)
dschult marked this conversation as resolved.
Show resolved Hide resolved
else tee(arg)
if isinstance(arg, Iterator) and not isinstance(arg, BufferedReader)
else (arg, arg)
for arg in args
)
)
if not kwargs:
kwargs1 = kwargs2 = kwargs
else:
kwargs1, kwargs2 = zip(
*(
((k, v), (k, copy(v)))
if isinstance(v, Random | BytesIO)
else ((k, (teed := tee(v))[0]), (k, teed[1]))
if isinstance(v, Iterator) and not isinstance(v, BufferedReader)
else ((k, v), (k, v))
for k, v in kwargs.items()
)
)
kwargs1 = dict(kwargs1)
kwargs2 = dict(kwargs2)
try:
converted_args, converted_kwargs = self._convert_arguments(
backend_name, args, kwargs
backend_name, args1, kwargs1
)
result = getattr(backend, self.name)(*converted_args, **converted_kwargs)
except (NotImplementedError, NetworkXNotImplemented) as exc:
if fallback_to_nx:
return self.orig_func(*args, **kwargs)
return self.orig_func(*args2, **kwargs2)
import pytest

pytest.xfail(
Expand All @@ -849,14 +888,15 @@ def _convert_and_call_for_tests(
if self.name in {
"edmonds_karp_core",
"barycenter",
"contracted_edge",
"contracted_nodes",
"stochastic_graph",
"relabel_nodes",
}:
# Special-case algorithms that mutate input graphs
bound = self.__signature__.bind(*converted_args, **converted_kwargs)
bound.apply_defaults()
bound2 = self.__signature__.bind(*args, **kwargs)
bound2 = self.__signature__.bind(*args2, **kwargs2)
bound2.apply_defaults()
if self.name == "edmonds_karp_core":
R1 = backend.convert_to_nx(bound.arguments["R"])
Expand All @@ -869,7 +909,10 @@ def _convert_and_call_for_tests(
attr = bound.arguments["attr"]
for k, v in G1.nodes.items():
G2.nodes[k][attr] = v[attr]
elif self.name == "contracted_nodes" and not bound.arguments["copy"]:
elif (
self.name in {"contracted_nodes", "contracted_edge"}
and not bound.arguments["copy"]
):
# Edges and nodes changed; node "contraction" and edge "weight" attrs
G1 = backend.convert_to_nx(bound.arguments["G"])
G2 = bound2.arguments["G"]
Expand All @@ -895,8 +938,43 @@ def _convert_and_call_for_tests(
G2._succ.clear()
G2._succ.update(G1._succ)
return G2

return backend.convert_to_nx(result, name=self.name)
return backend.convert_to_nx(result)

converted_result = backend.convert_to_nx(result)
if isinstance(converted_result, nx.Graph) and self.name not in {
"boykov_kolmogorov",
"preflow_push",
"quotient_graph",
"shortest_augmenting_path",
"spectral_graph_forge",
# We don't handle tempfile.NamedTemporaryFile arguments
"read_gml",
"read_graph6",
"read_sparse6",
# We don't handle io.BufferedReader arguments
"bipartite_read_edgelist",
"read_adjlist",
"read_edgelist",
"read_graphml",
"read_multiline_adjlist",
"read_pajek",
# graph comparison fails b/c of nan values
"read_gexf",
}:
# For graph return types (e.g. generators), we compare that results are
# the same between the backend and networkx, then return the original
# networkx result so the iteration order will be consistent in tests.
G = self.orig_func(*args2, **kwargs2)
if not nx.utils.graphs_equal(G, converted_result):
assert G.number_of_nodes() == converted_result.number_of_nodes()
assert G.number_of_edges() == converted_result.number_of_edges()
assert G.graph == converted_result.graph
assert G.nodes == converted_result.nodes
assert G.adj == converted_result.adj
assert type(G) is type(converted_result)
raise AssertionError("Graphs are not equal")
return G
return converted_result

def _make_doc(self):
if not self.backends:
Expand Down