From 460c9259434a73a303d5e75e0a08436530cda21b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 4 May 2026 15:57:41 +0200 Subject: [PATCH 1/2] Arbitrary number type --- src/model.jl | 16 ++++----- src/parse.jl | 13 ++++--- src/parse_moi.jl | 6 ++-- src/types.jl | 57 ++++++++++++++++--------------- test/ArrayDiff.jl | 86 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 135 insertions(+), 43 deletions(-) diff --git a/src/model.jl b/src/model.jl index c0a310d..6010ae6 100644 --- a/src/model.jl +++ b/src/model.jl @@ -7,15 +7,15 @@ function set_objective(model::Model, obj) end function add_constraint( - model::Model, + model::Model{T}, func, set::Union{ - MOI.GreaterThan{Float64}, - MOI.LessThan{Float64}, - MOI.Interval{Float64}, - MOI.EqualTo{Float64}, + MOI.GreaterThan{T}, + MOI.LessThan{T}, + MOI.Interval{T}, + MOI.EqualTo{T}, }, -) +) where {T} f = parse_expression(model, func) model.last_constraint_index += 1 index = ConstraintIndex(model.last_constraint_index) @@ -23,8 +23,8 @@ function add_constraint( return index end -function add_parameter(model::Model, value::Float64) - push!(model.parameters, value) +function add_parameter(model::Model{T}, value::Real) where {T} + push!(model.parameters, convert(T, value)) return ParameterIndex(length(model.parameters)) end diff --git a/src/parse.jl b/src/parse.jl index 8382d2d..3c5d582 100644 --- a/src/parse.jl +++ b/src/parse.jl @@ -52,14 +52,19 @@ function parse_expression( return end -function parse_expression(data::Model, input) - expr = Expression() +function parse_expression(data::Model{T}, input) where {T} + expr = Expression{T}() parse_expression(data, expr, input, -1) return expr end -function parse_expression(::Model, expr::Expression, x::Real, parent_index::Int) - push!(expr.values, convert(Float64, x)::Float64) +function parse_expression( + ::Model, + expr::Expression{T}, + x::Real, + parent_index::Int, +) where {T} + push!(expr.values, convert(T, x)::T) push!(expr.nodes, Node(NODE_VALUE, length(expr.values), parent_index)) return end diff --git a/src/parse_moi.jl b/src/parse_moi.jl index 86c1c01..26b003b 100644 --- a/src/parse_moi.jl +++ b/src/parse_moi.jl @@ -51,7 +51,7 @@ function _parse_moi_stack!( ::Vector{Tuple{Int,Any}}, data::Model, expr::Expression, - x::Union{Float64,MOI.VariableIndex}, + x::Union{Real,MOI.VariableIndex}, parent_index::Int, ) return parse_expression(data, expr, x, parent_index) @@ -188,7 +188,7 @@ function _parse_moi_stack!( stack::Vector{Tuple{Int,Any}}, data::Model, expr::Expression, - x::Matrix{Float64}, + x::AbstractMatrix{<:Real}, parent_index::Int, ) m, n = size(x) @@ -210,7 +210,7 @@ function _parse_moi_stack!( stack::Vector{Tuple{Int,Any}}, data::Model, expr::Expression, - x::Vector{Float64}, + x::AbstractVector{<:Real}, parent_index::Int, ) vect_id = data.operators.multivariate_operator_to_id[:vect] diff --git a/src/types.jl b/src/types.jl index 97b2d51..d3601fb 100644 --- a/src/types.jl +++ b/src/types.jl @@ -14,10 +14,10 @@ The core type that represents a nonlinear expression. See the MathOptInterface documentation for information on how the nodes and values form an expression tree. """ -struct Expression +struct Expression{T} nodes::Vector{Node} - values::Vector{Float64} - Expression() = new(Node[], Float64[]) + values::Vector{T} + Expression{T}() where {T} = new{T}(Node[], T[]) end function Base.:(==)(x::Expression, y::Expression) @@ -38,13 +38,13 @@ end A type to hold information relating to the nonlinear constraint `f(x) in S`, where `f(x)` is defined by `.expression`, and `S` is `.set`. """ -struct Constraint - expression::Expression +struct Constraint{T} + expression::Expression{T} set::Union{ - MOI.LessThan{Float64}, - MOI.GreaterThan{Float64}, - MOI.EqualTo{Float64}, - MOI.Interval{Float64}, + MOI.LessThan{T}, + MOI.GreaterThan{T}, + MOI.EqualTo{T}, + MOI.Interval{T}, } end @@ -101,7 +101,7 @@ function _subexpression_and_linearity( return _SubexpressionStorage( nodes, adj, - expr.values, + convert(Vector{Float64}, expr.values), partials_storage_ϵ, linearity[1], ), @@ -192,30 +192,31 @@ It has the following fields: * `parameters::Vector{Float64}` : holds the current values of the parameters. * `operators::OperatorRegistry` : stores the operators used in the model. """ -mutable struct Model - objective::Union{Nothing,Expression} - expressions::Vector{Expression} - constraints::OrderedCollections.OrderedDict{ConstraintIndex,Constraint} - parameters::Vector{Float64} +mutable struct Model{T} + objective::Union{Nothing,Expression{T}} + expressions::Vector{Expression{T}} + constraints::OrderedCollections.OrderedDict{ConstraintIndex,Constraint{T}} + parameters::Vector{T} operators::OperatorRegistry # This is a private field, used only to increment the ConstraintIndex. last_constraint_index::Int64 - function Model() - model = new( + function Model{T}() where {T} + return new{T}( nothing, - Expression[], - OrderedCollections.OrderedDict{ConstraintIndex,Constraint}(), - Float64[], + Expression{T}[], + OrderedCollections.OrderedDict{ConstraintIndex,Constraint{T}}(), + T[], OperatorRegistry(), 0, ) - return model end end -mutable struct Evaluator{B} <: MOI.AbstractNLPEvaluator +Model() = Model{Float64}() + +mutable struct Evaluator{T,B} <: MOI.AbstractNLPEvaluator # The internal datastructure. - model::Model + model::Model{T} # The abstract-differentiation backend backend::B # ordered_constraints is needed because `OrderedDict` doesn't support @@ -223,7 +224,7 @@ mutable struct Evaluator{B} <: MOI.AbstractNLPEvaluator ordered_constraints::Vector{ConstraintIndex} # Storage for the NLPBlockDual, so that we can query the dual of individual # constraints without needing to query the full vector each time. - constraint_dual::Vector{Float64} + constraint_dual::Vector{T} # Timers initialize_timer::Float64 eval_objective_timer::Float64 @@ -236,14 +237,14 @@ mutable struct Evaluator{B} <: MOI.AbstractNLPEvaluator eval_hessian_lagrangian_timer::Float64 function Evaluator( - model::Model, + model::Model{T}, backend::B = nothing, - ) where {B<:Union{Nothing,MOI.AbstractNLPEvaluator}} - return new{B}( + ) where {T,B<:Union{Nothing,MOI.AbstractNLPEvaluator}} + return new{T,B}( model, backend, MOI.ConstraintIndex[], - Float64[], + T[], 0.0, 0.0, 0.0, diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index 5666cc9..f05e0da 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -702,6 +702,92 @@ function test_objective_broadcasted_pow_cubed() return end +function test_model_typed_default_is_float64() + model = ArrayDiff.Model() + @test model isa ArrayDiff.Model{Float64} + @test model.parameters isa Vector{Float64} + @test model.expressions isa Vector{ArrayDiff.Expression{Float64}} + @test model.constraints isa + ArrayDiff.OrderedCollections.OrderedDict{ + ArrayDiff.ConstraintIndex, + ArrayDiff.Constraint{Float64}, + } + return +end + +function test_model_typed_float32_parse_value() + model = ArrayDiff.Model{Float32}() + x = MOI.VariableIndex(1) + ArrayDiff.set_objective(model, :($x + 1.5)) + obj = something(model.objective) + @test obj isa ArrayDiff.Expression{Float32} + @test obj.values isa Vector{Float32} + @test obj.values == Float32[1.5] + return +end + +function test_model_typed_float32_add_parameter() + model = ArrayDiff.Model{Float32}() + p = ArrayDiff.add_parameter(model, 2.5) + @test p isa ArrayDiff.ParameterIndex + @test model.parameters isa Vector{Float32} + @test model.parameters == Float32[2.5] + return +end + +function test_model_typed_float32_add_constraint() + model = ArrayDiff.Model{Float32}() + x = MOI.VariableIndex(1) + set = MOI.LessThan{Float32}(3.0f0) + idx = ArrayDiff.add_constraint(model, :($x + 1.0), set) + @test idx isa ArrayDiff.ConstraintIndex + c = model.constraints[idx] + @test c isa ArrayDiff.Constraint{Float32} + @test c.expression isa ArrayDiff.Expression{Float32} + @test c.expression.values == Float32[1.0] + @test c.set === set + return +end + +function test_model_typed_float32_add_expression() + model = ArrayDiff.Model{Float32}() + x = MOI.VariableIndex(1) + idx = ArrayDiff.add_expression(model, :($x * 2.0)) + @test idx isa ArrayDiff.ExpressionIndex + e = model[idx] + @test e isa ArrayDiff.Expression{Float32} + @test e.values == Float32[2.0] + return +end + +function test_model_typed_bigfloat_constraint_set() + model = ArrayDiff.Model{BigFloat}() + x = MOI.VariableIndex(1) + set = MOI.GreaterThan{BigFloat}(big"1.0") + idx = ArrayDiff.add_constraint(model, :($x), set) + c = model.constraints[idx] + @test c isa ArrayDiff.Constraint{BigFloat} + @test c.set === set + return +end + +function test_model_typed_float32_evaluator_runs() + # End-to-end smoke test: parsing happens in T = Float32, AD evaluation + # converts to Float64 internally. + model = ArrayDiff.Model{Float32}() + x = MOI.VariableIndex(1) + ArrayDiff.set_objective(model, :(2 * dot([$x], [$x]) + 1.0)) + evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x]) + @test evaluator isa ArrayDiff.Evaluator{Float32} + MOI.initialize(evaluator, [:Grad]) + xv = [1.5] + @test MOI.eval_objective(evaluator, xv) ≈ 2 * xv[1]^2 + 1.0 + g = ones(1) + MOI.eval_objective_gradient(evaluator, g, xv) + @test g[1] ≈ 4 * xv[1] + return +end + end # module TestArrayDiff.runtests() From 9802d59c96665da0c48b31fd79252eb642bbc609 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 4 May 2026 15:58:15 +0200 Subject: [PATCH 2/2] Fix format --- test/ArrayDiff.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index f05e0da..06fc357 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -707,8 +707,7 @@ function test_model_typed_default_is_float64() @test model isa ArrayDiff.Model{Float64} @test model.parameters isa Vector{Float64} @test model.expressions isa Vector{ArrayDiff.Expression{Float64}} - @test model.constraints isa - ArrayDiff.OrderedCollections.OrderedDict{ + @test model.constraints isa ArrayDiff.OrderedCollections.OrderedDict{ ArrayDiff.ConstraintIndex, ArrayDiff.Constraint{Float64}, }