-
-
Notifications
You must be signed in to change notification settings - Fork 9.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: accept zeros on numpy.random dirichlet function #23440
Conversation
Thanks for the PR. Just to let everybody know, this is in reference to #22547. |
numpy/tests/test_matlib.py
Outdated
def test_dirichlet(): | ||
dirichlet([5, 9, 0, 8]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_dirichlet(): | |
dirichlet([5, 9, 0, 8]) | |
def test_dirichlet(): | |
y = dirichlet([5, 9, 0, 8]) | |
assert y[2] == 0 |
We don't just want to show that the function doesn't throw an error, we also want to show that the function returns the right result when 0 is given as input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matteo is right, we should check this, thanks. However, this isn't the right file. The correct tests are in the numpy/random
submodule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, but all changes here need to be moved around, they should be OK changes, but they are at the wrong spot.
numpy/random/mtrand.pyx
Outdated
if np.any(np.less_equal(alpha_arr, 0)): | ||
raise ValueError('alpha <= 0') | ||
if np.any(np.less(alpha_arr, 0)): | ||
raise ValueError('alpha < 0') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please limit this change to generator.pyx
. I am not sure I mind this much (since it is an error being thrown), but the backward compat guarantees on RandomState
this file are very strong.
So, the change here is not correct, but the identical change on the generator would be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the changes to generator.pyx
as intended, and also passed the test to the numpy/random
submodule. However, the test would raise the alpha = 0
value error from mtrand.pyx
, so I maintained the changes that were made. The changes now seem OK, but I will wait for your feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value-error is desired in mtrand
(strict compatibility). The tests just need to be different. We should make sure that generator is tested, not mtrand
(i.e. not np.random.diriclet
but np.random.default_rng().dirichlet
. Those tests live in a different file.
In a sense, we want to nudge users towards preferring np.random.default_rng()
anyway, so there is probably not much reason to improve mtrand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, sorry for the long wait. I've read your comment and I've reverted the changes in mtrand
. Also fixed the test having np.random.default_rng()
in account. Hope these changes are good now.
numpy/tests/test_matlib.py
Outdated
def test_dirichlet(): | ||
dirichlet([5, 9, 0, 8]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matteo is right, we should check this, thanks. However, this isn't the right file. The correct tests are in the numpy/random
submodule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thanks for the followup, looks good now. |
@@ -812,6 +812,10 @@ def test_dirichlet_bad_alpha(self): | |||
alpha = np.array([5.4e-01, -1.0e-16]) | |||
assert_raises(ValueError, random.dirichlet, alpha) | |||
|
|||
def test_dirichlet_zero_alpha(self): | |||
y = random.default_rng().dirichlet([5, 9, 0, 8]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is using the result of default_rng()
, which is a Generator
instance (not RandomState
), so this test should not be in this file. I'll fix this in the pull request I am working on for #24210.
Changed
alpha
value error to pass a null value. This way, dirichlet function wont raise a value exception at 0. Also added newtest_dirichlet
function onnumpy.tests.test_matlib.py
.