Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code to handle multi-graph in mst #7454

Merged
merged 7 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions networkx/algorithms/tree/mst.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,19 +1126,28 @@ def _write_partition(self, partition):
A Partition dataclass describing a partition on the edges of the
graph.
"""
for u, v, d in self.G.edges(data=True):
if (u, v) in partition.partition_dict:
d[self.partition_key] = partition.partition_dict[(u, v)]
else:
d[self.partition_key] = EdgePartition.OPEN

partition_dict = partition.partition_dict
partition_key = self.partition_key
G = self.G

edges = (
G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
)
for *e, d in edges:
d[partition_key] = partition_dict.get(tuple(e), EdgePartition.OPEN)

def _clear_partition(self, G):
"""
Removes partition data from the graph
"""
for u, v, d in G.edges(data=True):
if self.partition_key in d:
del d[self.partition_key]
partition_key = self.partition_key
edges = (
G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
)
Comment on lines +1145 to +1147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're here, I think this looks cleaner. Feel free to disregard.

Suggested change
edges = (
G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
)
edges = G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)

for *e, d in edges:
Comment on lines +1145 to +1148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this pattern! What do you think about using it above when setting d[partition_key]? This way, we wouldn't need separate for loops for if G.is_multigraph() and else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looked neat this way.
Will it affect the performance of the code? I would love to hear your suggestions. I will make sure to change the code accordingly.

if partition_key in d:
del d[partition_key]


@nx._dispatchable(edge_attrs="weight")
Expand Down
62 changes: 62 additions & 0 deletions networkx/algorithms/tree/tests/test_mst.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,68 @@ def test_maximum_spanning_tree_iterator(self):
tree_index -= 1


class TestSpanningTreeMultiGraphIterator:
"""
Uses the same graph as the above class but with an added edge of twice the weight.
"""

def setup_method(self):
# New graph
edges = [
(0, 1, 5),
(0, 1, 10),
(1, 2, 4),
(1, 2, 8),
(1, 4, 6),
(1, 4, 12),
(2, 3, 5),
(2, 3, 10),
(2, 4, 7),
(2, 4, 14),
(3, 4, 3),
(3, 4, 6),
]
self.G = nx.MultiGraph()
self.G.add_weighted_edges_from(edges)

# There are 128 trees. I'd rather not list all 128 here, and computing them
# on such a small graph actually doesn't take that long.
from itertools import combinations

self.spanning_trees = []
for e in combinations(self.G.edges, 4):
tree = self.G.edge_subgraph(e)
if nx.is_tree(tree):
self.spanning_trees.append(sorted(tree.edges(keys=True, data=True)))

def test_minimum_spanning_tree_iterator_multigraph(self):
"""
Tests that the spanning trees are correctly returned in increasing order
"""
tree_index = 0
last_weight = 0
for tree in nx.SpanningTreeIterator(self.G):
actual = sorted(tree.edges(keys=True, data=True))
weight = sum([e[3]["weight"] for e in actual])
assert actual in self.spanning_trees
assert weight >= last_weight
tree_index += 1

def test_maximum_spanning_tree_iterator_multigraph(self):
"""
Tests that the spanning trees are correctly returned in decreasing order
"""
tree_index = 127
# Maximum weight tree is 46
last_weight = 50
for tree in nx.SpanningTreeIterator(self.G, minimum=False):
actual = sorted(tree.edges(keys=True, data=True))
weight = sum([e[3]["weight"] for e in actual])
assert actual in self.spanning_trees
assert weight <= last_weight
tree_index -= 1


def test_random_spanning_tree_multiplicative_small():
"""
Using a fixed seed, sample one tree for repeatability.
Expand Down
Loading