diff --git a/src/Utilities/model.jl b/src/Utilities/model.jl index 81de8d35f9..0251a95a0e 100644 --- a/src/Utilities/model.jl +++ b/src/Utilities/model.jl @@ -957,7 +957,7 @@ macro model( sets = vector_sets end voc = map(sets) do set - return :(VectorOfConstraints{$(_typedfun(funs[i])),$(_typedset(set))}) + return :(VectorOfConstraints{$(_typed(funs[i])),$(_typed(set))}) end return _struct_of_constraints_type(cname, voc, true) end diff --git a/src/Utilities/struct_of_constraints.jl b/src/Utilities/struct_of_constraints.jl index dd38494ea2..ccbadd6ffc 100644 --- a/src/Utilities/struct_of_constraints.jl +++ b/src/Utilities/struct_of_constraints.jl @@ -2,33 +2,37 @@ abstract type StructOfConstraints <: MOI.ModelLike end function _throw_if_cannot_delete(model::StructOfConstraints, vis, fast_in_vis) broadcastcall(model) do constrs - return _throw_if_cannot_delete(constrs, vis, fast_in_vis) + if constrs !== nothing + _throw_if_cannot_delete(constrs, vis, fast_in_vis) + end + return end + return end + function _deleted_constraints( callback::Function, model::StructOfConstraints, vi, ) broadcastcall(model) do constrs - return _deleted_constraints(callback, constrs, vi) + if constrs !== nothing + _deleted_constraints(callback, constrs, vi) + end + return end + return end function MOI.add_constraint( model::StructOfConstraints, - func::MOI.AbstractFunction, - set::MOI.AbstractSet, -) - if MOI.supports_constraint(model, typeof(func), typeof(set)) - return MOI.add_constraint( - constraints(model, typeof(func), typeof(set)), - func, - set, - ) - else - throw(MOI.UnsupportedConstraint{typeof(func),typeof(set)}()) + func::F, + set::S, +) where {F<:MOI.AbstractFunction,S<:MOI.AbstractSet} + if !MOI.supports_constraint(model, F, S) + throw(MOI.UnsupportedConstraint{F,S}()) end + return MOI.add_constraint(constraints(model, F, S), func, set) end function constraints( @@ -40,6 +44,7 @@ function constraints( end return constraints(model, F, S) end + function MOI.get( model::StructOfConstraints, attr::Union{MOI.ConstraintFunction,MOI.ConstraintSet}, @@ -49,18 +54,18 @@ function MOI.get( end function MOI.delete(model::StructOfConstraints, ci::MOI.ConstraintIndex) - return MOI.delete(constraints(model, ci), ci) + MOI.delete(constraints(model, ci), ci) + return end function MOI.is_valid( model::StructOfConstraints, ci::MOI.ConstraintIndex{F,S}, ) where {F,S} - if MOI.supports_constraint(model, F, S) - return MOI.is_valid(constraints(model, ci), ci) - else + if !MOI.supports_constraint(model, F, S) return false end + return MOI.is_valid(constraints(model, ci), ci) end function MOI.modify( @@ -68,7 +73,8 @@ function MOI.modify( ci::MOI.ConstraintIndex, change::MOI.AbstractFunctionModification, ) - return MOI.modify(constraints(model, ci), ci, change) + MOI.modify(constraints(model, ci), ci, change) + return end function MOI.set( @@ -77,138 +83,113 @@ function MOI.set( ci::MOI.ConstraintIndex, func_or_set, ) - return MOI.set(constraints(model, ci), attr, ci, func_or_set) + MOI.set(constraints(model, ci), attr, ci, func_or_set) + return end function MOI.get( model::StructOfConstraints, - loc::MOI.ListOfConstraintTypesPresent, -) where {T} - return broadcastvcat(model) do v - return MOI.get(v, loc) + attr::MOI.ListOfConstraintTypesPresent, +) + return broadcastvcat(model) do constrs + if constrs === nothing + return Tuple{DataType,DataType}[] + end + return MOI.get(constrs, attr) end end function MOI.get( model::StructOfConstraints, - noc::MOI.NumberOfConstraints{F,S}, + attr::MOI.NumberOfConstraints{F,S}, ) where {F,S} - if MOI.supports_constraint(model, F, S) - return MOI.get(constraints(model, F, S), noc) - else + if !MOI.supports_constraint(model, F, S) return 0 end + return MOI.get(constraints(model, F, S), attr) end function MOI.get( model::StructOfConstraints, - loc::MOI.ListOfConstraintIndices{F,S}, + attr::MOI.ListOfConstraintIndices{F,S}, ) where {F,S} - if MOI.supports_constraint(model, F, S) - return MOI.get(constraints(model, F, S), loc) - else + if !MOI.supports_constraint(model, F, S) return MOI.ConstraintIndex{F,S}[] end + return MOI.get(constraints(model, F, S), attr) end function MOI.is_empty(model::StructOfConstraints) - return mapreduce_constraints(MOI.is_empty, &, model, true) + return mapreduce_constraints(&, model, true) do constrs + return constrs === nothing || MOI.is_empty(constrs) + end end + function MOI.empty!(model::StructOfConstraints) - return broadcastcall(MOI.empty!, model) + broadcastcall(model) do constrs + if constrs !== nothing + MOI.empty!(constrs) + end + return + end + return end # Can be used to access constraints of a model """ -broadcastcall(f::Function, model::AbstractModel) - -Calls `f(contrs)` for every vector `constrs::Vector{ConstraintIndex{F, S}, F, S}` of the model. + broadcastcall(f::Function, model::StructOfConstraints) -# Examples - -To add all constraints of the model to a solver `solver`, one can do -```julia -_addcon(solver, ci, f, s) = MOI.add_constraint(solver, f, s) -function _addcon(solver, constrs::Vector) - for constr in constrs - _addcon(solver, constr...) - end -end -MOIU.broadcastcall(constrs -> _addcon(solver, constrs), model) -``` +Calls `f(contrs)` for every field in `model`. """ function broadcastcall end """ -broadcastvcat(f::Function, model::AbstractModel) + broadcastvcat(f::Function, model::StructOfConstraints) -Calls `f(contrs)` for every vector `constrs::Vector{ConstraintIndex{F, S}, F, S}` of the model and concatenate the results with `vcat` (this is used internally for `ListOfConstraintTypesPresent`). - -# Examples - -To get the list of all functions: -```julia -_getfun(ci, f, s) = f -_getfun(cindices::Tuple) = _getfun(cindices...) -_getfuns(constrs::Vector) = _getfun.(constrs) -MOIU.broadcastvcat(_getfuns, model) -``` +Calls `f(contrs)` for every field in `model` and `vcat`s the results. """ function broadcastvcat end +""" + mapreduce_constraints( + f::Function, + op::Function, + model::StructOfConstraints, + init, + ) + +Call `mapreduce` on every field of `model` given an initial value `init`. Each +element in the map is computed as `f(x)` and the elements are reduced using +`op`. +""" function mapreduce_constraints end # Macro code abstract type SymbolFS end + struct SymbolFun <: SymbolFS s::Union{Symbol,Expr} typed::Bool end + struct SymbolSet <: SymbolFS s::Union{Symbol,Expr} typed::Bool end -# QuoteNode prevents s from being interpolated and keeps it as a symbol -# Expr(:., MOI, s) would be MOI.s -# Expr(:., MOI, $s) would be Expr(:., MOI, EqualTo) -# Expr(:., MOI, :($s)) would be Expr(:., MOI, :EqualTo) -# Expr(:., MOI, :($(QuoteNode(s)))) is Expr(:., MOI, :(:EqualTo)) <- what we want - -# (MOI, :Zeros) -> :(MOI.Zeros) -# (:Zeros) -> :(MOI.Zeros) -_set(s::SymbolSet) = esc(s.s) -_fun(s::SymbolFun) = esc(s.s) -function _typedset(s::SymbolSet) - if s.typed - T = esc(:T) - :($(_set(s)){$T}) - else - _set(s) - end -end -function _typedfun(s::SymbolFun) - if s.typed - T = esc(:T) - :($(_fun(s)){$T}) - else - _fun(s) - end -end +_typed(s::SymbolFS) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s) # Base.lowercase is moved to Unicode.lowercase in Julia v0.7 -using Unicode +import Unicode -_field(s::SymbolFS) = Symbol(replace(lowercase(string(s.s)), "." => "_")) - -_getC(s::SymbolSet) = :(VectorOfConstraints{F,$(_typedset(s))}) -_getC(s::SymbolFun) = _typedfun(s) +function _field(s::SymbolFS) + return Symbol(replace(Unicode.lowercase(string(s.s)), "." => "_")) +end _callfield(f, s::SymbolFS) = :($f(model.$(_field(s)))) -_broadcastfield(b, s::SymbolFS) = :($b(f, model.$(_field(s)))) + _mapreduce_field(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s))))) -_mapreduce_constraints(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s))))) """ struct_of_constraint_code(struct_name, types, field_types = nothing) @@ -223,100 +204,82 @@ If `types` is vector of `SymbolFun` (resp. `SymbolSet`) then the constraints of that function (resp. set) type are stored in the corresponding field. """ function struct_of_constraint_code(struct_name, types, field_types = nothing) - esc_struct_name = struct_name T = esc(:T) - typed_struct = :($(esc_struct_name){$T}) + typed_struct = :($(struct_name){$T}) type_parametrized = field_types === nothing if type_parametrized field_types = [Symbol("C$i") for i in eachindex(types)] append!(typed_struct.args, field_types) end - - struct_def = :(struct $typed_struct <: StructOfConstraints end) - - for (t, field_type) in zip(types, field_types) - field = _field(t) - push!(struct_def.args[3].args, :($field::$field_type)) - end code = quote - function $MOIU.broadcastcall(f::Function, model::$esc_struct_name) - return $(Expr(:block, _callfield.(Ref(:f), types)...)) + mutable struct $typed_struct <: StructOfConstraints end + + function $MOIU.broadcastcall(f::Function, model::$struct_name) + $(Expr(:block, _callfield.(Ref(:f), types)...)) + return end - function $MOIU.broadcastvcat(f::Function, model::$esc_struct_name) + + function $MOIU.broadcastvcat(f::Function, model::$struct_name) return vcat($(_callfield.(Ref(:f), types)...)) end + function $MOIU.mapreduce_constraints( f::Function, op::Function, - model::$esc_struct_name, + model::$struct_name, cur, ) return $(Expr(:block, _mapreduce_field.(types)...)) end end - for t in types - if t isa SymbolFun - fun = _fun(t) - set = :(MOI.AbstractSet) - else - fun = :(MOI.AbstractFunction) - set = _set(t) - end + for (t, field_type) in zip(types, field_types) field = _field(t) - code = quote - $code + push!(code.args[2].args[3].args, :($field::Union{Nothing,$field_type})) + fun = t isa SymbolFun ? esc(t.s) : :(MOI.AbstractFunction) + set = t isa SymbolFun ? :(MOI.AbstractSet) : esc(t.s) + constraints_code = :( function $MOIU.constraints( - model::$esc_struct_name, + model::$typed_struct, ::Type{<:$fun}, ::Type{<:$set}, - ) + )::$(field_type) where {$T} + if model.$field === nothing + model.$field = $(field_type)() + end return model.$field end + ) + if type_parametrized + append!(constraints_code.args[1].args, field_types) end + push!(code.args, constraints_code) end - supports_code = if eltype(types) <: SymbolFun - quote + is_func = eltype(types) <: SymbolFun + SuperF = is_func ? :(Union{$(_typed.(types)...)}) : :(MOI.AbstractFunction) + SuperS = is_func ? :(MOI.AbstractSet) : :(Union{$(_typed.(types)...)}) + push!( + code.args, + :( function $MOI.supports_constraint( - model::$esc_struct_name{$T}, + model::$struct_name{$T}, ::Type{F}, ::Type{S}, - ) where { - $T, - F<:Union{$(_typedfun.(types)...)}, - S<:MOI.AbstractSet, - } + ) where {$T,F<:$SuperF,S<:$SuperS} return $MOI.supports_constraint(constraints(model, F, S), F, S) end - end - else - @assert eltype(types) <: SymbolSet - quote - function $MOI.supports_constraint( - model::$esc_struct_name{$T}, - ::Type{F}, - ::Type{S}, - ) where { - $T, - F<:MOI.AbstractFunction, - S<:Union{$(_typedset.(types)...)}, - } - return $MOI.supports_constraint(constraints(model, F, S), F, S) - end - end - end - expr = Expr(:block, struct_def, supports_code, code) + ), + ) if !isempty(field_types) - constructors = [:($field_type()) for field_type in field_types] # If there is no field type, the default constructor is sufficient and # adding this constructor will make a `StackOverflow`. constructor_code = :(function $typed_struct() where {$T} - return $typed_struct($(constructors...)) + return $typed_struct($([:(nothing) for _ in field_types]...)) end) if type_parametrized append!(constructor_code.args[1].args, field_types) end - push!(expr.args, constructor_code) + push!(code.args, constructor_code) end - return expr + return code end diff --git a/src/Utilities/vector_of_constraints.jl b/src/Utilities/vector_of_constraints.jl index dd6236079d..35a9db648f 100644 --- a/src/Utilities/vector_of_constraints.jl +++ b/src/Utilities/vector_of_constraints.jl @@ -15,11 +15,10 @@ # vector, it readily gives the entries of `model.constrmap` that need to be # updated. -struct VectorOfConstraints{F<:MOI.AbstractFunction,S<:MOI.AbstractSet} <: - MOI.ModelLike - # FIXME: It is not ideal that we have `DataType` here, it might induce type - # instabilities. We should change `CleverDicts` so that we can just - # use `typeof(CleverDicts.index_to_key)` here. +mutable struct VectorOfConstraints{ + F<:MOI.AbstractFunction, + S<:MOI.AbstractSet, +} <: MOI.ModelLike constraints::CleverDicts.CleverDict{ MOI.ConstraintIndex{F,S}, Tuple{F,S}, @@ -38,12 +37,13 @@ MOI.is_empty(v::VectorOfConstraints) = isempty(v.constraints) MOI.empty!(v::VectorOfConstraints) = empty!(v.constraints) function MOI.supports_constraint( - v::VectorOfConstraints{F,S}, + ::VectorOfConstraints{F,S}, ::Type{F}, ::Type{S}, ) where {F<:MOI.AbstractFunction,S<:MOI.AbstractSet} return true end + function MOI.add_constraint( v::VectorOfConstraints{F,S}, func::F, @@ -153,7 +153,8 @@ function _remove_variable(v::VectorOfConstraints, vi::MOI.VariableIndex) end return end -function _filter_variables(keep::F, v::VectorOfConstraints) where {F<:Function} + +function _filter_variables(keep::Function, v::VectorOfConstraints) CleverDicts.map_values!(v.constraints) do (f, s) return filter_variables(keep, f, s) end @@ -176,7 +177,7 @@ function _throw_if_cannot_delete( vis, fast_in_vis, ) where {S<:MOI.AbstractVectorSet} - if MOI.supports_dimension_update(S) + if MOI.supports_dimension_update(S) || MOI.is_empty(v) return end for fs in values(v.constraints) @@ -203,10 +204,10 @@ function _delete_variables( end function _delete_variables( - callback::F, + callback::Function, v::VectorOfConstraints{MOI.VectorOfVariables,<:MOI.AbstractVectorSet}, vis::Vector{MOI.VariableIndex}, -) where {F<:Function} +) filter!(v.constraints) do p f = p.second[1] del = if length(f.variables) == 1 @@ -223,23 +224,22 @@ function _delete_variables( end function _deleted_constraints( - callback::F, + callback::Function, v::VectorOfConstraints, vi::MOI.VariableIndex, -) where {F<:Function} - vis = [vi] - _delete_variables(callback, v, vis) +) + _delete_variables(callback, v, [vi]) _remove_variable(v, vi) return end function _deleted_constraints( - callback::F, + callback::Function, v::VectorOfConstraints, vis::Vector{MOI.VariableIndex}, -) where {F<:Function} - removed = Set(vis) +) _delete_variables(callback, v, vis) + removed = Set(vis) _filter_variables(vi -> !(vi in removed), v) return end diff --git a/test/Utilities/model.jl b/test/Utilities/model.jl index ca6c62b217..c91e5f0aca 100644 --- a/test/Utilities/model.jl +++ b/test/Utilities/model.jl @@ -348,6 +348,7 @@ end push!(loc2, (F, S)) end end + _pushloc(::Nothing) = nothing function _pushloc(model::MOI.Utilities.StructOfConstraints) return MOIU.broadcastcall(_pushloc, model) end