Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added 'where' keyword to 'jnp.{mean, var, std}' #5966

Merged
merged 1 commit into from
Mar 12, 2021

Conversation

mtsokol
Copy link
Contributor

@mtsokol mtsokol commented Mar 6, 2021

Addresses #5842

Hi @jakevdp!

Here's WiP PR for adding where keyword support to jnp.{mean, var, std} as added in numpy 1.20.

While working on it I encountered a bug in numpy's np.mean: support for where keyword when axis param is present results in internal error. I filed an issue: numpy/numpy#18552 and fix PR: numpy/numpy#18560 and I hope it will be included in 1.20.2 but for now this is blocked (impressive jaxs test case generators cached it!).

I followed numpy API when adding where keyword (also its source code). I also added a separate test suite as those methods do not take initial argument (as they have their own identity defined). Also when generating those where masks it often masks the whole row that results in division by zero (but is legal operation, as in numpy) so I included additional warning suppressors. Is it correct to do?

Thank you for any help!

@google-cla google-cla bot added the cla: yes label Mar 6, 2021
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Regarding the numpy incompatibility you mentioned... we should gate those cases with a numpy version check

tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Show resolved Hide resolved
tests/lax_numpy_test.py Show resolved Hide resolved
@mtsokol
Copy link
Contributor Author

mtsokol commented Mar 9, 2021

Looks good! Regarding the numpy incompatibility you mentioned... we should gate those cases with a numpy version check

I assumed it will be 1.20.2 so I added (...,2) to required version.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 9, 2021

I assumed it will be 1.20.2 so I added (...,2) to required version.

Oh, I see... but you disabled the entire test in this case, right? It would be better to just skip the CheckAgainstNumpy, because the JAX code by itself should work regardless of what numpy version is installed. And try to only skip the actual cases that are broken, because if I understand correctly, some of them should work correctly with 1.20.0.

We can't merge this code if all tests that cover it are skipped.

jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
if numpy_version >= (1, 20, 2):
Copy link
Collaborator

@jakevdp jakevdp Mar 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the numpy issue, it looks like this works correctly for sum but not for mean. Can we only skip the tests in specific cases where numpy 1.20.1 returns the wrong result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just pushed changes that includes any and all from this test suite - the rest is failing. sum is not present in this test suite as it's already tested with where in testReducerInitialWhere

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks!

@mtsokol mtsokol marked this pull request as ready for review March 10, 2021 10:02
@mtsokol
Copy link
Contributor Author

mtsokol commented Mar 10, 2021

I marked this PR so it's not a draft anymore. I think it's ready but in CI one test is failing that wasn't affected here. One of the tests/lax_numpy_test.py::LaxBackedNumpyTests::testReducerInitialWhereSum fails with being out of tolerance. I haven't modified sum or reduce methods here and I also can't reproduce it locally (with pytest -n auto -k "testReducerInitialWhere" tests/lax_numpy_test.py) so I don't know where it comes from.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 12, 2021

also can't reproduce it locally (with pytest -n auto -k "testReducerInitialWhere" tests/lax_numpy_test.py) so I don't know where it comes from.

Try with:

$ JAX_NUM_GENERATED_CASES=25 pytest -n auto -k "testReducerInitialWhere" tests/lax_numpy_test.py

@mtsokol
Copy link
Contributor Author

mtsokol commented Mar 12, 2021

@jakevdp Hi! Thanks for a guide - it's all green now (changing existing test's name caused this failure - no idea why). Also numpy/numpy#18560 was merged yesterday so I guess it will be present in 1.20.2. Is it ready to merge or should I add anything?

@jakevdp jakevdp self-assigned this Mar 12, 2021
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 12, 2021

Regarding the test name change causing failures: I believe it's related to the random seed being changed.

One last request: could you squash all your commits into one and force-push to the branch?

@mtsokol
Copy link
Contributor Author

mtsokol commented Mar 12, 2021

Sure - squashed and pushed, all green now.

@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Mar 12, 2021
@copybara-service copybara-service bot merged commit 77c1f31 into jax-ml:master Mar 12, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 12, 2021

Thanks for all the work on this!

@mtsokol
Copy link
Contributor Author

mtsokol commented Mar 12, 2021

Thank you for assistance! I guess #5842 can be closed.
Also if there's a next not-so-hard issue to take please let me know! I can try sth in coming weeks. If there isn't right now I will check issues page somewhere in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants