From 0dcf8477df6f2981ffce5d3c12e89b33e2892bfa Mon Sep 17 00:00:00 2001 From: Steven Diamond Date: Thu, 10 Mar 2022 18:57:25 -0800 Subject: [PATCH] Fix sign error with log_sum_exp (#1689) * fix sign error with log_sum_exp * docs (cherry picked from commit 06854d9d2cbcf68ccba4eac69112fecc10e4389e) --- cvxpy/atoms/log_sum_exp.py | 3 ++- cvxpy/tests/test_atoms.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/cvxpy/atoms/log_sum_exp.py b/cvxpy/atoms/log_sum_exp.py index 8303e0b185..2cef2d74af 100644 --- a/cvxpy/atoms/log_sum_exp.py +++ b/cvxpy/atoms/log_sum_exp.py @@ -69,7 +69,8 @@ def _column_grad(self, value): def sign_from_args(self) -> Tuple[bool, bool]: """Returns sign (is positive, is negative) of the expression. """ - return (False, False) + # Non-negative when arg is non-negative. + return (self.args[0].is_nonneg(), False) def is_atom_convex(self) -> bool: """Is the atom convex? diff --git a/cvxpy/tests/test_atoms.py b/cvxpy/tests/test_atoms.py index 2ba6d24db4..cd431a7c92 100644 --- a/cvxpy/tests/test_atoms.py +++ b/cvxpy/tests/test_atoms.py @@ -1167,3 +1167,18 @@ def test_loggamma(self) -> None: [X == A]) result = prob.solve(solver=cp.SCS) assert np.isclose(result, true_val.sum(), atol=1e0) + + def test_log_sum_exp(self) -> None: + """Test log_sum_exp sign. + """ + # Test for non-negative x + x = Variable(nonneg=True) + atom = cp.log_sum_exp(x) + self.assertEqual(atom.curvature, s.CONVEX) + self.assertEqual(atom.sign, s.NONNEG) + + # Test for non-positive x + x = Variable(nonpos=True) + atom = cp.log_sum_exp(x) + self.assertEqual(atom.curvature, s.CONVEX) + self.assertEqual(atom.sign, s.UNKNOWN)