diff --git a/src/sizes.jl b/src/sizes.jl index baf8f7e..462f3b0 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -346,47 +346,33 @@ function _infer_sizes( continue end op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] - if op == :+ || op == :- - # Broadcasted +/- preserves shape - _copy_size!(sizes, k, children_arr[first(children_indices)]) - elseif op == :^ - # Broadcasted ^ with scalar exponent preserves base shape - _copy_size!(sizes, k, children_arr[first(children_indices)]) - elseif op == :* - # TODO assert compatible sizes and all ndims should be 0 or 2 - first_matrix = findfirst(children_indices) do i - return !iszero(sizes.ndims[children_arr[i]]) + if op == :+ || op == :- || op == :* + sizes.ndims[k] = maximum(children_indices, init = 0) do i + return sizes.ndims[children_arr[i]] end - if !isnothing(first_matrix) - if sizes.ndims[children_arr[first(children_indices)]] == 0 - _add_size!(sizes, k, (1, 1)) - continue - else - if sizes.ndims[children_arr[first(children_indices)]] == - 1 - nb_cols = 1 - else - nb_cols = _size( - sizes, - children_arr[first(children_indices)], - 1, - ) + sizes.size_offset[k] = length(sizes.size) + for _ in 1:sizes.ndims[k] + push!(sizes.size, 1) + end + sz_parent = _size(sizes, k) + for i in children_indices + id = children_arr[i] + sz = _size(sizes, id) + for j in eachindex(sz) + if sz[j] > 1 + if sz_parent[j] == 1 + sz_parent[j] = sz[j] + else + @assert sz_parent[j] == sz[j] + end end - _add_size!( - sizes, - k, - ( - _size( - sizes, - children_arr[first(children_indices)], - 1, - ), - nb_cols, - ), - ) - continue end end + elseif op == :^ + # Broadcasted ^ with scalar exponent preserves base shape + @assert length(children_indices) == 2 "Expected two arguments for broadcasted operator `$op`, got $(length(children_indices))" + @assert iszero(sizes.ndims[children_arr[children_indices[2]]]) "Expected scalar exponent for broadcasted operator `$op`" + _copy_size!(sizes, k, children_arr[first(children_indices)]) end elseif node.type == NODE_CALL_UNIVARIATE if !( diff --git a/test/ArrayDiff.jl b/test/ArrayDiff.jl index 99ef909..067eae8 100644 --- a/test/ArrayDiff.jl +++ b/test/ArrayDiff.jl @@ -607,9 +607,9 @@ function test_objective_broadcasted_product() evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4]) MOI.initialize(evaluator, [:Grad]) sizes = evaluator.backend.objective.expr.sizes - @test sizes.ndims == [0, 2, 1, 0, 0, 1, 0, 0] + @test sizes.ndims == [0, 1, 1, 0, 0, 1, 0, 0] @test sizes.size_offset == [0, 2, 1, 0, 0, 0, 0, 0] - @test sizes.size == [2, 2, 2, 1] + @test sizes.size == [2, 2, 2] @test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11] x1 = 1.0 x2 = 2.0 diff --git a/test/JuMP.jl b/test/JuMP.jl index 9901621..533e631 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -164,13 +164,7 @@ function test_parse_moi() return end -function _eval( - model::JuMP.GenericModel{T}, - func, - x, - obj_val, - grad_val, -) where {T} +function _eval(model::JuMP.GenericModel{T}, func, x) where {T} mode = ArrayDiff.Mode{Vector{T}}() ad = ArrayDiff.model(mode) MOI.Nonlinear.set_objective(ad, JuMP.moi_function(func)) @@ -180,20 +174,20 @@ function _eval( JuMP.index.(JuMP.all_variables(model)), ) MOI.initialize(evaluator, [:Grad]) - x_grad = T.(collect(1:8)) - @test MOI.eval_objective(evaluator, x) ≈ obj_val + sizes = evaluator.backend.objective.expr.sizes + val = MOI.eval_objective(evaluator, x) if VERSION >= v"1.12" @test 0 == @allocated MOI.eval_objective(evaluator, x) end + x_grad = T.(collect(1:8)) g = zero(x) MOI.eval_objective_gradient(evaluator, g, x_grad) - @test g ≈ grad_val if VERSION >= v"1.12" @test 0 == @allocated MOI.eval_objective_gradient(evaluator, g, x_grad) end MOI.Nonlinear.set_objective(ad, nothing) @test isnothing(ad.objective) - return + return sizes, val, g end function _test_neural( @@ -280,7 +274,9 @@ function _test_neural( else grad_val = grad_sumsq end - _eval(model, loss, [vec(W1_val); vec(W2_val)], obj_val, grad_val) + _, val, g = _eval(model, loss, [vec(W1_val); vec(W2_val)]) + @test obj_val ≈ val + @test grad_val ≈ g return end @@ -428,6 +424,72 @@ function test_size_vec_vect() return end +function test_broadcast_nonsquare_matrix() + model = Model() + @variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables) + Y = [10.0 20.0 30.0; 40.0 50.0 60.0] + x = Float64.(collect(1:6)) + W_val = reshape(x, 2, 3) + @testset "$(op)" for (op, expr, ref_mat) in [ + (:+, LinearAlgebra.norm(W .+ Y), W_val .+ Y), + (:-, LinearAlgebra.norm(W .- Y), W_val .- Y), + (:*, LinearAlgebra.norm(W .* W), W_val .* W_val), + ] + sizes, val, g = _eval(model, expr, x) + # Outer norm scalar, then the broadcasted op produces a 2x3 matrix, + # then the two 2x3 leaves: 4 nodes, three of them ndims=2 with size + # (2, 3). The old bug would report (2, 2) for the broadcast node. + @test sizes.ndims == [0, 2, 2, 2] + @test sizes.size == [2, 3, 2, 3, 2, 3] + @test sizes.size_offset == [0, 4, 2, 0] + @test sizes.storage_offset == [0, 1, 7, 13, 19] + @test val ≈ LinearAlgebra.norm(ref_mat) + ref_g = if op == :+ + vec(W_val .+ Y) ./ LinearAlgebra.norm(ref_mat) + elseif op == :- + vec(W_val .- Y) ./ LinearAlgebra.norm(ref_mat) + else # :* + # d(norm(W .* W))/dW = 2 .* W .^ 3 / norm(W .* W) + vec(2 .* W_val .^ 3) ./ LinearAlgebra.norm(ref_mat) + end + @test g ≈ ref_g + end + return +end + +function test_broadcast_scalar_matrix_size_inference() + model = Model() + @variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables) + mode = ArrayDiff.Mode() + @testset "$(name)" for (name, expr) in [ + ("scalar .* M", LinearAlgebra.norm(2.5 .* W)), + ("M .* scalar", LinearAlgebra.norm(W .* 2.5)), + ("scalar .+ M", LinearAlgebra.norm(2.5 .+ W)), + ("M .+ scalar", LinearAlgebra.norm(W .+ 2.5)), + ("scalar .- M", LinearAlgebra.norm(2.5 .- W)), + ("M .- scalar", LinearAlgebra.norm(W .- 2.5)), + ] + ad = ArrayDiff.model(mode) + MOI.Nonlinear.set_objective(ad, JuMP.moi_function(expr)) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + sizes = evaluator.backend.objective.expr.sizes + # Broadcast node is at index 2; it should inherit the matrix child's + # (2, 3) shape, not the old `(1, 1)` stub. + @test sizes.ndims[2] == 2 + broadcast_size_off = sizes.size_offset[2] + @test sizes.size[broadcast_size_off+1] == 2 + @test sizes.size[broadcast_size_off+2] == 3 + # And the scalar leaf among the children stays ndims=0. + @test 0 in sizes.ndims[3:4] + end + return +end + end # module TestJuMP.runtests()