Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: plt.hist takes significantly more time with torch and jax arrays #25882

Closed
patel-zeel opened this issue May 13, 2023 · 13 comments · Fixed by #25887
Closed

[Bug]: plt.hist takes significantly more time with torch and jax arrays #25882

patel-zeel opened this issue May 13, 2023 · 13 comments · Fixed by #25887

Comments

@patel-zeel
Copy link
Contributor

patel-zeel commented May 13, 2023

Bug summary

Hi,

Time taken to plot plt.hist directly on jax or torch arrays is significantly more than combined time taken to first convert them to numpy and then using plt.hist. Shouldn't matplotlib internally convert them to numpy arrays before plotting?

To reproduce the bug, directly run the following snippet on Google Colab.

Code for reproduction

from time import time
import numpy as np

import torch

import jax
import jax.random as jr
import jax.numpy as jnp

import matplotlib.pyplot as plt

jax_array = jr.normal(jr.PRNGKey(0), (1000, 150))
torch_array = torch.randn(1000, 150)

def plot_hist(array):
    init = time()
    plt.figure()
    plt.hist(array)
    print(f"Time to plot: {time() - init:.2f} s")
    plt.show()
    
plot_hist(jax_array.ravel())
plot_hist(torch_array.ravel())
plot_hist(np.array(jax_array.ravel()))
plot_hist(np.array(torch_array.ravel()))

Actual outcome

Time to plot: 4.19 s
image

Time to plot: 2.61 s
image

Time to plot: 0.03 s
image

Time to plot: 0.04 s
image

Expected outcome

Time to plot: 0.03 s

Time to plot: 0.04 s

Time to plot: 0.03 s

Time to plot: 0.04 s

Additional information

What are the conditions under which this bug happens? input parameters, edge cases, etc?

It is happening with all kinds of shapes.

Has this worked in earlier versions?

Tested with default colab matplotlib version 3.7.1 and also with 3.6.3.

Do you know why this bug is happening?

Not exactly sure.

Do you maybe even know a fix?

Maybe convert any python object to a numpy array before plotting?

Operating system

Ubuntu 20.04.5 LTS

Matplotlib Version

3.7.1

Matplotlib Backend

module://matplotlib_inline.backend_inline

Python version

3.10.11

Jupyter version

6.4.8

Installation

None

@oscargus
Copy link
Contributor

The unpacking happens here:

def _unpack_to_numpy(x):
"""Internal helper to extract data from e.g. pandas and xarray objects."""
if isinstance(x, np.ndarray):
# If numpy, return directly
return x
if hasattr(x, 'to_numpy'):
# Assume that any to_numpy() method actually returns a numpy array
return x.to_numpy()
if hasattr(x, 'values'):
xtmp = x.values
# For example a dict has a 'values' attribute, but it is not a property
# so in this case we do not want to return a function
if isinstance(xtmp, np.ndarray):
return xtmp
return x

The pytorch tensor does not support any of the conversion methods, so Matplotlib doesn't really know what to do with it. There is a discussion in #22645 about this, but if I remember correctly we expect the libraries to support the to_numpy method (but still support the values attribute).

(I could not install jax, but I suppose something similar goes on there.)

@oscargus
Copy link
Contributor

And when the conversion doesn't work, it ends up in this loop:

result = []
is_1d = True
for xi in X:
# check if this is iterable, except for strings which we
# treat as singletons.
if not isinstance(xi, str):
try:
iter(xi)
except TypeError:
pass
else:
is_1d = False
xi = np.asanyarray(xi)
nd = np.ndim(xi)
if nd > 1:
raise ValueError(f'{name} must have 2 or fewer dimensions')
result.append(xi.reshape(-1))

which is where most of the time is spent.

@patel-zeel
Copy link
Contributor Author

Thanks for the quick response, @oscargus! Given that both these libraries support .__array__() method for conversion to numpy array, wouldn't it be easier to add one more if condition in _unpack_to_numpy to include them?

type(jax_array.__array__()), type(torch_array.__array__())
# Output: (numpy.ndarray, numpy.ndarray)

@oscargus
Copy link
Contributor

Yes, I also noted that. It probably can make sense.

(I think the reason why we do this somewhat carefully is for unit information to not get lost.)

Would you be interested in submitting a patch? I think that if this goes last in the conversion chain, it shouldn't break too many things... (A problem here is that we do not, yet, test for "all" types that possibly can be used and "works". There's been a discussion of having a special test suite for that, but it has not yet been implemented.)

@patel-zeel
Copy link
Contributor Author

patel-zeel commented May 13, 2023

Even tensorflow supports __array__() method. I guess these 3 libraries account for almost 99% of the machine learning codebase available online :) It'd be great if this conversion passes without breaking many things!

Sure, I'd submit a patch. I guess I need to only change the _unpack_to_numpy to the following, right?

def _unpack_to_numpy(x): 
     """Internal helper to extract data from e.g. pandas and xarray objects.""" 
     if isinstance(x, np.ndarray): 
         # If numpy, return directly 
         return x 
     if hasattr(x, 'to_numpy'): 
         # Assume that any to_numpy() method actually returns a numpy array 
         return x.to_numpy() 
     if hasattr(x, 'values'): 
         xtmp = x.values 
         # For example a dict has a 'values' attribute, but it is not a property 
         # so in this case we do not want to return a function 
         if isinstance(xtmp, np.ndarray): 
             return xtmp 
     if hasattr(x, '__array__'):
         # Assume that any to __array__() method returns a numpy array (e.g. TensorFlow, JAX or PyTorch arrays)
         return x.__array__()
     return x 

@timhoffm
Copy link
Member

Yes, but please verify that __array__ actually returns a numpy array, like we do with values above.

@patel-zeel
Copy link
Contributor Author

patel-zeel commented May 13, 2023

Thank you for the important suggestion, @timhoffm. Now, __array__ method check works in theory for the cases I imagined but np.float32 type objects get stuck into that check. When __array__ method is called on np.float32 object, it gets converted to ndarray type and eventually this leads to an infinite recursion.

A temporary fix I could figure out is to add two more if conditions to check if object is of type np.floating (includes all float types) and type np.integer (includes all integer types including uint). I can also include a boolean check. Will it be all or this already looks unpythonic?

More directions to solve this issue could be the following:

  1. Raise an issue to add to_numpy() methods in JAX and PyTorch repos.
  2. Raise an issue to have a universal numpy object checker type in NumPy library so that we can replace ndarray check with that. After this, any numpy object will be captured in the first check.
  3. Add hard-coded checks for JAX and PyTorch like the following:
if str(type(x)) == "<class 'torch.Tensor'>":
    return x.__array__()
if str(type(x)) == "<class 'jaxlib.xla_extension.ArrayImpl'>":
    return x.__array__()

I am open to your suggestions.

Edit1: np.generic works for most (all?) scalars, so we can add if isinstance(x, np.generic) as the second check just after ndarray check like the following:

def _unpack_to_numpy(x):
    """Internal helper to extract data from e.g. pandas and xarray objects."""
    if isinstance(x, np.ndarray):
        # If numpy array, return directly
        return x
    if isinstance(x, np.generic):
       # If numpy scalar, return directly
        return x
    if hasattr(x, 'to_numpy'):
        # Assume that any to_numpy() method actually returns a numpy array
        return x.to_numpy()
    if hasattr(x, 'values'):
        xtmp = x.values
        # For example a dict has a 'values' attribute, but it is not a property
        # so in this case we do not want to return a function
        if isinstance(xtmp, np.ndarray):
            return xtmp
    if hasattr(x, '__array__'):
        # Assume that any to __array__() method returns a numpy array (e.g. TensorFlow, JAX or PyTorch arrays)
        x = x.__array__()
        # Anything that doesn't return ndarray via __array__() method will be filtered by the following check
        if isinstance(x, np.ndarray):
            return x
    return x

@tacaswell
Copy link
Member

I am not convinced this is a high-priority problem as there is a very simple user-side solution (cast to numpy before passing to Matplotlib!).

Given that, to my understanding, some of these arrays may actually be in GPU memory, forcing them back to CPU memory before passing to Matplotlib is probably a good idea anyway.

@jklymak
Copy link
Member

jklymak commented May 18, 2023

I agree this isn't high priority, but if an object has a numpy array representation, why would we ignore it?

@xstreck1
Copy link

I've just ran into the issue and Googled my way to here. While the solution is simple, it took me a while to find which is the offending function, so I'd support fixing it.

To address the issue of the tensor being in GPU, a simple check on what device the tensor is on could be added. My execution crashed with 10K elements on CPU, which I expected to take less than a second to plot.

@jklymak
Copy link
Member

jklymak commented Oct 22, 2023

To address the issue of the tensor being in GPU, a simple check on what device the tensor is on could be added. My execution crashed with 10K elements on CPU, which I expected to take less than a second to plot.

Just to set expectations, that is a request to torch, or something the user needs to take care of, not to Matplotlib. We don't officially support torch or jax arrays, and we aren't going to have extra checks to support them beyond what __array__ might support

@timhoffm
Copy link
Member

Sorry this slipped.

The code suggested in Edit1 of #25882 (comment) looks reasonable to me.

Tests should include the numpy scalar case and a mock object with an __array__ method (we don’t want to depend on torch or JAX).

@patel-zeel do you want to pick this up again.

@patel-zeel
Copy link
Contributor Author

patel-zeel commented Dec 30, 2023

Sure, @timhoffm! I will try this fix in my existing PR #25887.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants