Skip to content

Commit

Permalink
Fix sign error with log_sum_exp (#1689)
Browse files Browse the repository at this point in the history
* fix sign error with log_sum_exp

* docs

(cherry picked from commit 06854d9)
  • Loading branch information
SteveDiamond authored and rileyjmurray committed May 16, 2022
1 parent c395696 commit 0dcf847
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
3 changes: 2 additions & 1 deletion cvxpy/atoms/log_sum_exp.py
Expand Up @@ -69,7 +69,8 @@ def _column_grad(self, value):
def sign_from_args(self) -> Tuple[bool, bool]:
"""Returns sign (is positive, is negative) of the expression.
"""
return (False, False)
# Non-negative when arg is non-negative.
return (self.args[0].is_nonneg(), False)

def is_atom_convex(self) -> bool:
"""Is the atom convex?
Expand Down
15 changes: 15 additions & 0 deletions cvxpy/tests/test_atoms.py
Expand Up @@ -1167,3 +1167,18 @@ def test_loggamma(self) -> None:
[X == A])
result = prob.solve(solver=cp.SCS)
assert np.isclose(result, true_val.sum(), atol=1e0)

def test_log_sum_exp(self) -> None:
"""Test log_sum_exp sign.
"""
# Test for non-negative x
x = Variable(nonneg=True)
atom = cp.log_sum_exp(x)
self.assertEqual(atom.curvature, s.CONVEX)
self.assertEqual(atom.sign, s.NONNEG)

# Test for non-positive x
x = Variable(nonpos=True)
atom = cp.log_sum_exp(x)
self.assertEqual(atom.curvature, s.CONVEX)
self.assertEqual(atom.sign, s.UNKNOWN)

0 comments on commit 0dcf847

Please sign in to comment.