diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index bebb872752..8841388dc1 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -320,9 +320,81 @@ function scalar_type(::Type{MOI.VectorQuadraticFunction{T}}) where {T} return MOI.ScalarQuadraticFunction{T} end -struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction} +""" + ScalarFunctionIterator{F<:MOI.AbstractVectorFunction} + +A type that allows iterating over the scalar-functions that comprise an +`AbstractVectorFunction`. +""" +struct ScalarFunctionIterator{F<:MOI.AbstractVectorFunction, C} f::F + # Cache that can be used to store a precomputed datastructure that allows + # an efficient implementation of `getindex`. + cache::C +end +function ScalarFunctionIterator(func::MOI.AbstractVectorFunction) + return ScalarFunctionIterator( + func, + scalar_iterator_cache(func), + ) +end + +scalar_iterator_cache(func::MOI.AbstractVectorFunction) = nothing + +function output_index_iterator(terms::AbstractVector, output_dimension) + start = zeros(Int, output_dimension) + next = Vector{Int}(undef, length(terms)) + last = zeros(Int, output_dimension) + for i in eachindex(terms) + j = terms[i].output_index + if iszero(last[j]) + start[j] = i + else + next[last[j]] = i + end + last[j] = i + end + for j in eachindex(last) + if !iszero(last[j]) + next[last[j]] = 0 + end + end + return ChainedIterator(start, next) +end +struct ChainedIterator + start::Vector{Int} + next::Vector{Int} +end +struct ChainedIteratorAtIndex + start::Int + next::Vector{Int} +end +function ChainedIteratorAtIndex(it::ChainedIterator, index::Int) + return ChainedIteratorAtIndex(it.start[index], it.next) end +#TODO We could also precompute the length for each `output_index`, +# check that it's a win. +Base.IteratorSize(::ChainedIteratorAtIndex) = Base.SizeUnknown() +function Base.iterate(it::ChainedIteratorAtIndex, i = it.start) + if iszero(i) + return nothing + else + return i, it.next[i] + end +end + +function ScalarFunctionIterator(f::MOI.VectorAffineFunction) + return ScalarFunctionIterator(f, output_index_iterator(f.terms, MOI.output_dimension(f))) +end + +function ScalarFunctionIterator(f::MOI.VectorQuadraticFunction) + return ScalarFunctionIterator( + f, + (output_index_iterator(f.affine_terms, MOI.output_dimension(f)), + output_index_iterator(f.quadratic_terms, MOI.output_dimension(f))), + ) +end + eachscalar(f::MOI.AbstractVectorFunction) = ScalarFunctionIterator(f) eachscalar(f::AbstractVector) = f @@ -344,70 +416,88 @@ Base.lastindex(it::ScalarFunctionIterator) = length(it) # Define getindex for Vector functions +# VectorOfVariables + function Base.getindex( it::ScalarFunctionIterator{MOI.VectorOfVariables}, - i::Integer, -) - return MOI.SingleVariable(it.f.variables[i]) -end -# Returns the scalar terms of output_index i -function scalar_terms_at_index( - terms::Vector{<:Union{MOI.VectorAffineTerm,MOI.VectorQuadraticTerm}}, - i::Int, + output_index::Integer, ) - return [term.scalar_term for term in terms if term.output_index == i] -end -function Base.getindex(it::ScalarFunctionIterator{<:VAF}, i::Integer) - return SAF(scalar_terms_at_index(it.f.terms, i), it.f.constants[i]) -end -function Base.getindex(it::ScalarFunctionIterator{<:VQF}, i::Integer) - lin = scalar_terms_at_index(it.f.affine_terms, i) - quad = scalar_terms_at_index(it.f.quadratic_terms, i) - return SQF(lin, quad, it.f.constants[i]) + return MOI.SingleVariable(it.f.variables[output_index]) end function Base.getindex( it::ScalarFunctionIterator{MOI.VectorOfVariables}, - I::AbstractVector, + output_indices::AbstractVector{<:Integer}, ) - return MOI.VectorOfVariables(it.f.variables[I]) + return MOI.VectorOfVariables(it.f.variables[output_indices]) +end + +# VectorAffineFunction + +function Base.getindex( + it::ScalarFunctionIterator{MOI.VectorAffineFunction{T}}, + output_index::Integer, +) where {T} + return MOI.ScalarAffineFunction{T}( + MOI.ScalarAffineTerm{T}[ + it.f.terms[i].scalar_term + for i in ChainedIteratorAtIndex(it.cache, output_index) + ], + it.f.constants[output_index], + ) end + function Base.getindex( - it::ScalarFunctionIterator{VAF{T}}, - I::AbstractVector, + it::ScalarFunctionIterator{MOI.VectorAffineFunction{T}}, + output_indices::AbstractVector{<:Integer}, ) where {T} terms = MOI.VectorAffineTerm{T}[] - # assume at least one term per index - sizehint!(terms, length(I)) - constant = it.f.constants[I] - for term in it.f.terms - idx = findfirst(Base.Fix1(==, term.output_index), I) - if idx !== nothing - push!(terms, MOI.VectorAffineTerm(idx, term.scalar_term)) + for (i, output_index) in enumerate(output_indices) + for j in ChainedIteratorAtIndex(it.cache, output_index) + push!(terms, MOI.VectorAffineTerm(i, it.f.terms[j].scalar_term)) end end - return VAF(terms, constant) + return MOI.VectorAffineFunction(terms, it.f.constants[output_indices]) end + +# VectorQuadraticFunction + function Base.getindex( - it::ScalarFunctionIterator{VQF{T}}, - I::AbstractVector, -) where {T} - affine_terms = MOI.VectorAffineTerm{T}[] - quadratic_terms = MOI.VectorQuadraticTerm{T}[] - constant = Vector{T}(undef, length(I)) - for (i, j) in enumerate(I) - g = it[j] - append!( - affine_terms, - map(t -> MOI.VectorAffineTerm(i, t), g.affine_terms), - ) - append!( - quadratic_terms, - map(t -> MOI.VectorQuadraticTerm(i, t), g.quadratic_terms), - ) - constant[i] = g.constant + it::ScalarFunctionIterator{MOI.VectorQuadraticFunction{T}}, + output_index::Integer, +) where {T} + return MOI.ScalarQuadraticFunction( + MOI.ScalarAffineTerm{T}[ + it.f.affine_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[1], output_index) + ], + MOI.ScalarQuadraticTerm{T}[ + it.f.quadratic_terms[i].scalar_term for i in ChainedIteratorAtIndex(it.cache[2], output_index) + ], + it.f.constants[output_index], + ) +end + +function Base.getindex( + it::ScalarFunctionIterator{MOI.VectorQuadraticFunction{T}}, + output_indices::AbstractVector{<:Integer}, +) where {T} + vat = MOI.VectorAffineTerm{T}[] + vqt = MOI.VectorQuadraticTerm{T}[] + for (i, output_index) in enumerate(output_indices) + for j in ChainedIteratorAtIndex(it.cache[1], output_index) + push!( + vat, + MOI.VectorAffineTerm(i, it.f.affine_terms[j].scalar_term), + ) + end + for j in ChainedIteratorAtIndex(it.cache[2], output_index) + push!( + vqt, + MOI.VectorQuadraticTerm(i, it.f.quadratic_terms[j].scalar_term), + ) + end end - return VQF(affine_terms, quadratic_terms, constant) + return MOI.VectorQuadraticFunction(vat, vqt, it.f.constants[output_indices]) end function zero_with_output_dimension(::Type{Vector{T}}, n::Integer) where {T}