diff --git a/src/Utilities/struct_of_constraints.jl b/src/Utilities/struct_of_constraints.jl index 4390e199e9..af514b0b98 100644 --- a/src/Utilities/struct_of_constraints.jl +++ b/src/Utilities/struct_of_constraints.jl @@ -188,31 +188,11 @@ struct SymbolFun <: SymbolFS s::Union{Symbol,Expr} typed::Bool end -SymbolFun(s::Symbol) = SymbolFun(s, false) -function SymbolFun(s::Expr) - if Meta.isexpr(s, :curly) - @assert length(s.args) == 2 - @assert s.args[2] == :T - return SymbolFun(s.args[1], true) - else - return SymbolFun(s, false) - end -end struct SymbolSet <: SymbolFS s::Union{Symbol,Expr} typed::Bool end -SymbolSet(s::Symbol) = SymbolSet(s, false) -function SymbolSet(s::Expr) - if Meta.isexpr(s, :curly) - @assert length(s.args) == 2 - @assert s.args[2] == :T - return SymbolSet(s.args[1], true) - else - return SymbolSet(s, false) - end -end _typed(s::SymbolFS) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s) @@ -223,21 +203,61 @@ function _field(s::SymbolFS) return Symbol(replace(Unicode.lowercase(string(s.s)), "." => "_")) end -_callfield(f, s::SymbolFS) = :($f(model.$(_field(s)))) +# Represents a union of function or set types +struct _UnionSymbolFS{S<:SymbolFS} + s::Vector{S} +end + +function _typed(s::_UnionSymbolFS) + tt = _typed.(s.s) + return Expr(:curly, :Union, tt...) +end + +_field(s::_UnionSymbolFS) = _field(s.s[1]) + +function _mapreduce_field(s::Union{SymbolFS,_UnionSymbolFS}) + return :(cur = op(cur, f(model.$(_field(s))))) +end -_mapreduce_field(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s))))) +_callfield(f, s::Union{SymbolFS,_UnionSymbolFS}) = :($f(model.$(_field(s)))) + +function _parse_expr(::Type{S}, expr::Symbol) where {S<:SymbolFS} + return S(expr, false) +end + +function _parse_expr(::Type{S}, expr::Expr) where {S<:SymbolFS} + if Meta.isexpr(expr, :curly) + @assert length(expr.args) >= 1 + if expr.args[1] == :Union + # `Union{:A, :B}` parses as + # `Expr(:curly, :Union, :A, :B)` + @assert length(expr.args) >= 3 + return _UnionSymbolFS{S}(_parse_expr.(S, expr.args[2:end])) + else + # Typed set, e.g. `MOI.EqualTo{T}` parses as: + # `Expr(:curly, :(MOI.EqualTo), :T)` + @assert length(expr.args) == 2 + @assert expr.args[2] == :T + return S(expr.args[1], true) + end + else + return S(expr, false) + end +end """ struct_of_constraint_code(struct_name, types, field_types = nothing) -Given a vector of `n` `SymbolFun` or `SymbolSet` in `types`, defines -a subtype of `StructOfConstraints` of name `name` and which type parameters +Given a vector of `n` `Union{SymbolFun,_UnionSymbolFS{SymbolFun}}` or +`Union{SymbolSet,_UnionSymbolFS{SymbolSet}}` in `types`, defines a subtype of +`StructOfConstraints` of name `name` and which type parameters `{T, F1, F2, ..., Fn}` if `field_types` is `nothing` and a `{T}` otherwise. It contains `n` field where the `i`th field has type `Ci` if `field_types` is `nothing` and type `field_types[i]` otherwise. -If `types` is vector of `SymbolFun` (resp. `SymbolSet`) then the constraints -of that function (resp. set) type are stored in the corresponding field. +If `types` is vector of `Union{SymbolFun,_UnionSymbolFS{SymbolFun}}` (resp. +`Union{SymbolSet,_UnionSymbolFS{SymbolSet}}`) then the constraints of that +function (resp. set) type are stored in the corresponding field. This function is used by the macros [`@model`](@ref), [`@struct_of_constraints_by_function_types`](@ref) and @@ -278,8 +298,8 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing) for (t, field_type) in zip(types, field_types) field = _field(t) push!(code.args[2].args[3].args, :($field::Union{Nothing,$field_type})) - fun = t isa SymbolFun ? esc(t.s) : :(MOI.AbstractFunction) - set = t isa SymbolFun ? :(MOI.AbstractSet) : esc(t.s) + fun = t isa SymbolFun ? _typed(t) : :(MOI.AbstractFunction) + set = t isa SymbolFun ? :(MOI.AbstractSet) : _typed(t) constraints_code = :( function $MOIU.constraints( model::$typed_struct, @@ -301,21 +321,41 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing) end push!(code.args, constraints_code) end - is_func = eltype(types) <: SymbolFun - SuperF = is_func ? :(Union{$(_typed.(types)...)}) : :(MOI.AbstractFunction) - SuperS = is_func ? :(MOI.AbstractSet) : :(Union{$(_typed.(types)...)}) - push!( - code.args, - :( - function $MOI.supports_constraint( - model::$struct_name{$T}, - ::Type{F}, - ::Type{S}, - ) where {$T,F<:$SuperF,S<:$SuperS} - return $MOI.supports_constraint(constraints(model, F, S), F, S) - end - ), - ) + if !isempty(types) + is_func = any(types) do t + return t isa Union{SymbolFun,_UnionSymbolFS{SymbolFun}} + end + is_set = any(types) do t + return t isa Union{SymbolSet,_UnionSymbolFS{SymbolSet}} + end + @assert xor(is_func, is_set) + if is_func + SuperF = :(Union{$(_typed.(types)...)}) + else + SuperF = :(MOI.AbstractFunction) + end + if is_set + SuperS = :(Union{$(_typed.(types)...)}) + else + SuperS = :(MOI.AbstractSet) + end + push!( + code.args, + :( + function $MOI.supports_constraint( + model::$struct_name{$T}, + ::Type{F}, + ::Type{S}, + ) where {$T,F<:$SuperF,S<:$SuperS} + return $MOI.supports_constraint( + constraints(model, F, S), + F, + S, + ) + end + ), + ) + end constructor_code = :(function $typed_struct() where {$T} return $typed_struct(0, $([:(nothing) for _ in field_types]...)) end) @@ -334,9 +374,12 @@ a subtype of `StructOfConstraints` of name `name` and which type parameters `{T, C1, C2, ..., Cn}`. It contains `n` field where the `i`th field has type `Ci` and stores the constraints of function type `Fi`. + +The expression `Fi` can also be a union in which case any constraint for which +the function type is in the union is stored in the field with type `Ci`. """ macro struct_of_constraints_by_function_types(name, func_types...) - funcs = SymbolFun.(func_types) + funcs = _parse_expr.(SymbolFun, func_types) return struct_of_constraint_code(esc(name), funcs) end @@ -348,8 +391,13 @@ a subtype of `StructOfConstraints` of name `name` and which type parameters `{T, C1, C2, ..., Cn}`. It contains `n` field where the `i`th field has type `Ci` and stores the constraints of set type `Si`. +The expression `Si` can also be a union in which case any constraint for which +the set type is in the union is stored in the field with type `Ci`. +This can be useful if `Ci` is a [`MatrixOfConstraints`](@ref) in order to +concatenate the coefficients of constraints of several different set types in +the same matrix. """ macro struct_of_constraints_by_set_types(name, set_types...) - sets = SymbolSet.(set_types) + sets = _parse_expr.(SymbolSet, set_types) return struct_of_constraint_code(esc(name), sets) end diff --git a/test/Utilities/matrix_of_constraints.jl b/test/Utilities/matrix_of_constraints.jl index 74f71a2fdf..4259f72869 100644 --- a/test/Utilities/matrix_of_constraints.jl +++ b/test/Utilities/matrix_of_constraints.jl @@ -526,6 +526,61 @@ function test_modif() @test_throws err MOI.add_constraint(model, func, set) end +MOIU.@struct_of_constraints_by_set_types( + ZerosOrNot, + MOI.Zeros, + Union{MOI.Nonnegatives,MOI.Nonpositives}, +) + +function test_multicone() + T = Int + Indexing = MOIU.OneBasedIndexing + model = MOIU.GenericOptimizer{ + T, + ZerosOrNot{T}{ + MOIU.MatrixOfConstraints{ + T, + MOIU.MutableSparseMatrixCSC{T,Int,Indexing}, + Vector{T}, + Zeros{T}, + }, + MOIU.MatrixOfConstraints{ + T, + MOIU.MutableSparseMatrixCSC{T,Int,Indexing}, + Vector{T}, + NonnegNonpos{T}, + }, + }, + }() + #return model + x = MOI.add_variable(model) + fx = MOI.SingleVariable(x) + y = MOI.add_variable(model) + fy = MOI.SingleVariable(y) + MOI.add_constraint(model, MOIU.vectorize([T(5) * fx + T(2)]), MOI.Zeros(1)) + MOI.add_constraint( + model, + MOIU.vectorize([T(3) * fy + T(1)]), + MOI.Nonnegatives(1), + ) + MOI.add_constraint( + model, + MOIU.vectorize([T(6), T(7) * fx, T(4)]), + MOI.Nonpositives(1), + ) + MOIU.final_touch(model, nothing) + _test_matrix_equal( + model.constraints.moi_zeros.coefficients, + sparse([1], [1], T[5], 1, 2), + ) + @test model.constraints.moi_zeros.constants == T[2] + _test_matrix_equal( + model.constraints.moi_nonnegatives.coefficients, + sparse([1, 3], [2, 1], T[3, 7], 4, 2), + ) + @test model.constraints.moi_nonnegatives.constants == T[1, 6, 0, 4] +end + end TestMatrixOfConstraints.runtests()