From 6031860fd96cd0f55482e69923986ee6c49cc148 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Mon, 17 Jun 2024 16:08:43 +1200 Subject: [PATCH] Support AbstractArray in relative_entropy and log_perspective (#695) --- src/supported_operations.jl | 20 ++++++++++++++++--- test/test_atoms.jl | 38 ++++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/supported_operations.jl b/src/supported_operations.jl index c97cfeac3..f189efafe 100644 --- a/src/supported_operations.jl +++ b/src/supported_operations.jl @@ -1056,7 +1056,12 @@ julia> size(atom) (1, 1) ``` """ -log_perspective(x::AbstractExpr, y::AbstractExpr) = -relative_entropy(y, x) +function log_perspective( + x::Union{AbstractExpr,AbstractArray}, + y::Union{AbstractExpr,AbstractArray}, +) + return -relative_entropy(y, x) +end """ LinearAlgebra.logdet(X::Convex.AbstractExpr) @@ -1821,7 +1826,7 @@ Base.real(x::Constant) = x """ relative_entropy(x::Convex.AbstractExpr, y::Convex.AbstractExpr) -The epigraph of \$\\sum y_i*\\log \\frac{x_i}{y_i}\$. +The epigraph of \$\\sum x_i*\\log \\frac{x_i}{y_i}\$. ## Examples @@ -1845,7 +1850,6 @@ julia> x = Variable(3); julia> y = Variable(3); -julia> atom = relative_entropy(x, y) julia> atom = relative_entropy(x, y) relative_entropy (convex; real) ├─ 3-element real variable (id: 906…671) @@ -1857,6 +1861,16 @@ julia> size(atom) """ relative_entropy(x::AbstractExpr, y::AbstractExpr) = RelativeEntropyAtom(x, y) +function relative_entropy(x::AbstractExpr, y::AbstractArray) + return RelativeEntropyAtom(x, constant(y)) +end + +function relative_entropy(x::AbstractArray, y::AbstractExpr) + return RelativeEntropyAtom(constant(x), y) +end + +relative_entropy(x::AbstractArray, y::AbstractArray) = sum(x .* log.(x ./ y)) + """ Base.reshape(x::AbstractExpr, m::Int, n::Int) diff --git a/test/test_atoms.jl b/test/test_atoms.jl index aa9ba5ef6..2fa320be3 100644 --- a/test/test_atoms.jl +++ b/test/test_atoms.jl @@ -1196,6 +1196,22 @@ function test_RelativeEntropyAtom() return relative_entropy(x, y) end target = """ + variables: u, v1, v2 + minobjective: 1.0 * u + 0.0 + [1.0*u, 1.0*v1, 1.0*v2, 2.0, 3.0] in RelativeEntropyCone(5) + """ + _test_atom(target) do context + return relative_entropy([2, 3], Variable(2)) + end + target = """ + variables: u, w1, w2 + minobjective: 1.0 * u + 0.0 + [1.0*u, 2.0, 3.0, 1.0*w1, 1.0*w2] in RelativeEntropyCone(5) + """ + _test_atom(target) do context + return relative_entropy(Variable(2), [2, 3]) + end + target = """ variables: u, v1, v2, w1, w2 minobjective: -1.0 * u + 1.0 [1.0*u, 1.0*v1, 1.0*v2, 1.0*w1, 1.0*w2] in RelativeEntropyCone(5) @@ -1205,6 +1221,24 @@ function test_RelativeEntropyAtom() y = Variable(2) return 1.0 + log_perspective(x, y) end + target = """ + variables: u, w1, w2 + minobjective: -1.0 * u + 1.0 + [1.0*u, 1.0*w1, 1.0*w2, 2.0, 3.0] in RelativeEntropyCone(5) + """ + _test_atom(target) do context + return 1.0 + log_perspective(Variable(2), [2, 3]) + end + target = """ + variables: u, w1, w2 + minobjective: -1.0 * u + 1.0 + [1.0*u, 2.0, 3.0, 1.0*w1, 1.0*w2] in RelativeEntropyCone(5) + """ + _test_atom(target) do context + x = [2.0, 3.0] + y = Variable(2) + return 1.0 + log_perspective(x, y) + end x, y = Variable(2), im * Variable(2) @test_throws( ErrorException( @@ -1229,7 +1263,9 @@ function test_RelativeEntropyAtom() @test evaluate(atom) ≈ log(0.5) x.value = [5.0, 1.0] y.value = [3.0, 2.0] - @test evaluate(atom) ≈ 5 * log(5 / 3) + log(0.5) + u = 5 * log(5 / 3) + log(0.5) + @test evaluate(atom) ≈ u + @test relative_entropy([5.0, 1.0], [3.0, 2.0]) ≈ u return end