-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added 'where' keyword to 'jnp.{mean, var, std}' #5966
Added 'where' keyword to 'jnp.{mean, var, std}' #5966
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Regarding the numpy incompatibility you mentioned... we should gate those cases with a numpy version check
I assumed it will be 1.20.2 so I added (...,2) to required version. |
Oh, I see... but you disabled the entire test in this case, right? It would be better to just skip the We can't merge this code if all tests that cover it are skipped. |
tests/lax_numpy_test.py
Outdated
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) | ||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) | ||
args_maker = lambda: [rng(shape, dtype)] | ||
if numpy_version >= (1, 20, 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the numpy issue, it looks like this works correctly for sum
but not for mean
. Can we only skip the tests in specific cases where numpy 1.20.1 returns the wrong result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just pushed changes that includes any
and all
from this test suite - the rest is failing. sum
is not present in this test suite as it's already tested with where
in testReducerInitialWhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, thanks!
I marked this PR so it's not a draft anymore. I think it's ready but in CI one test is failing that wasn't affected here. One of the |
Try with:
|
@jakevdp Hi! Thanks for a guide - it's all green now (changing existing test's name caused this failure - no idea why). Also numpy/numpy#18560 was merged yesterday so I guess it will be present in 1.20.2. Is it ready to merge or should I add anything? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Regarding the test name change causing failures: I believe it's related to the random seed being changed. One last request: could you squash all your commits into one and force-push to the branch? |
7914552
to
d743aa5
Compare
Sure - squashed and pushed, all green now. |
Thanks for all the work on this! |
Thank you for assistance! I guess #5842 can be closed. |
Addresses #5842
Hi @jakevdp!
Here's WiP PR for adding
where
keyword support tojnp.{mean, var, std}
as added in numpy1.20
.While working on it I encountered a bug in numpy's
np.mean
: support forwhere
keyword whenaxis
param is present results in internal error. I filed an issue: numpy/numpy#18552 and fix PR: numpy/numpy#18560 and I hope it will be included in1.20.2
but for now this is blocked (impressivejax
s test case generators cached it!).I followed numpy API when adding
where
keyword (also its source code). I also added a separate test suite as those methods do not takeinitial
argument (as they have their own identity defined). Also when generating thosewhere
masks it often masks the whole row that results in division by zero (but is legal operation, as in numpy) so I included additional warning suppressors. Is it correct to do?Thank you for any help!