Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: accept zeros on numpy.random dirichlet function #23440

Merged
merged 5 commits into from
Apr 11, 2023

Conversation

pcralmeida
Copy link
Contributor

Changed alpha value error to pass a null value. This way, dirichlet function wont raise a value exception at 0. Also added new test_dirichlet function on numpy.tests.test_matlib.py.

@MatteoRaso
Copy link
Contributor

Thanks for the PR. Just to let everybody know, this is in reference to #22547.

Comment on lines 61 to 62
def test_dirichlet():
dirichlet([5, 9, 0, 8])
Copy link
Contributor

@MatteoRaso MatteoRaso Mar 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_dirichlet():
dirichlet([5, 9, 0, 8])
def test_dirichlet():
y = dirichlet([5, 9, 0, 8])
assert y[2] == 0

We don't just want to show that the function doesn't throw an error, we also want to show that the function returns the right result when 0 is given as input.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matteo is right, we should check this, thanks. However, this isn't the right file. The correct tests are in the numpy/random submodule.

Copy link
Member

@seberg seberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, but all changes here need to be moved around, they should be OK changes, but they are at the wrong spot.

if np.any(np.less_equal(alpha_arr, 0)):
raise ValueError('alpha <= 0')
if np.any(np.less(alpha_arr, 0)):
raise ValueError('alpha < 0')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please limit this change to generator.pyx. I am not sure I mind this much (since it is an error being thrown), but the backward compat guarantees on RandomState this file are very strong.

So, the change here is not correct, but the identical change on the generator would be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the changes to generator.pyx as intended, and also passed the test to the numpy/random submodule. However, the test would raise the alpha = 0 value error from mtrand.pyx, so I maintained the changes that were made. The changes now seem OK, but I will wait for your feedback.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value-error is desired in mtrand (strict compatibility). The tests just need to be different. We should make sure that generator is tested, not mtrand (i.e. not np.random.diriclet but np.random.default_rng().dirichlet. Those tests live in a different file.

In a sense, we want to nudge users towards preferring np.random.default_rng() anyway, so there is probably not much reason to improve mtrand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for the long wait. I've read your comment and I've reverted the changes in mtrand. Also fixed the test having np.random.default_rng() in account. Hope these changes are good now.

Comment on lines 61 to 62
def test_dirichlet():
dirichlet([5, 9, 0, 8])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matteo is right, we should check this, thanks. However, this isn't the right file. The correct tests are in the numpy/random submodule.

@pcralmeida pcralmeida requested a review from seberg March 25, 2023 22:18
Copy link
Contributor

@MatteoRaso MatteoRaso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@seberg seberg merged commit fe89374 into numpy:main Apr 11, 2023
@seberg
Copy link
Member

seberg commented Apr 11, 2023

Thanks for the followup, looks good now.

@@ -812,6 +812,10 @@ def test_dirichlet_bad_alpha(self):
alpha = np.array([5.4e-01, -1.0e-16])
assert_raises(ValueError, random.dirichlet, alpha)

def test_dirichlet_zero_alpha(self):
y = random.default_rng().dirichlet([5, 9, 0, 8])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is using the result of default_rng(), which is a Generator instance (not RandomState), so this test should not be in this file. I'll fix this in the pull request I am working on for #24210.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants