diff --git a/networkx/algorithms/approximation/steinertree.py b/networkx/algorithms/approximation/steinertree.py index c6c834f422c..475bea6890d 100644 --- a/networkx/algorithms/approximation/steinertree.py +++ b/networkx/algorithms/approximation/steinertree.py @@ -114,9 +114,18 @@ def _kou_steiner_tree(G, terminal_nodes, weight): def _remove_nonterminal_leaves(G, terminals): terminals_set = set(terminals) - for n in list(G.nodes): - if n not in terminals_set and G.degree(n) == 1: - G.remove_node(n) + degree_one_nodes = {n for n in G if G.degree(n) == 1} + nonterminal_leaves = degree_one_nodes - terminals_set + while nonterminal_leaves: + possible_nonterminal_leaves = set() + for n in nonterminal_leaves: + possible_nonterminal_leaves = possible_nonterminal_leaves | set( + G.neighbors(n) + ) + G.remove_nodes_from(nonterminal_leaves) + nonterminal_leaves = { + leaf for leaf in possible_nonterminal_leaves if G.degree(leaf) == 1 + } - terminals_set ALGORITHMS = { diff --git a/networkx/algorithms/approximation/tests/test_steinertree.py b/networkx/algorithms/approximation/tests/test_steinertree.py index 23c3193e42e..763a9618f3c 100644 --- a/networkx/algorithms/approximation/tests/test_steinertree.py +++ b/networkx/algorithms/approximation/tests/test_steinertree.py @@ -190,6 +190,12 @@ def test_multigraph_steiner_tree(self): S = steiner_tree(G, terminal_nodes, method=method) assert edges_equal(S.edges(data=True, keys=True), expected_edges) + def test_remove_nonterminal_leaves(self): + G = nx.path_graph(10) + nx.approximation.steinertree._remove_nonterminal_leaves(G, [4, 5, 6]) + + assert list(G) == [4, 5, 6] # only the terminal nodes are left + @pytest.mark.parametrize("method", ("kou", "mehlhorn")) def test_steiner_tree_weight_attribute(method):