Skip to content

Commit

Permalink
cleanup & bugfix dot (#524)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericphanson committed Dec 29, 2023
1 parent f47b5a4 commit 15f05fb
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 83 deletions.
7 changes: 3 additions & 4 deletions docs/src/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,15 @@ LP solver.
| `x-y` or `x.-y` | subtraction | affine | increasing in $x$ decreasing in $y$ | none none |
| `x*y` | multiplication | affine | increasing if constant term $\ge 0$ decreasing if constant term $\le 0$ not monotonic otherwise | PR: one argument is constant |
| `x/y` | division | affine | increasing | PR: $y$ is scalar constant |
| `dot(*)(x, y)` | elementwise multiplication | affine | increasing | PR: one argument is constant |
| `dot(/)(x, y)` | elementwise division | affine | increasing | PR: one argument is constant |
| `x .* y` | elementwise multiplication | affine | increasing | PR: one argument is constant |
| `x ./ y` | elementwise division | affine | increasing | PR: one argument is constant |
| `x[1:4, 2:3]` | indexing and slicing | affine | increasing | none |
| `diag(x, k)` | $k$-th diagonal of a matrix | affine | increasing | none |
| `diagm(x)` | construct diagonal matrix | affine | increasing | PR: $x$ is a vector |
| `x'` | transpose | affine | increasing | none |
| `vec(x)` | vector representation | affine | increasing | none |
| `dot(x,y)` | $\sum_i x_i y_i$ | affine | increasing | PR: one argument is constant |
| `kron(x,y)` | Kronecker product | affine | increasing | PR: one argument is constant |
| `vecdot(x,y)` | `dot(vec(x),vec(y))` | affine | increasing | PR: one argument is constant |
| `sum(x)` | $\sum_{ij} x_{ij}$ | affine | increasing | none |
| `sum(x, k)` | sum elements across dimension $k$ | affine | increasing | none |
| `sumlargest(x, k)` | sum of $k$ largest elements of $x$ | convex | increasing | none |
Expand Down Expand Up @@ -82,7 +81,7 @@ any solver that can solve both LPs and SOCPs can solve the problem.
| `sumsquares(x)` | $\sum x_i^2$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | none |
| `sqrt(x)` | $\sqrt{x}$ | concave | decreasing | IC: $x>0$ |
| `square(x), x^2` | $x^2$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | PR : $x$ is scalar |
| `dot(^)(x,2)` | $x.^2$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | elementwise |
| `x .^ 2` | $x.^2$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | elementwise |
| `geomean(x, y)` | $\sqrt{xy}$ | concave | increasing | IC: $x\ge0$, $y\ge0$ |
| `huber(x, M=1)` | $\begin{cases} x^2 &\|x\| \leq M \\ 2M\|x\| - M^2 &\|x\| > M \end{cases}$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | PR: $M>=1$ |

Expand Down
3 changes: 3 additions & 0 deletions docs/src/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Breaking changes:
* `x + A` will error if `x` is a scalar variable and `A` is an array. Instead, use `x * ones(size(A)) + A`.
* The `RelativeEntropyAtom` now returns a scalar value instead of elementwise values. This does not affect the result of `relative_entropy`.
* The function `constant` should be used instead of the type `Constant` (which now refers to exclusively real constants).
* The syntaxes `dot(*)`, `dot(/)` and `dot(^)` have been removed in favor of explicit broadcasting (`x .* y`, `x ./ y`, and `x .^ y`). These were (mild) type piracy.
* `vecdot(x,y)` has been removed. Call `dot(vec(x), vec(y))` instead.


Other changes:
Expand All @@ -15,6 +17,7 @@ Other changes:
* `geomean` supports more than 2 arguments
* [Type piracy](https://docs.julialang.org/en/v1/manual/style-guide/#Avoid-type-piracy) of `imag` and `real` has been removed. This should not affect use of Convex. Unfortunately, piracy of `hcat`, `vcat`, and `hvcat` still remains.
* `sumlargesteigs` now enforces that it's argument is hermitian.
* Bugfix: `dot` now correctly complex-conjugates its first argument

## v0.15.4 (October 24, 2023)

Expand Down
1 change: 0 additions & 1 deletion src/Convex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ include("utilities/tree_print.jl")
include("utilities/tree_interface.jl")
include("utilities/show.jl")
include("utilities/iteration.jl")
include("utilities/broadcast.jl")
include("problem_depot/problem_depot.jl")

end
9 changes: 1 addition & 8 deletions src/atoms/affine/dot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,8 @@ ismatrix(::Any) = false
# as extending singleton dimensions. We need to ensure that the inputs have the same
# length, which broadcast will check for us if both inputs are vectors.
asvec(x) = convert(AbstractExpr, ismatrix(x) ? vec(x) : x)
_vecdot(x, y) = sum(broadcast(*, asvec(x), asvec(y)))
_vecdot(x, y) = sum(broadcast(*, conj(asvec(x)), asvec(y)))

dot(x::AbstractExpr, y::AbstractExpr) = _vecdot(x, y)
dot(x::Value, y::AbstractExpr) = _vecdot(x, y)
dot(x::AbstractExpr, y::Value) = _vecdot(x, y)

if isdefined(LinearAlgebra, :vecdot) # defined but deprecated
import LinearAlgebra: vecdot
end
Base.@deprecate vecdot(x::AbstractExpr, y::AbstractExpr) dot(x, y)
Base.@deprecate vecdot(x::Value, y::AbstractExpr) dot(x, y)
Base.@deprecate vecdot(x::AbstractExpr, y::Value) dot(x, y)
35 changes: 6 additions & 29 deletions src/atoms/affine/multiply_divide.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ end
# end

function dotmultiply(x, y)
if size(x) == (1, 1) || size(y) == (1, 1)
return x * y
end

if vexity(x) != ConstVexity()
if vexity(y) != ConstVexity()
error(
Expand Down Expand Up @@ -223,39 +227,12 @@ function dotmultiply(x, y)
return reshape(const_multiplier * vec(var), size(var)...)
end

function broadcasted(
::typeof(*),
x::Union{Constant,ComplexConstant},
y::AbstractExpr,
)
if x.size == (1, 1) || y.size == (1, 1)
return x * y
elseif size(y, 1) < size(x, 1) && size(y, 1) == 1
return dotmultiply(x, ones(size(x, 1)) * y)
elseif size(y, 2) < size(x, 2) && size(y, 2) == 1
return dotmultiply(x, y * ones(1, size(x, 1)))
else
return dotmultiply(x, y)
end
end
function broadcasted(
::typeof(*),
y::AbstractExpr,
x::Union{Constant,ComplexConstant},
)
return dotmultiply(x, y)
end

# if neither is a constant it's not DCP, but might be nice to support anyway for eg MultiConvex
function broadcasted(::typeof(*), x::AbstractExpr, y::AbstractExpr)
if x.size == (1, 1) || y.size == (1, 1)
return x * y
elseif vexity(x) == ConstVexity()
return dotmultiply(x, y)
elseif isequal(x, y)
if isequal(x, y)
return square(x)
else
return dotmultiply(y, x)
return dotmultiply(x, y)
end
end
function broadcasted(::typeof(*), x::Value, y::AbstractExpr)
Expand Down
10 changes: 10 additions & 0 deletions src/atoms/second_order_cone/qol_elementwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ function broadcasted(::typeof(^), x::AbstractExpr, k::Int)
error("raising variables to powers other than 2 is not implemented")
end

# handle literal case
function broadcasted(
::typeof(Base.literal_pow),
::typeof(^),
x::AbstractExpr,
::Val{k},
) where {k}
return broadcasted(^, x, k)
end

invpos(x::AbstractExpr) = QolElemAtom(constant(ones(x.size[1], x.size[2])), x)
function broadcasted(::typeof(/), x::Value, y::AbstractExpr)
return dotmultiply(constant(x), invpos(y))
Expand Down
19 changes: 9 additions & 10 deletions src/problem_depot/problems/affine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,31 +405,31 @@ end
::Type{T},
) where {T,test}
x = Variable(3)
p = maximize(sum(dot(*)(x, [1, 2, 3])), x <= 1; numeric_type = T)
p = maximize(sum(x .* [1, 2, 3]), x <= 1; numeric_type = T)

if test
@test problem_vexity(p) == AffineVexity()
end
handle_problem!(p)
if test
@test p.optval 6 atol = atol rtol = rtol
@test evaluate(sum((dot(*))(x, [1, 2, 3]))) 6 atol = atol rtol = rtol
@test evaluate(sum(x .* [1, 2, 3])) 6 atol = atol rtol = rtol
end

x = Variable(3, 3)
p = maximize(sum(dot(*)(x, eye(3))), x <= 1; numeric_type = T)
p = maximize(sum(x .* eye(3)), x <= 1; numeric_type = T)

if test
@test problem_vexity(p) == AffineVexity()
end
handle_problem!(p)
if test
@test p.optval 3 atol = atol rtol = rtol
@test evaluate(sum((dot(*))(x, eye(3)))) 3 atol = atol rtol = rtol
@test evaluate(sum(x .* eye(3))) 3 atol = atol rtol = rtol
end

x = Variable(5, 5)
p = minimize(x[1, 1], dot(*)(3, x) >= 3; numeric_type = T)
p = minimize(x[1, 1], 3 .* x >= 3; numeric_type = T)

if test
@test problem_vexity(p) == AffineVexity()
Expand All @@ -441,7 +441,7 @@ end
end

x = Variable(3, 1)
p = minimize(sum(dot(*)(ones(3, 3), x)), x >= 1; numeric_type = T)
p = minimize(sum(ones(3, 3) .* x), x >= 1; numeric_type = T)

if test
@test problem_vexity(p) == AffineVexity()
Expand All @@ -453,7 +453,7 @@ end
end

x = Variable(1, 3)
p = minimize(sum(dot(*)(ones(3, 3), x)), x >= 1; numeric_type = T)
p = minimize(sum(ones(3, 3) .* x), x >= 1; numeric_type = T)

if test
@test problem_vexity(p) == AffineVexity()
Expand All @@ -465,16 +465,15 @@ end
end

x = Variable(1, 3, Positive())
p = maximize(sum(dot(/)(x, [1 2 3])), x <= 1; numeric_type = T)
p = maximize(sum(x ./ [1 2 3]), x <= 1; numeric_type = T)

if test
@test problem_vexity(p) == AffineVexity()
end
handle_problem!(p)
if test
@test p.optval 11 / 6 atol = atol rtol = rtol
@test evaluate(sum((dot(/))(x, [1 2 3]))) 11 / 6 atol = atol rtol =
rtol
@test evaluate(sum(x ./ [1 2 3])) 11 / 6 atol = atol rtol = rtol
end

# Broadcast fusion works
Expand Down
23 changes: 18 additions & 5 deletions src/problem_depot/problems/socp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ end
A = [1 2; 2 1; 3 4]
b = [2; 3; 4]
expr = A * x + b
p = minimize(sum(dot(^)(expr, 2)); numeric_type = T) # elementwise ^
# `literal_pow` case:
p = minimize(sum(expr .^ 2); numeric_type = T) # elementwise ^
if test
@test problem_vexity(p) == ConvexVexity()
end
Expand All @@ -188,16 +189,28 @@ end
rtol
end

p = minimize(sum(dot(*)(expr, expr)); numeric_type = T) # elementwise *
# Test non-literal case:
k = 2
p = minimize(sum(expr .^ k); numeric_type = T) # elementwise ^
if test
@test problem_vexity(p) == ConvexVexity()
end
handle_problem!(p)
if test
@test p.optval 0.42105 atol = atol rtol = rtol
@test evaluate(sum((dot(*))(expr, expr))) 0.42105 atol = atol rtol =
@test evaluate(sum(broadcast(^, expr, 2))) 0.42105 atol = atol rtol =
rtol
end

p = minimize(sum(expr .* expr); numeric_type = T) # elementwise *
if test
@test problem_vexity(p) == ConvexVexity()
end
handle_problem!(p)
if test
@test p.optval 0.42105 atol = atol rtol = rtol
@test evaluate(sum(expr .* expr)) 0.42105 atol = atol rtol = rtol
end
end

@add_problem socp function socp_inv_pos_atom(
Expand Down Expand Up @@ -227,13 +240,13 @@ end
end

x = Variable(3)
p = minimize(sum(dot(/)([3, 6, 9], x)), x <= 3; numeric_type = T)
p = minimize(sum([3, 6, 9] ./ x), x <= 3; numeric_type = T)

handle_problem!(p)
if test
@test evaluate(x) fill(3.0, (3, 1)) atol = atol rtol = rtol
@test p.optval 6 atol = atol rtol = rtol
@test evaluate(sum((dot(/))([3, 6, 9], x))) 6 atol = atol rtol = rtol
@test evaluate(sum([3, 6, 9] ./ x)) 6 atol = atol rtol = rtol
end

x = Variable()
Expand Down
26 changes: 0 additions & 26 deletions src/utilities/broadcast.jl

This file was deleted.

6 changes: 6 additions & 0 deletions test/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ end
set_value!(x, 1.5)
@test evaluate(expr) log(1 + exp(1.5))
end

@testset "`dot` (issue #508)" begin
x = [1.0 + 1.0im]
y = [-1.0im]
@test dot(x, y) evaluate(dot(constant(x), y))
end

0 comments on commit 15f05fb

Please sign in to comment.