Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 93 additions & 45 deletions src/Utilities/struct_of_constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,31 +188,11 @@ struct SymbolFun <: SymbolFS
s::Union{Symbol,Expr}
typed::Bool
end
SymbolFun(s::Symbol) = SymbolFun(s, false)
function SymbolFun(s::Expr)
if Meta.isexpr(s, :curly)
@assert length(s.args) == 2
@assert s.args[2] == :T
return SymbolFun(s.args[1], true)
else
return SymbolFun(s, false)
end
end

struct SymbolSet <: SymbolFS
s::Union{Symbol,Expr}
typed::Bool
end
SymbolSet(s::Symbol) = SymbolSet(s, false)
function SymbolSet(s::Expr)
if Meta.isexpr(s, :curly)
@assert length(s.args) == 2
@assert s.args[2] == :T
return SymbolSet(s.args[1], true)
else
return SymbolSet(s, false)
end
end

_typed(s::SymbolFS) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s)

Expand All @@ -223,21 +203,61 @@ function _field(s::SymbolFS)
return Symbol(replace(Unicode.lowercase(string(s.s)), "." => "_"))
end

_callfield(f, s::SymbolFS) = :($f(model.$(_field(s))))
# Represents a union of function or set types
struct _UnionSymbolFS{S<:SymbolFS}
s::Vector{S}
end

function _typed(s::_UnionSymbolFS)
tt = _typed.(s.s)
return Expr(:curly, :Union, tt...)
end

_field(s::_UnionSymbolFS) = _field(s.s[1])

function _mapreduce_field(s::Union{SymbolFS,_UnionSymbolFS})
return :(cur = op(cur, f(model.$(_field(s)))))
end

_mapreduce_field(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s)))))
_callfield(f, s::Union{SymbolFS,_UnionSymbolFS}) = :($f(model.$(_field(s))))

function _parse_expr(::Type{S}, expr::Symbol) where {S<:SymbolFS}
return S(expr, false)
end

function _parse_expr(::Type{S}, expr::Expr) where {S<:SymbolFS}
if Meta.isexpr(expr, :curly)
@assert length(expr.args) >= 1
if expr.args[1] == :Union
# `Union{:A, :B}` parses as
# `Expr(:curly, :Union, :A, :B)`
@assert length(expr.args) >= 3
return _UnionSymbolFS{S}(_parse_expr.(S, expr.args[2:end]))
else
# Typed set, e.g. `MOI.EqualTo{T}` parses as:
# `Expr(:curly, :(MOI.EqualTo), :T)`
@assert length(expr.args) == 2
@assert expr.args[2] == :T
return S(expr.args[1], true)
end
else
return S(expr, false)
end
end

"""
struct_of_constraint_code(struct_name, types, field_types = nothing)

Given a vector of `n` `SymbolFun` or `SymbolSet` in `types`, defines
a subtype of `StructOfConstraints` of name `name` and which type parameters
Given a vector of `n` `Union{SymbolFun,_UnionSymbolFS{SymbolFun}}` or
`Union{SymbolSet,_UnionSymbolFS{SymbolSet}}` in `types`, defines a subtype of
`StructOfConstraints` of name `name` and which type parameters
`{T, F1, F2, ..., Fn}` if `field_types` is `nothing` and
a `{T}` otherwise.
It contains `n` field where the `i`th field has type `Ci` if `field_types` is
`nothing` and type `field_types[i]` otherwise.
If `types` is vector of `SymbolFun` (resp. `SymbolSet`) then the constraints
of that function (resp. set) type are stored in the corresponding field.
If `types` is vector of `Union{SymbolFun,_UnionSymbolFS{SymbolFun}}` (resp.
`Union{SymbolSet,_UnionSymbolFS{SymbolSet}}`) then the constraints of that
function (resp. set) type are stored in the corresponding field.

This function is used by the macros [`@model`](@ref),
[`@struct_of_constraints_by_function_types`](@ref) and
Expand Down Expand Up @@ -278,8 +298,8 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)
for (t, field_type) in zip(types, field_types)
field = _field(t)
push!(code.args[2].args[3].args, :($field::Union{Nothing,$field_type}))
fun = t isa SymbolFun ? esc(t.s) : :(MOI.AbstractFunction)
set = t isa SymbolFun ? :(MOI.AbstractSet) : esc(t.s)
fun = t isa SymbolFun ? _typed(t) : :(MOI.AbstractFunction)
set = t isa SymbolFun ? :(MOI.AbstractSet) : _typed(t)
constraints_code = :(
function $MOIU.constraints(
model::$typed_struct,
Expand All @@ -301,21 +321,41 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)
end
push!(code.args, constraints_code)
end
is_func = eltype(types) <: SymbolFun
SuperF = is_func ? :(Union{$(_typed.(types)...)}) : :(MOI.AbstractFunction)
SuperS = is_func ? :(MOI.AbstractSet) : :(Union{$(_typed.(types)...)})
push!(
code.args,
:(
function $MOI.supports_constraint(
model::$struct_name{$T},
::Type{F},
::Type{S},
) where {$T,F<:$SuperF,S<:$SuperS}
return $MOI.supports_constraint(constraints(model, F, S), F, S)
end
),
)
if !isempty(types)
is_func = any(types) do t
return t isa Union{SymbolFun,_UnionSymbolFS{SymbolFun}}
end
is_set = any(types) do t
return t isa Union{SymbolSet,_UnionSymbolFS{SymbolSet}}
end
@assert xor(is_func, is_set)
if is_func
SuperF = :(Union{$(_typed.(types)...)})
else
SuperF = :(MOI.AbstractFunction)
end
if is_set
SuperS = :(Union{$(_typed.(types)...)})
else
SuperS = :(MOI.AbstractSet)
end
push!(
code.args,
:(
function $MOI.supports_constraint(
model::$struct_name{$T},
::Type{F},
::Type{S},
) where {$T,F<:$SuperF,S<:$SuperS}
return $MOI.supports_constraint(
constraints(model, F, S),
F,
S,
)
end
),
)
end
constructor_code = :(function $typed_struct() where {$T}
return $typed_struct(0, $([:(nothing) for _ in field_types]...))
end)
Expand All @@ -334,9 +374,12 @@ a subtype of `StructOfConstraints` of name `name` and which type parameters
`{T, C1, C2, ..., Cn}`.
It contains `n` field where the `i`th field has type `Ci` and stores the
constraints of function type `Fi`.

The expression `Fi` can also be a union in which case any constraint for which
the function type is in the union is stored in the field with type `Ci`.
"""
macro struct_of_constraints_by_function_types(name, func_types...)
funcs = SymbolFun.(func_types)
funcs = _parse_expr.(SymbolFun, func_types)
return struct_of_constraint_code(esc(name), funcs)
end

Expand All @@ -348,8 +391,13 @@ a subtype of `StructOfConstraints` of name `name` and which type parameters
`{T, C1, C2, ..., Cn}`.
It contains `n` field where the `i`th field has type `Ci` and stores the
constraints of set type `Si`.
The expression `Si` can also be a union in which case any constraint for which
the set type is in the union is stored in the field with type `Ci`.
This can be useful if `Ci` is a [`MatrixOfConstraints`](@ref) in order to
concatenate the coefficients of constraints of several different set types in
the same matrix.
"""
macro struct_of_constraints_by_set_types(name, set_types...)
sets = SymbolSet.(set_types)
sets = _parse_expr.(SymbolSet, set_types)
return struct_of_constraint_code(esc(name), sets)
end
55 changes: 55 additions & 0 deletions test/Utilities/matrix_of_constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,61 @@ function test_modif()
@test_throws err MOI.add_constraint(model, func, set)
end

MOIU.@struct_of_constraints_by_set_types(
ZerosOrNot,
MOI.Zeros,
Union{MOI.Nonnegatives,MOI.Nonpositives},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why this is needed in more detail?!? Why not just two separate fields.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows to have one sparse matrix with the two instead of two separate sparse matrices. ECOS needs one sparse matrix with the three non-Zeros cones

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but you could concatenate the sparse matrices? It seems weird that we go to all this effort and then allow Union

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue that this PR is simpler than implementing concatenation in the solver wrapper though. DiffOpt also needs to split the matrix onto equality and non-equality. This is quite common so having this feature is not just for ECOS

)

function test_multicone()
T = Int
Indexing = MOIU.OneBasedIndexing
model = MOIU.GenericOptimizer{
T,
ZerosOrNot{T}{
MOIU.MatrixOfConstraints{
T,
MOIU.MutableSparseMatrixCSC{T,Int,Indexing},
Vector{T},
Zeros{T},
},
MOIU.MatrixOfConstraints{
T,
MOIU.MutableSparseMatrixCSC{T,Int,Indexing},
Vector{T},
NonnegNonpos{T},
},
},
}()
#return model
x = MOI.add_variable(model)
fx = MOI.SingleVariable(x)
y = MOI.add_variable(model)
fy = MOI.SingleVariable(y)
MOI.add_constraint(model, MOIU.vectorize([T(5) * fx + T(2)]), MOI.Zeros(1))
MOI.add_constraint(
model,
MOIU.vectorize([T(3) * fy + T(1)]),
MOI.Nonnegatives(1),
)
MOI.add_constraint(
model,
MOIU.vectorize([T(6), T(7) * fx, T(4)]),
MOI.Nonpositives(1),
)
MOIU.final_touch(model, nothing)
_test_matrix_equal(
model.constraints.moi_zeros.coefficients,
sparse([1], [1], T[5], 1, 2),
)
@test model.constraints.moi_zeros.constants == T[2]
_test_matrix_equal(
model.constraints.moi_nonnegatives.coefficients,
sparse([1, 3], [2, 1], T[3, 7], 4, 2),
)
@test model.constraints.moi_nonnegatives.constants == T[1, 6, 0, 4]
end

end

TestMatrixOfConstraints.runtests()