Skip to content

Commit

Permalink
Support AbstractArray in relative_entropy and log_perspective (#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Jun 17, 2024
1 parent b1ea4aa commit 6031860
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
20 changes: 17 additions & 3 deletions src/supported_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,12 @@ julia> size(atom)
(1, 1)
```
"""
log_perspective(x::AbstractExpr, y::AbstractExpr) = -relative_entropy(y, x)
function log_perspective(
x::Union{AbstractExpr,AbstractArray},
y::Union{AbstractExpr,AbstractArray},
)
return -relative_entropy(y, x)
end

"""
LinearAlgebra.logdet(X::Convex.AbstractExpr)
Expand Down Expand Up @@ -1821,7 +1826,7 @@ Base.real(x::Constant) = x
"""
relative_entropy(x::Convex.AbstractExpr, y::Convex.AbstractExpr)
The epigraph of \$\\sum y_i*\\log \\frac{x_i}{y_i}\$.
The epigraph of \$\\sum x_i*\\log \\frac{x_i}{y_i}\$.
## Examples
Expand All @@ -1845,7 +1850,6 @@ julia> x = Variable(3);
julia> y = Variable(3);
julia> atom = relative_entropy(x, y)
julia> atom = relative_entropy(x, y)
relative_entropy (convex; real)
├─ 3-element real variable (id: 906…671)
Expand All @@ -1857,6 +1861,16 @@ julia> size(atom)
"""
relative_entropy(x::AbstractExpr, y::AbstractExpr) = RelativeEntropyAtom(x, y)

function relative_entropy(x::AbstractExpr, y::AbstractArray)
return RelativeEntropyAtom(x, constant(y))
end

function relative_entropy(x::AbstractArray, y::AbstractExpr)
return RelativeEntropyAtom(constant(x), y)
end

relative_entropy(x::AbstractArray, y::AbstractArray) = sum(x .* log.(x ./ y))

"""
Base.reshape(x::AbstractExpr, m::Int, n::Int)
Expand Down
38 changes: 37 additions & 1 deletion test/test_atoms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,22 @@ function test_RelativeEntropyAtom()
return relative_entropy(x, y)
end
target = """
variables: u, v1, v2
minobjective: 1.0 * u + 0.0
[1.0*u, 1.0*v1, 1.0*v2, 2.0, 3.0] in RelativeEntropyCone(5)
"""
_test_atom(target) do context
return relative_entropy([2, 3], Variable(2))
end
target = """
variables: u, w1, w2
minobjective: 1.0 * u + 0.0
[1.0*u, 2.0, 3.0, 1.0*w1, 1.0*w2] in RelativeEntropyCone(5)
"""
_test_atom(target) do context
return relative_entropy(Variable(2), [2, 3])
end
target = """
variables: u, v1, v2, w1, w2
minobjective: -1.0 * u + 1.0
[1.0*u, 1.0*v1, 1.0*v2, 1.0*w1, 1.0*w2] in RelativeEntropyCone(5)
Expand All @@ -1205,6 +1221,24 @@ function test_RelativeEntropyAtom()
y = Variable(2)
return 1.0 + log_perspective(x, y)
end
target = """
variables: u, w1, w2
minobjective: -1.0 * u + 1.0
[1.0*u, 1.0*w1, 1.0*w2, 2.0, 3.0] in RelativeEntropyCone(5)
"""
_test_atom(target) do context
return 1.0 + log_perspective(Variable(2), [2, 3])
end
target = """
variables: u, w1, w2
minobjective: -1.0 * u + 1.0
[1.0*u, 2.0, 3.0, 1.0*w1, 1.0*w2] in RelativeEntropyCone(5)
"""
_test_atom(target) do context
x = [2.0, 3.0]
y = Variable(2)
return 1.0 + log_perspective(x, y)
end
x, y = Variable(2), im * Variable(2)
@test_throws(
ErrorException(
Expand All @@ -1229,7 +1263,9 @@ function test_RelativeEntropyAtom()
@test evaluate(atom) log(0.5)
x.value = [5.0, 1.0]
y.value = [3.0, 2.0]
@test evaluate(atom) 5 * log(5 / 3) + log(0.5)
u = 5 * log(5 / 3) + log(0.5)
@test evaluate(atom) u
@test relative_entropy([5.0, 1.0], [3.0, 2.0]) u
return
end

Expand Down

0 comments on commit 6031860

Please sign in to comment.