-
-
Notifications
You must be signed in to change notification settings - Fork 7.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: plt.hist takes significantly more time with torch and jax arrays #25882
Comments
The unpacking happens here: matplotlib/lib/matplotlib/cbook.py Lines 2237 to 2251 in b61bb0b
The pytorch tensor does not support any of the conversion methods, so Matplotlib doesn't really know what to do with it. There is a discussion in #22645 about this, but if I remember correctly we expect the libraries to support the (I could not install jax, but I suppose something similar goes on there.) |
And when the conversion doesn't work, it ends up in this loop: matplotlib/lib/matplotlib/cbook.py Lines 1332 to 1348 in b61bb0b
which is where most of the time is spent. |
Thanks for the quick response, @oscargus! Given that both these libraries support type(jax_array.__array__()), type(torch_array.__array__())
# Output: (numpy.ndarray, numpy.ndarray) |
Yes, I also noted that. It probably can make sense. (I think the reason why we do this somewhat carefully is for unit information to not get lost.) Would you be interested in submitting a patch? I think that if this goes last in the conversion chain, it shouldn't break too many things... (A problem here is that we do not, yet, test for "all" types that possibly can be used and "works". There's been a discussion of having a special test suite for that, but it has not yet been implemented.) |
Even Sure, I'd submit a patch. I guess I need to only change the def _unpack_to_numpy(x):
"""Internal helper to extract data from e.g. pandas and xarray objects."""
if isinstance(x, np.ndarray):
# If numpy, return directly
return x
if hasattr(x, 'to_numpy'):
# Assume that any to_numpy() method actually returns a numpy array
return x.to_numpy()
if hasattr(x, 'values'):
xtmp = x.values
# For example a dict has a 'values' attribute, but it is not a property
# so in this case we do not want to return a function
if isinstance(xtmp, np.ndarray):
return xtmp
if hasattr(x, '__array__'):
# Assume that any to __array__() method returns a numpy array (e.g. TensorFlow, JAX or PyTorch arrays)
return x.__array__()
return x |
Yes, but please verify that |
Thank you for the important suggestion, @timhoffm. Now, A temporary fix I could figure out is to add two more More directions to solve this issue could be the following:
if str(type(x)) == "<class 'torch.Tensor'>":
return x.__array__()
if str(type(x)) == "<class 'jaxlib.xla_extension.ArrayImpl'>":
return x.__array__() I am open to your suggestions. Edit1: def _unpack_to_numpy(x):
"""Internal helper to extract data from e.g. pandas and xarray objects."""
if isinstance(x, np.ndarray):
# If numpy array, return directly
return x
if isinstance(x, np.generic):
# If numpy scalar, return directly
return x
if hasattr(x, 'to_numpy'):
# Assume that any to_numpy() method actually returns a numpy array
return x.to_numpy()
if hasattr(x, 'values'):
xtmp = x.values
# For example a dict has a 'values' attribute, but it is not a property
# so in this case we do not want to return a function
if isinstance(xtmp, np.ndarray):
return xtmp
if hasattr(x, '__array__'):
# Assume that any to __array__() method returns a numpy array (e.g. TensorFlow, JAX or PyTorch arrays)
x = x.__array__()
# Anything that doesn't return ndarray via __array__() method will be filtered by the following check
if isinstance(x, np.ndarray):
return x
return x |
I am not convinced this is a high-priority problem as there is a very simple user-side solution (cast to numpy before passing to Matplotlib!). Given that, to my understanding, some of these arrays may actually be in GPU memory, forcing them back to CPU memory before passing to Matplotlib is probably a good idea anyway. |
I agree this isn't high priority, but if an object has a numpy array representation, why would we ignore it? |
I've just ran into the issue and Googled my way to here. While the solution is simple, it took me a while to find which is the offending function, so I'd support fixing it. To address the issue of the tensor being in GPU, a simple check on what device the tensor is on could be added. My execution crashed with 10K elements on CPU, which I expected to take less than a second to plot. |
Just to set expectations, that is a request to torch, or something the user needs to take care of, not to Matplotlib. We don't officially support torch or jax arrays, and we aren't going to have extra checks to support them beyond what |
Sorry this slipped. The code suggested in Edit1 of #25882 (comment) looks reasonable to me. Tests should include the numpy scalar case and a mock object with an @patel-zeel do you want to pick this up again. |
Bug summary
Hi,
Time taken to plot
plt.hist
directly onjax
ortorch
arrays is significantly more than combined time taken to first convert them tonumpy
and then usingplt.hist
. Shouldn'tmatplotlib
internally convert them tonumpy
arrays before plotting?To reproduce the bug, directly run the following snippet on Google Colab.
Code for reproduction
Actual outcome
Time to plot: 4.19 s
Time to plot: 2.61 s
Time to plot: 0.03 s
Time to plot: 0.04 s
Expected outcome
Time to plot: 0.03 s
Time to plot: 0.04 s
Time to plot: 0.03 s
Time to plot: 0.04 s
Additional information
It is happening with all kinds of shapes.
Tested with default colab
matplotlib
version 3.7.1 and also with 3.6.3.Not exactly sure.
Maybe convert any python object to a
numpy
array before plotting?Operating system
Ubuntu 20.04.5 LTS
Matplotlib Version
3.7.1
Matplotlib Backend
module://matplotlib_inline.backend_inline
Python version
3.10.11
Jupyter version
6.4.8
Installation
None
The text was updated successfully, but these errors were encountered: