Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/Utilities/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,17 @@ function map_indices(index_map::F, ci::MOI.ConstraintIndex) where {F<:Function}
end

function map_indices(index_map::F, x::AbstractArray) where {F<:Function}
return [map_indices(index_map, xi) for xi in x]
# @odow tried other alternatives here, like
# [map_indices(index_map, xi) for xi in x]
# and
# map(Base.Fix1(map_indices, index_map), x)
# but this was the fastest. The insight is that we know that map_indices
# does not modify the element type of `x`.
y = similar(x)
for i in eachindex(x)
y[i] = map_indices(index_map, x[i])
end
return y
end

function map_indices(index_map::F, t::MOI.ScalarAffineTerm) where {F<:Function}
Expand Down Expand Up @@ -337,11 +347,32 @@ function map_indices(
index_map::F,
f::MOI.ScalarNonlinearFunction,
) where {F<:Function}
# TODO(odow): this uses recursion. We should remove at some point.
return MOI.ScalarNonlinearFunction(
f.head,
convert(Vector{Any}, map_indices(index_map, f.args)),
)
root = MOI.ScalarNonlinearFunction(f.head, similar(f.args))
stack = Tuple{MOI.ScalarNonlinearFunction,Int,MOI.ScalarNonlinearFunction}[]
for (i, fi) in enumerate(f.args)
if fi isa MOI.ScalarNonlinearFunction
push!(stack, (root, i, fi))
else
root.args[i] = MOI.Utilities.map_indices(index_map, fi)
end
end
while !isempty(stack)
parent, i, arg = pop!(stack)
if arg isa MOI.ScalarNonlinearFunction
child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args))
for (j, argj) in enumerate(arg.args)
if argj isa MOI.ScalarNonlinearFunction
push!(stack, (child, j, argj))
else
child.args[j] = MOI.Utilities.map_indices(index_map, argj)
end
end
parent.args[i] = child
else
parent.args[i] = MOI.Utilities.map_indices(index_map, arg)
end
end
return root
end

function map_indices(
Expand Down
54 changes: 53 additions & 1 deletion test/Utilities/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1954,13 +1954,65 @@ function test_ScalarNonlinearFunction_map_indices()
x = MOI.add_variable(src)
f = MOI.ScalarNonlinearFunction(:log, Any[x])
c = MOI.add_constraint(src, f, MOI.LessThan(1.0))
dest = MOI.Utilities.Model{Float64}()
dest = MOI.Utilities.MockOptimizer(MOI.Utilities.Model{Float64}())
index_map = MOI.copy_to(dest, src)
new_f = MOI.Utilities.map_indices(index_map, f)
@test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[index_map[x]])
return
end

function test_ScalarNonlinearFunction_map_indices_nested()
src = MOI.Utilities.Model{Float64}()
x = MOI.add_variables(src, 5)
f = MOI.ScalarNonlinearFunction(
:+,
Any[
x[1],
MOI.ScalarNonlinearFunction(
:+,
Any[MOI.ScalarNonlinearFunction(:+, Any[x[2], x[3]]), x[4]],
),
x[5],
],
)
c = MOI.add_constraint(src, f, MOI.LessThan(1.0))
dest = MOI.Utilities.MockOptimizer(MOI.Utilities.Model{Float64}())
index_map = MOI.copy_to(dest, src)
new_f = MOI.Utilities.map_indices(index_map, f)
y = [index_map[xi] for xi in x]
@test new_f ≈ MOI.ScalarNonlinearFunction(
:+,
Any[
y[1],
MOI.ScalarNonlinearFunction(
:+,
Any[MOI.ScalarNonlinearFunction(:+, Any[y[2], y[3]]), y[4]],
),
y[5],
],
)
return
end

function test_ScalarNonlinearFunction_map_indices_deep_recursion()
src = MOI.Utilities.Model{Float64}()
x = MOI.add_variable(src)
f = MOI.ScalarNonlinearFunction(:log, Any[x])
for _ in 1:50_000
f = MOI.ScalarNonlinearFunction(:log, Any[f])
end
c = MOI.add_constraint(src, f, MOI.LessThan(1.0))
dest = MOI.Utilities.MockOptimizer(MOI.Utilities.Model{Float64}())
index_map = MOI.copy_to(dest, src)
new_f = MOI.Utilities.map_indices(index_map, f)
g = new_f
for _ in 1:50_000
g = g.args[1]
end
@test g ≈ MOI.ScalarNonlinearFunction(:log, Any[index_map[x]])
return
end

function test_ScalarNonlinearFunction_substitute_variables()
x = MOI.VariableIndex(1)
f = MOI.ScalarNonlinearFunction(:log, Any[1.0*x])
Expand Down