From f9b005094b36ee81cb65e9511d52ed2851dcf8ec Mon Sep 17 00:00:00 2001 From: georgios-ts Date: Tue, 7 Sep 2021 19:13:37 +0300 Subject: [PATCH] Fixes `digraph_union` if `merge_edges` is set to true. Previously, `digraph_union` would falsely keep or delete edges if `merge_edges` is set to true. This commit fixes the logic of `digraph_union` to skip an edge from the second graph if both its endpoints were merged to nodes from the first graph and these nodes already share an edge with equal weight data. At the same time, a new function `graph_union` was added that returns the union of two `PyGraph`s. Closes #432. --- docs/source/api.rst | 1 + .../notes/bugfix-union-7da79789134a3028.yaml | 24 ++ src/lib.rs | 1 + src/union.rs | 241 ++++++++++++------ tests/digraph/test_union.py | 47 ++-- tests/graph/test_union.py | 76 ++++++ 6 files changed, 288 insertions(+), 102 deletions(-) create mode 100644 releasenotes/notes/bugfix-union-7da79789134a3028.yaml create mode 100644 tests/graph/test_union.py diff --git a/docs/source/api.rst b/docs/source/api.rst index c535383f3..5b0d78104 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -146,6 +146,7 @@ Other Algorithm Functions retworkx.core_number retworkx.graph_greedy_color retworkx.digraph_union + retworkx.graph_union retworkx.metric_closure Generators diff --git a/releasenotes/notes/bugfix-union-7da79789134a3028.yaml b/releasenotes/notes/bugfix-union-7da79789134a3028.yaml new file mode 100644 index 000000000..132195c10 --- /dev/null +++ b/releasenotes/notes/bugfix-union-7da79789134a3028.yaml @@ -0,0 +1,24 @@ +--- +fixes: + - | + Previously, :func:`~retworkx.digraph_union` would falsely keep or delete edges + if argument `merge_edges` is set to true. This has been fixed and an edge from + the second graph will be skipped if both its endpoints were merged to nodes from + the first graph and these nodes already share an edge with equal weight data. + Fixed `#432 `__ +features: + - | + Adds a new function :func:`~retworkx.graph_union` that returns the union + of two :class:`~retworkx.PyGraph` objects. + For example: + + .. jupyter-execute:: + + import retworkx + from retworkx.visualization import mpl_draw + + first = retworkx.generators.path_graph(3, weights=["a_0", "node", "a_1"]) + second = retworkx.generators.cycle_graph(3, weights=["node", "b_0", "b_1"]) + graph = retworkx.graph_union(first, second, merge_nodes=True) + mpl_draw(graph) + \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index f028c9cf3..1bf4adfc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -171,6 +171,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(digraph_vf2_mapping))?; m.add_wrapped(wrap_pyfunction!(graph_vf2_mapping))?; m.add_wrapped(wrap_pyfunction!(digraph_union))?; + m.add_wrapped(wrap_pyfunction!(graph_union))?; m.add_wrapped(wrap_pyfunction!(topological_sort))?; m.add_wrapped(wrap_pyfunction!(descendants))?; m.add_wrapped(wrap_pyfunction!(ancestors))?; diff --git a/src/union.rs b/src/union.rs index 307bcfccb..7bd307de5 100644 --- a/src/union.rs +++ b/src/union.rs @@ -10,106 +10,167 @@ // License for the specific language governing permissions and limitations // under the License. -use crate::{digraph, digraph::PyDiGraph}; -use hashbrown::{HashMap, HashSet}; -use petgraph::algo; -use petgraph::graph::EdgeIndex; +use crate::{digraph, graph}; + +use petgraph::stable_graph::{NodeIndex, StableGraph}; +use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable}; +use petgraph::{algo, EdgeType}; + use pyo3::prelude::*; use pyo3::Python; -use std::cmp::Ordering; -/// [Graph] Return a new PyDiGraph by forming a union from`a` and `b` graphs. -/// -/// The algorithm has three phases: -/// - adds all nodes from `b` to `a`. operates in O(n), n being number of nodes in `b`. -/// - merges nodes from `b` over `a` given that: -/// - `merge_nodes` is `true`. operates in O(n^2), n being number of nodes in `b`. -/// - respective node in`b` and `a` share the same weight -/// - adds all edges from `b` to `a`. -/// - `merge_edges` is `true` -/// - respective edge in`b` and `a` share the same weight -/// -/// with the same weight in graphs `a` and `b` and merged those nodes. -/// -/// The nodes from graph `b` will replace nodes from `a`. -/// -/// At this point, only `PyDiGraph` is supported. -fn _digraph_union( - py: Python, - a: &PyDiGraph, - b: &PyDiGraph, - merge_nodes: bool, - merge_edges: bool, -) -> PyResult { - let first = &a.graph; - let second = &b.graph; - let mut combined = PyDiGraph { - graph: first.clone(), - cycle_state: algo::DfsSpace::default(), - check_cycle: false, - node_removed: false, - multigraph: true, - }; - let mut node_map = HashMap::with_capacity(second.node_count()); - let mut edge_map = HashSet::with_capacity(second.edge_count()); +type StablePyGraph = StableGraph; - let compare_weights = |a: &PyAny, b: &PyAny| -> PyResult { - let res = a.compare(b)?; - Ok(res == Ordering::Equal) - }; - - for node in second.node_indices() { - let node_index = combined.add_node(second[node].clone_ref(py))?; - node_map.insert(node.index(), node_index); +fn find_node_by_weight( + py: Python, + graph: &StablePyGraph, + obj: &PyObject, +) -> PyResult> { + let mut index = None; + for node in graph.node_indices() { + let weight = graph.node_weight(node).unwrap(); + if obj + .as_ref(py) + .rich_compare(weight, pyo3::basic::CompareOp::Eq)? + .is_true()? + { + index = Some(node); + break; + } } + Ok(index) +} - for edge in b.weighted_edge_list(py).edges { - let source = edge.0; - let target = edge.1; - let edge_weight = edge.2; - - let new_source = *node_map.get(&source).unwrap(); - let new_target = *node_map.get(&target).unwrap(); +#[derive(Copy, Clone)] +enum Entry { + Merged(T), + Added(T), + None, +} - let edge_index = combined.add_edge( - new_source, - new_target, - edge_weight.clone_ref(py), - )?; +fn extract(x: Entry) -> T { + match x { + Entry::Merged(val) => val, + Entry::Added(val) => val, + Entry::None => panic!("called `Entry::extract()` on a `None` value"), + } +} - let edge_node = EdgeIndex::new(edge_index); +fn union( + py: Python, + first: &StablePyGraph, + second: &StablePyGraph, + merge_nodes: bool, + merge_edges: bool, +) -> PyResult> { + let mut out_graph = first.clone(); - if combined.has_edge(source, target) { - let w = combined.graph.edge_weight(edge_node).unwrap(); - if compare_weights(edge_weight.as_ref(py), w.as_ref(py)).unwrap() { - edge_map.insert(edge_node); + let mut node_map: Vec> = + vec![Entry::None; second.node_bound()]; + for node in second.node_indices() { + let weight = &second[node]; + if merge_nodes { + if let Some(index) = find_node_by_weight(py, first, weight)? { + node_map[node.index()] = Entry::Merged(index); + continue; } } + + let index = out_graph.add_node(weight.clone_ref(py)); + node_map[node.index()] = Entry::Added(index); } - if merge_nodes { - for node in second.node_indices() { - let weight = &second[node].clone_ref(py); - let index = a.find_node_by_weight(py, weight.clone_ref(py)); - - if index.is_some() { - let other_node = node_map.get(&node.index()); - combined.merge_nodes( - py, - *other_node.unwrap(), - index.unwrap(), - )?; + let weights_equal = |a: &PyObject, b: &PyObject| -> PyResult { + a.as_ref(py) + .rich_compare(b, pyo3::basic::CompareOp::Eq)? + .is_true() + }; + + for edge in second.edge_references() { + let source = edge.source().index(); + let target = edge.target().index(); + let new_weight = edge.weight(); + + let mut found = false; + if merge_edges { + // if both endpoints were merged, + // check if need to skip the edge as well. + if let (Entry::Merged(new_source), Entry::Merged(new_target)) = + (node_map[source], node_map[target]) + { + for edge in first.edges(new_source) { + if edge.target() == new_target + && weights_equal(new_weight, edge.weight())? + { + found = true; + break; + } + } } } - } - if merge_edges { - for edge in edge_map { - combined.graph.remove_edge(edge); + if !found { + let new_source = extract(node_map[source]); + let new_target = extract(node_map[target]); + out_graph.add_edge( + new_source, + new_target, + new_weight.clone_ref(py), + ); } } - Ok(combined) + Ok(out_graph) +} + +/// Return a new PyGraph by forming a union from two input PyGraph objects +/// +/// The algorithm in this function operates in three phases: +/// +/// 1. Add all the nodes from ``second`` into ``first``. operates in O(n), +/// with n being number of nodes in `b`. +/// 2. Merge nodes from ``second`` over ``first`` given that: +/// +/// - The ``merge_nodes`` is ``True``. operates in O(n^2), with n being the +/// number of nodes in ``second``. +/// - The respective node in ``second`` and ``first`` share the same +/// weight/data payload. +/// +/// 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges`` +/// parameter is ``True`` and the respective edge in ``second`` and +/// first`` share the same weight/data payload they will be merged +/// together. +/// +/// :param PyGraph first: The first directed graph object +/// :param PyGraph second: The second directed graph object +/// :param bool merge_nodes: If set to ``True`` nodes will be merged between +/// ``second`` and ``first`` if the weights are equal. +/// :param bool merge_edges: If set to ``True`` edges will be merged between +/// ``second`` and ``first`` if the weights are equal. +/// +/// :returns: A new PyGraph object that is the union of ``second`` and +/// ``first``. It's worth noting the weight/data payload objects are +/// passed by reference from ``first`` and ``second`` to this new object. +/// :rtype: PyGraph +#[pyfunction(merge_nodes = false, merge_edges = false)] +#[pyo3( + text_signature = "(first, second, /, merge_nodes=False, merge_edges=False)" +)] +fn graph_union( + py: Python, + first: &graph::PyGraph, + second: &graph::PyGraph, + merge_nodes: bool, + merge_edges: bool, +) -> PyResult { + let out_graph = + union(py, &first.graph, &second.graph, merge_nodes, merge_edges)?; + + Ok(graph::PyGraph { + graph: out_graph, + node_removed: first.node_removed, + multigraph: true, + }) } /// Return a new PyDiGraph by forming a union from two input PyDiGraph objects @@ -141,8 +202,10 @@ fn _digraph_union( /// ``first``. It's worth noting the weight/data payload objects are /// passed by reference from ``first`` and ``second`` to this new object. /// :rtype: PyDiGraph -#[pyfunction] -#[pyo3(text_signature = "(first, second, merge_nodes, merge_edges, /)")] +#[pyfunction(merge_nodes = false, merge_edges = false)] +#[pyo3( + text_signature = "(first, second, /, merge_nodes=False, merge_edges=False)" +)] fn digraph_union( py: Python, first: &digraph::PyDiGraph, @@ -150,6 +213,14 @@ fn digraph_union( merge_nodes: bool, merge_edges: bool, ) -> PyResult { - let res = _digraph_union(py, first, second, merge_nodes, merge_edges)?; - Ok(res) + let out_graph = + union(py, &first.graph, &second.graph, merge_nodes, merge_edges)?; + + Ok(digraph::PyDiGraph { + graph: out_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: first.node_removed, + multigraph: true, + }) } diff --git a/tests/digraph/test_union.py b/tests/digraph/test_union.py index d32b7a528..0257e478b 100644 --- a/tests/digraph/test_union.py +++ b/tests/digraph/test_union.py @@ -31,23 +31,6 @@ def test_union_merge_all(self): self.assertTrue(retworkx.is_isomorphic(dag_a, dag_c)) - def test_union_basic_merge_edges_only(self): - dag_a = retworkx.PyDiGraph() - dag_b = retworkx.PyDiGraph() - - node_a = dag_a.add_node("a_1") - dag_a.add_child(node_a, "a_2", "e_1") - dag_a.add_child(node_a, "a_3", "e_2") - - node_b = dag_b.add_node("a_1") - dag_b.add_child(node_b, "a_2", "e_1") - dag_b.add_child(node_b, "a_3", "e_2") - - dag_c = retworkx.digraph_union(dag_a, dag_b, False, True) - - self.assertTrue(len(dag_c.edge_list()) == 2) - self.assertTrue(len(dag_c.nodes()) == 6) - def test_union_basic_merge_nodes_only(self): dag_a = retworkx.PyDiGraph() dag_b = retworkx.PyDiGraph() @@ -82,3 +65,33 @@ def test_union_basic_merge_none(self): self.assertTrue(len(dag_c.nodes()) == 6) self.assertTrue(len(dag_c.edge_list()) == 4) + + def test_union_mismatch_edge_weight(self): + first = retworkx.PyDiGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyDiGraph() + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.digraph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 1, "b")]) + + def test_union_node_hole(self): + first = retworkx.PyDiGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyDiGraph() + dummy = second.add_node("dummy") + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "a")]) + second.remove_node(dummy) + + final = retworkx.digraph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a")]) diff --git a/tests/graph/test_union.py b/tests/graph/test_union.py new file mode 100644 index 000000000..1136fd0ca --- /dev/null +++ b/tests/graph/test_union.py @@ -0,0 +1,76 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import retworkx + + +class TestUnion(unittest.TestCase): + def test_union_basic_merge_none(self): + graph = retworkx.PyGraph() + graph.add_nodes_from(["a_1", "a_2", "a_3"]) + graph.extend_from_weighted_edge_list([(0, 1, "e_1"), (1, 2, "e_2")]) + final = retworkx.graph_union( + graph, graph, merge_nodes=False, merge_edges=False + ) + self.assertTrue(len(final.nodes()) == 6) + self.assertTrue(len(final.edge_list()) == 4) + + def test_union_merge_all(self): + graph = retworkx.PyGraph() + graph.add_nodes_from(["a_1", "a_2", "a_3"]) + graph.extend_from_weighted_edge_list([(0, 1, "e_1"), (1, 2, "e_2")]) + final = retworkx.graph_union( + graph, graph, merge_nodes=True, merge_edges=True + ) + self.assertTrue(retworkx.is_isomorphic(final, graph)) + + def test_union_basic_merge_nodes_only(self): + graph = retworkx.PyGraph() + graph.add_nodes_from(["a_1", "a_2", "a_3"]) + graph.extend_from_weighted_edge_list([(0, 1, "e_1"), (1, 2, "e_2")]) + final = retworkx.graph_union( + graph, graph, merge_nodes=True, merge_edges=False + ) + self.assertTrue(len(final.edge_list()) == 4) + self.assertTrue(len(final.get_all_edge_data(0, 1)) == 2) + self.assertTrue(len(final.nodes()) == 3) + + def test_union_mismatch_edge_weight(self): + first = retworkx.PyGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyGraph() + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "b")]) + + final = retworkx.graph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 1, "b")]) + + def test_union_node_hole(self): + first = retworkx.PyGraph() + nodes = first.add_nodes_from([0, 1]) + first.add_edges_from([(nodes[0], nodes[1], "a")]) + + second = retworkx.PyGraph() + dummy = second.add_node("dummy") + nodes = second.add_nodes_from([0, 1]) + second.add_edges_from([(nodes[0], nodes[1], "a")]) + second.remove_node(dummy) + + final = retworkx.graph_union( + first, second, merge_nodes=True, merge_edges=True + ) + self.assertEqual(final.weighted_edge_list(), [(0, 1, "a")])