Skip to content

Commit

Permalink
Fixes digraph_union if merge_edges is set to true.
Browse files Browse the repository at this point in the history
Previously, `digraph_union` would falsely keep or delete
edges if `merge_edges` is set to true. This commit fixes
the logic of `digraph_union` to skip an edge from the second
graph if both its endpoints were merged to nodes from the
first graph and these nodes already share an edge with equal
weight data. At the same time, a new function `graph_union`
was added that returns the union of two `PyGraph`s.
Closes Qiskit#432.
  • Loading branch information
georgios-ts committed Sep 7, 2021
1 parent f2f3a09 commit f9b0050
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 102 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Expand Up @@ -146,6 +146,7 @@ Other Algorithm Functions
retworkx.core_number
retworkx.graph_greedy_color
retworkx.digraph_union
retworkx.graph_union
retworkx.metric_closure

Generators
Expand Down
24 changes: 24 additions & 0 deletions releasenotes/notes/bugfix-union-7da79789134a3028.yaml
@@ -0,0 +1,24 @@
---
fixes:
- |
Previously, :func:`~retworkx.digraph_union` would falsely keep or delete edges
if argument `merge_edges` is set to true. This has been fixed and an edge from
the second graph will be skipped if both its endpoints were merged to nodes from
the first graph and these nodes already share an edge with equal weight data.
Fixed `#432 <https://github.com/Qiskit/retworkx/issues/432>`__
features:
- |
Adds a new function :func:`~retworkx.graph_union` that returns the union
of two :class:`~retworkx.PyGraph` objects.
For example:
.. jupyter-execute::
import retworkx
from retworkx.visualization import mpl_draw
first = retworkx.generators.path_graph(3, weights=["a_0", "node", "a_1"])
second = retworkx.generators.cycle_graph(3, weights=["node", "b_0", "b_1"])
graph = retworkx.graph_union(first, second, merge_nodes=True)
mpl_draw(graph)
1 change: 1 addition & 0 deletions src/lib.rs
Expand Up @@ -171,6 +171,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(digraph_vf2_mapping))?;
m.add_wrapped(wrap_pyfunction!(graph_vf2_mapping))?;
m.add_wrapped(wrap_pyfunction!(digraph_union))?;
m.add_wrapped(wrap_pyfunction!(graph_union))?;
m.add_wrapped(wrap_pyfunction!(topological_sort))?;
m.add_wrapped(wrap_pyfunction!(descendants))?;
m.add_wrapped(wrap_pyfunction!(ancestors))?;
Expand Down
241 changes: 156 additions & 85 deletions src/union.rs
Expand Up @@ -10,106 +10,167 @@
// License for the specific language governing permissions and limitations
// under the License.

use crate::{digraph, digraph::PyDiGraph};
use hashbrown::{HashMap, HashSet};
use petgraph::algo;
use petgraph::graph::EdgeIndex;
use crate::{digraph, graph};

use petgraph::stable_graph::{NodeIndex, StableGraph};
use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable};
use petgraph::{algo, EdgeType};

use pyo3::prelude::*;
use pyo3::Python;
use std::cmp::Ordering;

/// [Graph] Return a new PyDiGraph by forming a union from`a` and `b` graphs.
///
/// The algorithm has three phases:
/// - adds all nodes from `b` to `a`. operates in O(n), n being number of nodes in `b`.
/// - merges nodes from `b` over `a` given that:
/// - `merge_nodes` is `true`. operates in O(n^2), n being number of nodes in `b`.
/// - respective node in`b` and `a` share the same weight
/// - adds all edges from `b` to `a`.
/// - `merge_edges` is `true`
/// - respective edge in`b` and `a` share the same weight
///
/// with the same weight in graphs `a` and `b` and merged those nodes.
///
/// The nodes from graph `b` will replace nodes from `a`.
///
/// At this point, only `PyDiGraph` is supported.
fn _digraph_union(
py: Python,
a: &PyDiGraph,
b: &PyDiGraph,
merge_nodes: bool,
merge_edges: bool,
) -> PyResult<PyDiGraph> {
let first = &a.graph;
let second = &b.graph;
let mut combined = PyDiGraph {
graph: first.clone(),
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
};
let mut node_map = HashMap::with_capacity(second.node_count());
let mut edge_map = HashSet::with_capacity(second.edge_count());
type StablePyGraph<Ty> = StableGraph<PyObject, PyObject, Ty>;

let compare_weights = |a: &PyAny, b: &PyAny| -> PyResult<bool> {
let res = a.compare(b)?;
Ok(res == Ordering::Equal)
};

for node in second.node_indices() {
let node_index = combined.add_node(second[node].clone_ref(py))?;
node_map.insert(node.index(), node_index);
fn find_node_by_weight<Ty: EdgeType>(
py: Python,
graph: &StablePyGraph<Ty>,
obj: &PyObject,
) -> PyResult<Option<NodeIndex>> {
let mut index = None;
for node in graph.node_indices() {
let weight = graph.node_weight(node).unwrap();
if obj
.as_ref(py)
.rich_compare(weight, pyo3::basic::CompareOp::Eq)?
.is_true()?
{
index = Some(node);
break;
}
}
Ok(index)
}

for edge in b.weighted_edge_list(py).edges {
let source = edge.0;
let target = edge.1;
let edge_weight = edge.2;

let new_source = *node_map.get(&source).unwrap();
let new_target = *node_map.get(&target).unwrap();
#[derive(Copy, Clone)]
enum Entry<T> {
Merged(T),
Added(T),
None,
}

let edge_index = combined.add_edge(
new_source,
new_target,
edge_weight.clone_ref(py),
)?;
fn extract<T>(x: Entry<T>) -> T {
match x {
Entry::Merged(val) => val,
Entry::Added(val) => val,
Entry::None => panic!("called `Entry::extract()` on a `None` value"),
}
}

let edge_node = EdgeIndex::new(edge_index);
fn union<Ty: EdgeType>(
py: Python,
first: &StablePyGraph<Ty>,
second: &StablePyGraph<Ty>,
merge_nodes: bool,
merge_edges: bool,
) -> PyResult<StablePyGraph<Ty>> {
let mut out_graph = first.clone();

if combined.has_edge(source, target) {
let w = combined.graph.edge_weight(edge_node).unwrap();
if compare_weights(edge_weight.as_ref(py), w.as_ref(py)).unwrap() {
edge_map.insert(edge_node);
let mut node_map: Vec<Entry<NodeIndex>> =
vec![Entry::None; second.node_bound()];
for node in second.node_indices() {
let weight = &second[node];
if merge_nodes {
if let Some(index) = find_node_by_weight(py, first, weight)? {
node_map[node.index()] = Entry::Merged(index);
continue;
}
}

let index = out_graph.add_node(weight.clone_ref(py));
node_map[node.index()] = Entry::Added(index);
}

if merge_nodes {
for node in second.node_indices() {
let weight = &second[node].clone_ref(py);
let index = a.find_node_by_weight(py, weight.clone_ref(py));

if index.is_some() {
let other_node = node_map.get(&node.index());
combined.merge_nodes(
py,
*other_node.unwrap(),
index.unwrap(),
)?;
let weights_equal = |a: &PyObject, b: &PyObject| -> PyResult<bool> {
a.as_ref(py)
.rich_compare(b, pyo3::basic::CompareOp::Eq)?
.is_true()
};

for edge in second.edge_references() {
let source = edge.source().index();
let target = edge.target().index();
let new_weight = edge.weight();

let mut found = false;
if merge_edges {
// if both endpoints were merged,
// check if need to skip the edge as well.
if let (Entry::Merged(new_source), Entry::Merged(new_target)) =
(node_map[source], node_map[target])
{
for edge in first.edges(new_source) {
if edge.target() == new_target
&& weights_equal(new_weight, edge.weight())?
{
found = true;
break;
}
}
}
}
}

if merge_edges {
for edge in edge_map {
combined.graph.remove_edge(edge);
if !found {
let new_source = extract(node_map[source]);
let new_target = extract(node_map[target]);
out_graph.add_edge(
new_source,
new_target,
new_weight.clone_ref(py),
);
}
}

Ok(combined)
Ok(out_graph)
}

/// Return a new PyGraph by forming a union from two input PyGraph objects
///
/// The algorithm in this function operates in three phases:
///
/// 1. Add all the nodes from ``second`` into ``first``. operates in O(n),
/// with n being number of nodes in `b`.
/// 2. Merge nodes from ``second`` over ``first`` given that:
///
/// - The ``merge_nodes`` is ``True``. operates in O(n^2), with n being the
/// number of nodes in ``second``.
/// - The respective node in ``second`` and ``first`` share the same
/// weight/data payload.
///
/// 3. Adds all the edges from ``second`` to ``first``. If the ``merge_edges``
/// parameter is ``True`` and the respective edge in ``second`` and
/// first`` share the same weight/data payload they will be merged
/// together.
///
/// :param PyGraph first: The first directed graph object
/// :param PyGraph second: The second directed graph object
/// :param bool merge_nodes: If set to ``True`` nodes will be merged between
/// ``second`` and ``first`` if the weights are equal.
/// :param bool merge_edges: If set to ``True`` edges will be merged between
/// ``second`` and ``first`` if the weights are equal.
///
/// :returns: A new PyGraph object that is the union of ``second`` and
/// ``first``. It's worth noting the weight/data payload objects are
/// passed by reference from ``first`` and ``second`` to this new object.
/// :rtype: PyGraph
#[pyfunction(merge_nodes = false, merge_edges = false)]
#[pyo3(
text_signature = "(first, second, /, merge_nodes=False, merge_edges=False)"
)]
fn graph_union(
py: Python,
first: &graph::PyGraph,
second: &graph::PyGraph,
merge_nodes: bool,
merge_edges: bool,
) -> PyResult<graph::PyGraph> {
let out_graph =
union(py, &first.graph, &second.graph, merge_nodes, merge_edges)?;

Ok(graph::PyGraph {
graph: out_graph,
node_removed: first.node_removed,
multigraph: true,
})
}

/// Return a new PyDiGraph by forming a union from two input PyDiGraph objects
Expand Down Expand Up @@ -141,15 +202,25 @@ fn _digraph_union(
/// ``first``. It's worth noting the weight/data payload objects are
/// passed by reference from ``first`` and ``second`` to this new object.
/// :rtype: PyDiGraph
#[pyfunction]
#[pyo3(text_signature = "(first, second, merge_nodes, merge_edges, /)")]
#[pyfunction(merge_nodes = false, merge_edges = false)]
#[pyo3(
text_signature = "(first, second, /, merge_nodes=False, merge_edges=False)"
)]
fn digraph_union(
py: Python,
first: &digraph::PyDiGraph,
second: &digraph::PyDiGraph,
merge_nodes: bool,
merge_edges: bool,
) -> PyResult<digraph::PyDiGraph> {
let res = _digraph_union(py, first, second, merge_nodes, merge_edges)?;
Ok(res)
let out_graph =
union(py, &first.graph, &second.graph, merge_nodes, merge_edges)?;

Ok(digraph::PyDiGraph {
graph: out_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: first.node_removed,
multigraph: true,
})
}
47 changes: 30 additions & 17 deletions tests/digraph/test_union.py
Expand Up @@ -31,23 +31,6 @@ def test_union_merge_all(self):

self.assertTrue(retworkx.is_isomorphic(dag_a, dag_c))

def test_union_basic_merge_edges_only(self):
dag_a = retworkx.PyDiGraph()
dag_b = retworkx.PyDiGraph()

node_a = dag_a.add_node("a_1")
dag_a.add_child(node_a, "a_2", "e_1")
dag_a.add_child(node_a, "a_3", "e_2")

node_b = dag_b.add_node("a_1")
dag_b.add_child(node_b, "a_2", "e_1")
dag_b.add_child(node_b, "a_3", "e_2")

dag_c = retworkx.digraph_union(dag_a, dag_b, False, True)

self.assertTrue(len(dag_c.edge_list()) == 2)
self.assertTrue(len(dag_c.nodes()) == 6)

def test_union_basic_merge_nodes_only(self):
dag_a = retworkx.PyDiGraph()
dag_b = retworkx.PyDiGraph()
Expand Down Expand Up @@ -82,3 +65,33 @@ def test_union_basic_merge_none(self):

self.assertTrue(len(dag_c.nodes()) == 6)
self.assertTrue(len(dag_c.edge_list()) == 4)

def test_union_mismatch_edge_weight(self):
first = retworkx.PyDiGraph()
nodes = first.add_nodes_from([0, 1])
first.add_edges_from([(nodes[0], nodes[1], "a")])

second = retworkx.PyDiGraph()
nodes = second.add_nodes_from([0, 1])
second.add_edges_from([(nodes[0], nodes[1], "b")])

final = retworkx.digraph_union(
first, second, merge_nodes=True, merge_edges=True
)
self.assertEqual(final.weighted_edge_list(), [(0, 1, "a"), (0, 1, "b")])

def test_union_node_hole(self):
first = retworkx.PyDiGraph()
nodes = first.add_nodes_from([0, 1])
first.add_edges_from([(nodes[0], nodes[1], "a")])

second = retworkx.PyDiGraph()
dummy = second.add_node("dummy")
nodes = second.add_nodes_from([0, 1])
second.add_edges_from([(nodes[0], nodes[1], "a")])
second.remove_node(dummy)

final = retworkx.digraph_union(
first, second, merge_nodes=True, merge_edges=True
)
self.assertEqual(final.weighted_edge_list(), [(0, 1, "a")])

0 comments on commit f9b0050

Please sign in to comment.