Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "1.46.0"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -23,7 +22,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
BenchmarkTools = "1"
CodecBzip2 = "0.6, 0.7, 0.8"
CodecZlib = "0.6, 0.7"
DataStructures = "0.18, 0.19"
ForwardDiff = "0.10, 1"
JSON3 = "1"
JSONSchema = "1"
Expand Down
2 changes: 1 addition & 1 deletion src/FileFormats/MPS/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module MPS

import ..FileFormats
import MathOptInterface as MOI
import DataStructures: OrderedDict
import OrderedCollections: OrderedDict

const IndicatorLessThanTrue{T} =
MOI.Indicator{MOI.ACTIVATE_ON_ONE,MOI.LessThan{T}}
Expand Down
27 changes: 7 additions & 20 deletions src/Nonlinear/ReverseAD/Coloring/Coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

module Coloring

import DataStructures

include("IntDisjointSet.jl")
include("topological_sort.jl")

"""
Expand Down Expand Up @@ -154,7 +153,7 @@ function _prevent_cycle(
forbiddenColors,
color,
)
er = DataStructures.find_root!(S, e_idx2)
er = _find_root!(S, e_idx2)
@inbounds first = firstVisitToTree[er]
p = first.source # but this depends on the order?
q = first.target
Expand All @@ -172,29 +171,18 @@ function _grow_star(v, w, e_idx, firstNeighbor, color, S)
@inbounds if p != v
firstNeighbor[color[w]] = _Edge(e_idx, v, w)
else
union!(S, e_idx, e.index)
_root_union!(S, e_idx, e.index)
end
return
end

function _merge_trees(eg, eg1, S)
e1 = DataStructures.find_root!(S, eg)
e2 = DataStructures.find_root!(S, eg1)
if e1 != e2
union!(S, eg, eg1)
function _merge_trees(S::_IntDisjointSet, eg::Int, eg1::Int)
if _find_root!(S, eg) != _find_root!(S, eg1)
_root_union!(S, eg, eg1)
end
return
end

# Work-around a deprecation in DataStructures@0.19
function _IntDisjointSet(n)
@static if isdefined(DataStructures, :IntDisjointSet)
return DataStructures.IntDisjointSet(n)
else
return DataStructures.IntDisjointSets(n) # COV_EXCL_LINE
end
end

"""
acyclic_coloring(g::UndirectedGraph)

Expand All @@ -214,7 +202,6 @@ function acyclic_coloring(g::UndirectedGraph)
firstNeighbor = _Edge[]
firstVisitToTree = fill(_Edge(0, 0, 0), _num_edges(g))
color = fill(0, _num_vertices(g))
# disjoint set forest of edges in the graph
S = _IntDisjointSet(_num_edges(g))
@inbounds for v in 1:_num_vertices(g)
n_neighbor = _num_neighbors(v, g)
Expand Down Expand Up @@ -293,7 +280,7 @@ function acyclic_coloring(g::UndirectedGraph)
continue
end
if color[x] == color[v]
_merge_trees(e_idx, e2_idx, S)
_merge_trees(S, e_idx, e2_idx)
end
end
end
Expand Down
47 changes: 47 additions & 0 deletions src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2017: Miles Lubin and contributors
# Copyright (c) 2017: Google Inc.
# Copyright (c) 2024: Guillaume Dalle and Alexis Montoison
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

# The code in this file was taken from
# https://github.com/gdalle/SparseMatrixColorings.jl/blob/main/src/Forest.jl
#
# It was copied at the suggestion of Alexis in his JuMP-dev 2025 talk.
#
# @odow made minor changes to match MOI coding styles.
#
# x-ref https://github.com/gdalle/SparseMatrixColorings.jl/pull/190

mutable struct _IntDisjointSet
# current number of distinct trees in the S
number_of_trees::Int
# vector storing the index of a parent in the tree for each edge, used in
# union-find operations
parents::Vector{Int}
# vector approximating the depth of each tree to optimize path compression
ranks::Vector{Int}

_IntDisjointSet(n::Integer) = new(n, collect(1:n), zeros(Int, n))
end

function _find_root!(S::_IntDisjointSet, x::Integer)
p = S.parents[x]
if S.parents[p] != p
S.parents[x] = p = _find_root!(S, p)
end
return p
end

function _root_union!(S::_IntDisjointSet, x::Int, y::Int)
rank1, rank2 = S.ranks[x], S.ranks[y]
if rank1 < rank2
x, y = y, x
elseif rank1 == rank2
S.ranks[x] += 1
end
S.parents[y] = x
S.number_of_trees -= 1
return
end
2 changes: 1 addition & 1 deletion test/FileFormats/MPS/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Test

import MathOptInterface as MOI
import MathOptInterface.FileFormats: MPS
import DataStructures: OrderedDict
import OrderedCollections: OrderedDict

function runtests()
for name in names(@__MODULE__; all = true)
Expand Down
33 changes: 30 additions & 3 deletions test/Nonlinear/ReverseAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import LinearAlgebra
import MathOptInterface as MOI
import SparseArrays

const Nonlinear = MOI.Nonlinear
const ReverseAD = Nonlinear.ReverseAD
const Coloring = ReverseAD.Coloring
import MathOptInterface.Nonlinear
import MathOptInterface.Nonlinear.ReverseAD
import MathOptInterface.Nonlinear.ReverseAD.Coloring

function runtests()
for name in names(@__MODULE__; all = true)
Expand Down Expand Up @@ -1421,6 +1421,33 @@ function test_hessian_reinterpret_unsafe()
return
end

function test_IntDisjointSet()
for case in [
[(1, 2) => [1, 1, 3], (1, 3) => [1, 1, 1]],
[(1, 2) => [1, 1, 3], (3, 1) => [1, 1, 1]],
[(2, 1) => [2, 2, 3], (1, 3) => [2, 2, 2]],
[(2, 1) => [2, 2, 3], (3, 1) => [3, 2, 3]],
[(1, 3) => [1, 2, 1], (2, 3) => [1, 2, 2]],
[(1, 3) => [1, 2, 1], (3, 2) => [1, 1, 1]],
[(3, 1) => [3, 2, 3], (2, 3) => [3, 3, 3]],
[(3, 1) => [3, 2, 3], (3, 2) => [3, 3, 3]],
[(2, 3) => [1, 2, 2], (1, 3) => [1, 2, 1]],
[(2, 3) => [1, 2, 2], (3, 1) => [2, 2, 2]],
[(3, 2) => [1, 3, 3], (1, 3) => [3, 3, 3]],
[(3, 2) => [1, 3, 3], (3, 1) => [3, 3, 3]],
]
S = Coloring._IntDisjointSet(3)
@test Coloring._find_root!.((S,), [1, 2, 3]) == [1, 2, 3]
@test S.number_of_trees == 3
for (i, (union, result)) in enumerate(case)
Coloring._root_union!(S, union[1], union[2])
@test Coloring._find_root!.((S,), [1, 2, 3]) == result
@test S.number_of_trees == 3 - i
end
end
return
end

end # module

TestReverseAD.runtests()
Loading