Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in np.mean with where keyword (added in 1.20 version) #18552

Closed
mtsokol opened this issue Mar 5, 2021 · 8 comments
Closed

Error in np.mean with where keyword (added in 1.20 version) #18552

mtsokol opened this issue Mar 5, 2021 · 8 comments

Comments

@mtsokol
Copy link
Member

mtsokol commented Mar 5, 2021

Hi!
While porting latest API changes introduced in 1.20 to jax library I've encountered an error with np.mean(.., where=mask) usage.
I've also linked colab reproduction link below. Is that an expected behavior?

Thank you for any help!

Reproducing code example:

import numpy as np
print(np.__version__) # should be `1.20`

a = np.random.randn(2,3,4)

# correctly computes sums along axis=2 with 'where' mask
np.sum(a, axis=2, keepdims=False, where=[False, True, False, True])

# but computing means with newly added 'where' keyword fails
np.mean(a, axis=2, keepdims=False, where=[False, True, False, True])

Here's also a reproduction in colab: https://colab.research.google.com/drive/1KFvp4BjMwO27iNVT4R9O3IvVfwMSw0nB?usp=sharing

Error message:

Traceback (most recent call last):
  File "bug.py", line 10, in <module>
    np.mean(a, axis=2, keepdims=False, where=[False, True, False, True])
  File "<__array_function__ internals>", line 5, in mean
  File "/usr/local/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 3419, in mean
    return _methods._mean(a, axis=axis, dtype=dtype,
  File "/usr/local/lib/python3.8/site-packages/numpy/core/_methods.py", line 167, in _mean
    if rcount == 0 if where is True else umr_any(rcount == 0):
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

NumPy/Python version information:

1.20.1 3.8.8 (default, Feb 21 2021, 10:35:39)
[Clang 12.0.0 (clang-1200.0.32.29)]
@mtsokol mtsokol changed the title Error in np.mean with where keword (added in 1.20 version) Error in np.mean with where keyword (added in 1.20 version) Mar 5, 2021
@seberg seberg added this to the 1.20.2 release milestone Mar 5, 2021
@seberg
Copy link
Member

seberg commented Mar 5, 2021

Annoying that this slipped past, but a new feature luckily. I guess the bug is with 1-D specificially or so? I am tagging for 1.20.2, but its not a regression, so if that doesn't happen its probably OK. Its probably a fairly straight forward issue also for new people to look into though.

@mtsokol
Copy link
Member Author

mtsokol commented Mar 5, 2021

@seberg Thank you for confirmation! If you consider it a good-first-issue I can try working on it.

@seberg
Copy link
Member

seberg commented Mar 5, 2021

Sure, if you are up for it, that is great! I expect the fix is fairly simple, and definitely local to that chunk of code. Not quite sure whats wrong. Maybe the if where is True is not correctly detected things for some reason?
The only note maybe is that the if where is True part is there to avoid calling into the ufunc machinery as a micro-optimization when where is not used.

@eric-wieser
Copy link
Member

eric-wieser commented Mar 5, 2021

That code is nonsense, rcount == 0 if where is True else umr_any(rcount == 0) means:~~

rcount == (0 if where is True else umr_any(rcount == 0))

which is not the intended

(rcount == 0) if where is True else umr_any(rcount == 0)

Nevermind

@eric-wieser
Copy link
Member

#17896 is the offending PR.

@eric-wieser
Copy link
Member

I expect var to be broken in the same way.

@seberg
Copy link
Member

seberg commented Mar 5, 2021

Hmm, maybe its not quite local to here. When where is True, rcount should be a scalar? var should have the same issue and get the same fix, I am still confused where the array comes from, but I still think that it should be an easy fix once that is clear...

@n2cholas
Copy link

n2cholas commented Mar 8, 2021

I believe #18491 refers to the same issue (so that can be closed when this one is). That issue confirms var and std are affected.

@seberg seberg closed this as completed in c5de5b5 Mar 11, 2021
charris pushed a commit to charris/numpy that referenced this issue Mar 14, 2021
…mpygh-18560)

* Fixed  keyword bug

* Added test case

* Reverted to original notation

* Added tests for var and std

Closes numpygh-18552
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants