-
Notifications
You must be signed in to change notification settings - Fork 98
[Nonlinear.ReverseAd.Coloring] fix acyclic coloring algorithm #2898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
It was a bit hard to reason about the two subsequent diffs, so here's the combined one of what the original PR should have been: diff --git a/src/Nonlinear/ReverseAD/Coloring/Coloring.jl b/src/Nonlinear/ReverseAD/Coloring/Coloring.jl
index f847ce4e..5f00563a 100644
--- a/src/Nonlinear/ReverseAD/Coloring/Coloring.jl
+++ b/src/Nonlinear/ReverseAD/Coloring/Coloring.jl
@@ -6,8 +6,7 @@
module Coloring
-import DataStructures
-
+include("IntDisjointSet.jl")
include("topological_sort.jl")
"""
@@ -154,7 +153,7 @@ function _prevent_cycle(
forbiddenColors,
color,
)
- er = DataStructures.find_root!(S, e_idx2)
+ er = _find_root!(S, e_idx2)
@inbounds first = firstVisitToTree[er]
p = first.source # but this depends on the order?
q = first.target
@@ -172,29 +171,11 @@ function _grow_star(v, w, e_idx, firstNeighbor, color, S)
@inbounds if p != v
firstNeighbor[color[w]] = _Edge(e_idx, v, w)
else
- union!(S, e_idx, e.index)
- end
- return
-end
-
-function _merge_trees(eg, eg1, S)
- e1 = DataStructures.find_root!(S, eg)
- e2 = DataStructures.find_root!(S, eg1)
- if e1 != e2
- union!(S, eg, eg1)
+ _union!(S, e_idx, e.index)
end
return
end
-# Work-around a deprecation in DataStructures@0.19
-function _IntDisjointSet(n)
- @static if isdefined(DataStructures, :IntDisjointSet)
- return DataStructures.IntDisjointSet(n)
- else
- return DataStructures.IntDisjointSets(n) # COV_EXCL_LINE
- end
-end
-
"""
acyclic_coloring(g::UndirectedGraph)
@@ -214,7 +195,6 @@ function acyclic_coloring(g::UndirectedGraph)
firstNeighbor = _Edge[]
firstVisitToTree = fill(_Edge(0, 0, 0), _num_edges(g))
color = fill(0, _num_vertices(g))
- # disjoint set forest of edges in the graph
S = _IntDisjointSet(_num_edges(g))
@inbounds for v in 1:_num_vertices(g)
n_neighbor = _num_neighbors(v, g)
@@ -293,7 +273,7 @@ function acyclic_coloring(g::UndirectedGraph)
continue
end
if color[x] == color[v]
- _merge_trees(e_idx, e2_idx, S)
+ _union!(S, e_idx, e2_idx)
end
end
end
diff --git a/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl b/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl
new file mode 100644
index 00000000..4fb6ea26
--- /dev/null
+++ b/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl
@@ -0,0 +1,56 @@
+# Copyright (c) 2017: Miles Lubin and contributors
+# Copyright (c) 2017: Google Inc.
+# Copyright (c) 2024: Guillaume Dalle and Alexis Montoison
+#
+# Use of this source code is governed by an MIT-style license that can be found
+# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
+
+# The code in this file was taken from
+# https://github.com/gdalle/SparseMatrixColorings.jl/blob/main/src/Forest.jl
+#
+# It was copied at the suggestion of Alexis in his JuMP-dev 2025 talk.
+#
+# @odow made minor changes to match MOI coding styles.
+#
+# x-ref https://github.com/gdalle/SparseMatrixColorings.jl/pull/190
+
+mutable struct _IntDisjointSet
+ # current number of distinct trees in the S
+ number_of_trees::Int
+ # vector storing the index of a parent in the tree for each edge, used in
+ # union-find operations
+ parents::Vector{Int}
+ # vector approximating the depth of each tree to optimize path compression
+ ranks::Vector{Int}
+
+ _IntDisjointSet(n::Integer) = new(n, collect(1:n), zeros(Int, n))
+end
+
+function _find_root!(S::_IntDisjointSet, x::Integer)
+ p = S.parents[x]
+ if S.parents[p] != p
+ S.parents[x] = p = _find_root!(S, p)
+ end
+ return p
+end
+
+function _root_union!(S::_IntDisjointSet, x::Int, y::Int)
+ rank1, rank2 = S.ranks[x], S.ranks[y]
+ if rank1 < rank2
+ x, y = y, x
+ elseif rank1 == rank2
+ S.ranks[x] += 1
+ end
+ S.parents[y] = x
+ S.number_of_trees -= 1
+ return
+end
+
+function _union!(S, x::Int, y::Int)
+ root_x = _find_root!(S, x)
+ root_y = _find_root!(S, y)
+ if root_x != root_y
+ _root_union!(S, root_x, root_y)
+ end
+ return
+end |
|
As described at #2882 (comment), for #2882 I ran all the PureJuMP models in OptimizationModels and verified the hessians before and after the change. I think we should do that here. |
|
Yip. I'll also re-run the solver tests. |
@odow It is because the edge and the star don't have a "shared" edge so we can avoid to |
|
I see the issue, you checked that the "roots" are different but you merged the trees with the edge indices instead of the root indices. |
|
Now there are many large differences between SCT and JuMP: ┌ Warning: Inconsistencies were detected
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:74
┌ Warning: Inconsistency for Jacobian of hs117: SCT (75 nz) ⊃ JuMP (62 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:76
┌ Warning: Inconsistency for Jacobian of lincon: SCT (19 nz) ⊃ JuMP (17 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:76
┌ Warning: Inconsistency for Hessian of argauss: SCT (8 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of biggs5: SCT (9 nz) ⊂ JuMP (12 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of biggs6: SCT (9 nz) ⊂ JuMP (12 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of britgas: SCT (1087 nz) ⊂ JuMP (1111 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of chain: SCT (75 nz) ⊂ JuMP (100 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of channel: SCT (696 nz) ⊃ JuMP (672 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of dixmaane: SCT (493 nz) ⊃ JuMP (297 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of dixmaani: SCT (493 nz) ⊃ JuMP (297 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of dixmaanm: SCT (493 nz) ⊃ JuMP (297 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs114: SCT (19 nz) ⊂ JuMP (21 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs119: SCT (256 nz) ⊃ JuMP (76 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs250: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs251: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs36: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs37: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs40: SCT (15 nz) ⊂ JuMP (16 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs41: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs45: SCT (20 nz) ⊂ JuMP (25 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs56: SCT (10 nz) ⊂ JuMP (13 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs68: SCT (9 nz) ⊂ JuMP (10 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs69: SCT (9 nz) ⊂ JuMP (10 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs87: SCT (9 nz) ⊂ JuMP (11 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs93: SCT (34 nz) ⊂ JuMP (36 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of polygon1: SCT (550 nz) ⊂ JuMP (600 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of polygon2: SCT (350 nz) ⊂ JuMP (400 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of robotarm: SCT (252 nz) ⊂ JuMP (276 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79 |
Co-authored-by: Alexis Montoison <35051714+amontoison@users.noreply.github.com>
|
Okay. I ran the following. Running it also triggered the assert, so if I had re-run the Optimisation tests after #2885 we would have found this sooner. using Revise
using JuMP, OptimizationProblems, SparseArrays
function compute_random_hessian(name::String)
println(name)
try
model = getfield(OptimizationProblems.PureJuMP, Symbol(name))()
return _compute_random_hessian(model)
catch err
if err isa MOI.UnsupportedNonlinearOperator
return nothing
end
rethrow(err)
end
end
function _compute_random_hessian(model::Model)
rows = Any[]
nlp = MOI.Nonlinear.Model()
for (F, S) in list_of_constraint_types(model)
for ci in all_constraints(model, F, S)
push!(rows, ci)
object = constraint_object(ci)
MOI.Nonlinear.add_constraint(nlp, object.func, object.set)
end
end
MOI.Nonlinear.set_objective(nlp, objective_function(model))
x = all_variables(model)
backend = MOI.Nonlinear.SparseReverseMode()
evaluator = MOI.Nonlinear.Evaluator(nlp, backend, index.(x))
MOI.initialize(evaluator, [:Hess])
hessian_sparsity = MOI.hessian_lagrangian_structure(evaluator)
I = [i for (i, _) in hessian_sparsity]
J = [j for (_, j) in hessian_sparsity]
V = zeros(length(hessian_sparsity))
primal = sin.(1:length(x))
dual = cos.(1:length(rows))
MOI.eval_hessian_lagrangian(evaluator, V, primal, 1.234, dual)
return SparseArrays.sparse(I, J, V, length(x), length(x))
end
log = Dict(
name => compute_random_hessian(name)
for name in OptimizationProblems.meta[!, :name]
)
open("/tmp/log.txt", "w") do io
for name in OptimizationProblems.meta[!, :name]
H = log[name]
if H === nothing
println(io, name)
else
println(io, name, " ", nnz(H), " ", hash(H))
end
end
endfor both this PR and for MOI@1.46.0. There are a few differences: % diff log_pr.txt log_1.46.0.txt
54c54
< britgas 759 15207358802602782180
---
> britgas 911 7370492992218032816
132c132
< hs107 17 6271903269068003906
---
> hs107 17 5313540568828064766
140c140
< hs114 15 14596070845963864172
---
> hs114 16 14723025369218343364
335c335
< polygon3 200 5952677266859018177
---
> polygon3 200 1326136011177673121
366,368c366,368
< triangle_deer 15454 1481367466358415038
< triangle_pacman 9510 4334017736271540586
< triangle_turtle 31682 1790788003688105916
---
> triangle_deer 15454 9191873022530188051
> triangle_pacman 9510 2601781815016602201
> triangle_turtle 31682 4301116183540000934There are a few that have numerical differences, but that's just to some small tolerance. Some have reduced non-zeros in the correct places. screen_recording.mov |
Closes #2897
The issue was here:
https://github.com/jump-dev/MathOptInterface.jl/pull/2885/files#diff-42ba053a9aef9ff60f40635dd168d724ad628a29f2c47038a2d78d5b12b4c680R174
I assumed @amontoison had just renamed some things, but I didn't make the corresponding change to:
https://github.com/gdalle/SparseMatrixColorings.jl/blob/9b52faccdaae41d3ce27158434cc5597d1a61a36/src/coloring.jl#L390-L392
here's the original upstream of
Base.union!:https://github.com/JuliaCollections/DataStructures.jl/blob/b67c498a11402f6c18e5e74c69d95e2621f75aa0/src/disjoint_set.jl#L83-L94
There's still one small difference.
In SparseMatrixColorings.jl, the code to merge two trees is
but in the DataStructures version of JuMP it was equivalent to