diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index 7d3d475708..8841388dc1 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -326,39 +326,73 @@ end A type that allows iterating over the scalar-functions that comprise an `AbstractVectorFunction`. """ -struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction} +struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction, C} f::F - # Vectors which map output indices to their terms. - affine::Vector{Vector{Int}} - quadratic::Vector{Vector{Int}} + # Cache that can be used to store a precomputed datastructure that allows + # an efficient implementation of `getindex`. + cache::C end - -function ScalarFunctionIterator(f::MOI.VectorOfVariables) +function ScalarFunctionIterator(func::MOI.AbstractVectorFunction) return ScalarFunctionIterator( - f, - Vector{Int}[], - Vector{Int}[], + func, + scalar_iterator_cache(func), ) end -function ScalarFunctionIterator(f::MOI.VectorAffineFunction) - d = [Int[] for i = 1:MOI.output_dimension(f)] - for (i, term) in enumerate(f.terms) - push!(d[term.output_index], i) +scalar_iterator_cache(func::MOI.AbstractVectorFunction) = nothing + +function output_index_iterator(terms::AbstractVector, output_dimension) + start = zeros(Int, output_dimension) + next = Vector{Int}(undef, length(terms)) + last = zeros(Int, output_dimension) + for i in eachindex(terms) + j = terms[i].output_index + if iszero(last[j]) + start[j] = i + else + next[last[j]] = i + end + last[j] = i end - return ScalarFunctionIterator(f, d, Vector{Int}[]) + for j in eachindex(last) + if !iszero(last[j]) + next[last[j]] = 0 + end + end + return ChainedIterator(start, next) +end +struct ChainedIterator + start::Vector{Int} + next::Vector{Int} +end +struct ChainedIteratorAtIndex + start::Int + next::Vector{Int} +end +function ChainedIteratorAtIndex(it::ChainedIterator, index::Int) + return ChainedIteratorAtIndex(it.start[index], it.next) +end +#TODO We could also precompute the length for each `output_index`, +# check that it's a win. +Base.IteratorSize(::ChainedIteratorAtIndex) = Base.SizeUnknown() +function Base.iterate(it::ChainedIteratorAtIndex, i = it.start) + if iszero(i) + return nothing + else + return i, it.next[i] + end +end + +function ScalarFunctionIterator(f::MOI.VectorAffineFunction) + return ScalarFunctionIterator(f, output_index_iterator(f.terms, MOI.output_dimension(f))) end function ScalarFunctionIterator(f::MOI.VectorQuadraticFunction) - aff = [Int[] for i = 1:MOI.output_dimension(f)] - quad = [Int[] for i = 1:MOI.output_dimension(f)] - for (i, term) in enumerate(f.affine_terms) - push!(aff[term.output_index], i) - end - for (i, term) in enumerate(f.quadratic_terms) - push!(quad[term.output_index], i) - end - return ScalarFunctionIterator(f, aff, quad) + return ScalarFunctionIterator( + f, + (output_index_iterator(f.affine_terms, MOI.output_dimension(f)), + output_index_iterator(f.quadratic_terms, MOI.output_dimension(f))), + ) end eachscalar(f::MOI.AbstractVectorFunction) = ScalarFunctionIterator(f) @@ -407,7 +441,7 @@ function Base.getindex( return MOI.ScalarAffineFunction{T}( MOI.ScalarAffineTerm{T}[ it.f.terms[i].scalar_term - for i in it.affine[output_index] + for i in ChainedIteratorAtIndex(it.cache, output_index) ], it.f.constants[output_index], ) @@ -419,7 +453,7 @@ function Base.getindex( ) where {T} terms = MOI.VectorAffineTerm{T}[] for (i, output_index) in enumerate(output_indices) - for j in it.affine[output_index] + for j in ChainedIteratorAtIndex(it.cache, output_index) push!(terms, MOI.VectorAffineTerm(i, it.f.terms[j].scalar_term)) end end @@ -434,10 +468,10 @@ function Base.getindex( ) where {T} return MOI.ScalarQuadraticFunction( MOI.ScalarAffineTerm{T}[ - it.f.affine_terms[i].scalar_term for i in it.affine[output_index] + it.f.affine_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[1], output_index) ], MOI.ScalarQuadraticTerm{T}[ - it.f.quadratic_terms[i].scalar_term for i in it.quadratic[output_index] + it.f.quadratic_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[2], output_index) ], it.f.constants[output_index], ) @@ -450,13 +484,13 @@ function Base.getindex( vat = MOI.VectorAffineTerm{T}[] vqt = MOI.VectorQuadraticTerm{T}[] for (i, output_index) in enumerate(output_indices) - for j in it.affine[output_index] + for j in ChainedIteratorAtIndex(it.cache[1], output_index) push!( vat, MOI.VectorAffineTerm(i, it.f.affine_terms[j].scalar_term), ) end - for j in it.quadratic[output_index] + for j in ChainedIteratorAtIndex(it.cache[2], output_index) push!( vqt, MOI.VectorQuadraticTerm(i, it.f.quadratic_terms[j].scalar_term),