diff --git a/src/Nonlinear/ReverseAD/Coloring/Coloring.jl b/src/Nonlinear/ReverseAD/Coloring/Coloring.jl index b03860a48f..52e5d6f431 100644 --- a/src/Nonlinear/ReverseAD/Coloring/Coloring.jl +++ b/src/Nonlinear/ReverseAD/Coloring/Coloring.jl @@ -171,14 +171,9 @@ function _grow_star(v, w, e_idx, firstNeighbor, color, S) @inbounds if p != v firstNeighbor[color[w]] = _Edge(e_idx, v, w) else - _root_union!(S, e_idx, e.index) - end - return -end - -function _merge_trees(S::_IntDisjointSet, eg::Int, eg1::Int) - if _find_root!(S, eg) != _find_root!(S, eg1) - _root_union!(S, eg, eg1) + root1 = _find_root!(S, e_idx) + root2 = _find_root!(S, e.index) + _root_union!(S, root1, root2) end return end @@ -280,7 +275,7 @@ function acyclic_coloring(g::UndirectedGraph) continue end if color[x] == color[v] - _merge_trees(S, e_idx, e2_idx) + _union!(S, e_idx, e2_idx) end end end diff --git a/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl b/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl index b22ba407a1..4fb6ea261d 100644 --- a/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl +++ b/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl @@ -45,3 +45,12 @@ function _root_union!(S::_IntDisjointSet, x::Int, y::Int) S.number_of_trees -= 1 return end + +function _union!(S, x::Int, y::Int) + root_x = _find_root!(S, x) + root_y = _find_root!(S, y) + if root_x != root_y + _root_union!(S, root_x, root_y) + end + return +end diff --git a/test/Nonlinear/ReverseAD.jl b/test/Nonlinear/ReverseAD.jl index 0da06f93c1..f1a6cc4fb0 100644 --- a/test/Nonlinear/ReverseAD.jl +++ b/test/Nonlinear/ReverseAD.jl @@ -1448,6 +1448,16 @@ function test_IntDisjointSet() return end +function test_issue_2897() + I = [4, 5, 4, 6, 5, 6] + J = [2, 1, 1, 2, 3, 3] + g = Coloring.UndirectedGraph(I, J, length(I)) + color, num_colors = Coloring.acyclic_coloring(g) + @test color == [1, 1, 1, 2, 2, 3] + @test num_colors == 3 + return +end + end # module TestReverseAD.runtests()