From ed37131eba2b9b63fef4aa21dfd468b7309d0ed8 Mon Sep 17 00:00:00 2001 From: Orion Sehn Date: Wed, 17 Apr 2024 20:56:49 -0600 Subject: [PATCH 1/2] linting --- networkx/algorithms/approximation/steinertree.py | 15 ++++++++++++--- .../approximation/tests/test_steinertree.py | 11 +++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/networkx/algorithms/approximation/steinertree.py b/networkx/algorithms/approximation/steinertree.py index c6c834f422c..475bea6890d 100644 --- a/networkx/algorithms/approximation/steinertree.py +++ b/networkx/algorithms/approximation/steinertree.py @@ -114,9 +114,18 @@ def _kou_steiner_tree(G, terminal_nodes, weight): def _remove_nonterminal_leaves(G, terminals): terminals_set = set(terminals) - for n in list(G.nodes): - if n not in terminals_set and G.degree(n) == 1: - G.remove_node(n) + degree_one_nodes = {n for n in G if G.degree(n) == 1} + nonterminal_leaves = degree_one_nodes - terminals_set + while nonterminal_leaves: + possible_nonterminal_leaves = set() + for n in nonterminal_leaves: + possible_nonterminal_leaves = possible_nonterminal_leaves | set( + G.neighbors(n) + ) + G.remove_nodes_from(nonterminal_leaves) + nonterminal_leaves = { + leaf for leaf in possible_nonterminal_leaves if G.degree(leaf) == 1 + } - terminals_set ALGORITHMS = { diff --git a/networkx/algorithms/approximation/tests/test_steinertree.py b/networkx/algorithms/approximation/tests/test_steinertree.py index 23c3193e42e..299781df12b 100644 --- a/networkx/algorithms/approximation/tests/test_steinertree.py +++ b/networkx/algorithms/approximation/tests/test_steinertree.py @@ -190,6 +190,17 @@ def test_multigraph_steiner_tree(self): S = steiner_tree(G, terminal_nodes, method=method) assert edges_equal(S.edges(data=True, keys=True), expected_edges) + def test_remove_nonterminal_leaves(self): + from networkx.algorithms.approximation.steinertree import ( + _remove_nonterminal_leaves, + ) + + G = nx.Graph() + G.add_edges_from([(1, 2), (2, 3), (3, 4), (4, 5)], weight=1) + _remove_nonterminal_leaves(G, [2, 3]) + + assert list(G.nodes) == [2, 3] # only the terminal nodes are left + @pytest.mark.parametrize("method", ("kou", "mehlhorn")) def test_steiner_tree_weight_attribute(method): From 01e0f127c2ee3475adbc3f835c0d04ffcc64ce9d Mon Sep 17 00:00:00 2001 From: Orion Sehn <97123736+OrionSehn@users.noreply.github.com> Date: Tue, 21 May 2024 17:00:39 -0600 Subject: [PATCH 2/2] Update networkx/algorithms/approximation/tests/test_steinertree.py improve test so that there are multiple non-terminal nodes on both ends of the path Co-authored-by: Ross Barnowski --- .../approximation/tests/test_steinertree.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/networkx/algorithms/approximation/tests/test_steinertree.py b/networkx/algorithms/approximation/tests/test_steinertree.py index 299781df12b..763a9618f3c 100644 --- a/networkx/algorithms/approximation/tests/test_steinertree.py +++ b/networkx/algorithms/approximation/tests/test_steinertree.py @@ -191,15 +191,10 @@ def test_multigraph_steiner_tree(self): assert edges_equal(S.edges(data=True, keys=True), expected_edges) def test_remove_nonterminal_leaves(self): - from networkx.algorithms.approximation.steinertree import ( - _remove_nonterminal_leaves, - ) - - G = nx.Graph() - G.add_edges_from([(1, 2), (2, 3), (3, 4), (4, 5)], weight=1) - _remove_nonterminal_leaves(G, [2, 3]) + G = nx.path_graph(10) + nx.approximation.steinertree._remove_nonterminal_leaves(G, [4, 5, 6]) - assert list(G.nodes) == [2, 3] # only the terminal nodes are left + assert list(G) == [4, 5, 6] # only the terminal nodes are left @pytest.mark.parametrize("method", ("kou", "mehlhorn"))