Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: keepdims passed through 'min' correctly #2419

Closed
wants to merge 1 commit into from

Conversation

allenlawrence94
Copy link
Contributor

Description

The min atom ignores the keepdims flag.

Type of change

  • New feature (backwards compatible)
  • New feature (breaking API changes)
  • Bug fix
  • Other (Documentation, CI, ...)

Contribution checklist

  • Add our license to new files.
  • Check that your code adheres to our coding style.
  • Write unittests.
  • Run the unittests and check that they’re passing.
  • Run the benchmarks to make sure your change doesn’t introduce a regression.

@SteveDiamond
Copy link
Collaborator

Good catch!

@SteveDiamond
Copy link
Collaborator

We should add a test that would fail without this fix.

@phschiele
Copy link
Collaborator

Thanks @allenlawrence94! Would you mind adding a simple unit test as well?

@allenlawrence94
Copy link
Contributor Author

@phschiele - sure! Are there unit tests for the canonicalizers? Or would we rather just have a test on the min atom keepdims=True?

@SteveDiamond
Copy link
Collaborator

@phschiele - sure! Are there unit tests for the canonicalizers? Or would we rather just have a test on the min atom keepdims=True?

Do a test on the min atom. i.e. construct a problem then solve it that requires passing through keepdims=True

@SteveDiamond
Copy link
Collaborator

@allenlawrence94 please update the test_min function in test_atoms to the following code. We are trying to merge all the outstanding bugfix MRs.

    def test_min(self) -> None:
        """Test min.
        """
        # One arg, test sign.
        self.assertEqual(cp.min(1).sign, s.NONNEG)
        self.assertEqual(cp.min(-2).sign, s.NONPOS)
        self.assertEqual(cp.min(Variable()).sign, s.UNKNOWN)
        self.assertEqual(cp.min(0).sign, s.ZERO)

        # Test with axis argument.
        self.assertEqual(cp.min(Variable(2), axis=0).shape, tuple())
        self.assertEqual(cp.min(Variable(2), axis=1).shape, (2,))
        self.assertEqual(cp.min(Variable((2, 3)), axis=0).shape, (3,))
        self.assertEqual(cp.min(Variable((2, 3)), axis=1).shape, (2,))

        # Invalid axis.
        with self.assertRaises(Exception) as cm:
            cp.min(self.x, axis=4)
        self.assertEqual(str(cm.exception), "Invalid argument for axis.")
        with self.assertRaises(ValueError) as cm:
            cp.min(self.x, self.x)  # a common erroneous use-case
        self.assertEqual(str(cm.exception), cp.min.__EXPR_AXIS_ERROR__)

        # Test canonicalization with keepdims=True
        # https://github.com/cvxpy/cvxpy/pull/2419
        X = cp.Variable((2, 3))
        X_val = np.arange(6).reshape((2, 3))
        c = np.ones((1, 3))
        expr = cp.min(X, axis=0, keepdims=True)
        obj = cp.Maximize(cp.sum(expr + c))
        prob = cp.Problem(obj, [X == X_val])
        prob.solve()

@SteveDiamond
Copy link
Collaborator

I can't push to your fork for some reason, and we like to keep track of contributors so would rather not make a new PR.

@phschiele phschiele mentioned this pull request May 1, 2024
9 tasks
@SteveDiamond
Copy link
Collaborator

Picked up in #2431

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants