Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _unpack_to_numpy function to convert JAX and PyTorch arrays to NumPy #25887

Merged
merged 13 commits into from
Mar 13, 2024

Conversation

patel-zeel
Copy link
Contributor

@patel-zeel patel-zeel commented May 14, 2023

PR summary [Testing in process]

This PR closes #25882 by modifying the _unpack_to_numpy function. The main changes are the following.

  • Added an if condition to check if an object has __array__ method and the new object returned by accessing __array__ method is a NumPy array.
  • Added an if condition to capture NumPy scalars which were not captured by the ndarray check earlier. This was needed because otherwise NumPy scalar objects get infinitely stuck into __array__ check since they get converted to ndarray upon calling __array__ method on them.

PR checklist

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening your first PR into Matplotlib!

If you have not heard from us in a while, please feel free to ping @matplotlib/developers or anyone who has commented on the PR. Most of our reviewers are volunteers and sometimes things fall through the cracks.

You can also join us on gitter for real-time discussion.

For details on testing, writing docs, and our review process, please see the developer guide

We strive to be a welcoming and open project. Please follow our Code of Conduct.

@jklymak
Copy link
Member

jklymak commented May 14, 2023

Please remove the unrelated changes.

@tacaswell tacaswell added this to the v3.8.0 milestone May 14, 2023
@oscargus
Copy link
Contributor

oscargus commented May 15, 2023

Hmm, it seems like this is probably the clearest indication of the test failures:

>       assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours"
E       AssertionError: assert None == 'hours'

As the Quantity test class has an __array__ method

def __array__(self):
return np.asarray(self.magnitude)

the units will be dropped and the test fail.

I have no idea about the unit support though, so not really clear how to get around it...

@mwaskom
Copy link

mwaskom commented May 16, 2023

FYI I don't think you're supposed to call __array__ directly:

Users should not call this directly. Rather, it is invoked by numpy.array() and numpy.asarray().

But not sure how much it matters in practice, or whether another library like matplotlib is a "user" in this context.

@tacaswell
Copy link
Member

This is hitting the same reason why we do not blindly call np.asanyarray(...) on all of our input because that can strip off unit information (which will break some of our users, changing that is off the table for now).

Between this and the note about not calling __array__ directly probably means this particular solution is not the right path.

@patel-zeel
Copy link
Contributor Author

patel-zeel commented May 17, 2023

Thank you for a diverse set of feedback, @oscargus, @mwaskom and @tacaswell. So, what do you think would be a better way to go ahead?

  1. Implement to_numpy() method in JAX and PyTorch.
  2. JAX and PyTorch specific if check in _unpack_to_numpy function:
if str(type(x)) == "<class 'torch.Tensor'>":
    return x.numpy()
if str(type(x)) == "<class 'jaxlib.xla_extension.ArrayImpl'>":
    return np.asarray(x)

Or something completely different from these directions?

@oscargus
Copy link
Contributor

Do not trust this fully, but I think that checking if there is a numpy method and then calling that (and checking the output) would probably make sense and solve torch.Tensor. But I also think that from Matplotlib's perspective (and probably many other libraries), having a well defined interface, like to_numpy, on everything that can be converted to NumPy would make most sense.

@tacaswell
Copy link
Member

I am very 👎 on added string checking of types.

Unfortunately we are in a awkward bind where we very permissive in what we take as input, do not want to depend on any imports, and due to the diversity of input we can not treat them all the same.

@timhoffm
Copy link
Member

timhoffm commented May 18, 2023

I have not checked, but maybe there is something in the python array API standard - At least it would belong there.

@timhoffm
Copy link
Member

timhoffm commented May 18, 2023

Would __array_namespace__ be a solution for us?

@jakevdp
Copy link
Contributor

jakevdp commented May 18, 2023

There's a third option not mentioned here: use __array__ as the standard convert-to-array function, and give libraries where this is not the apropriate behavior here some way to opt-out. Adding a new standard (to_numpy) with poorly-defined semantics is not a great way forward in my opinion.

@mwaskom
Copy link

mwaskom commented May 18, 2023

My sense is that the direction data libraries would like to move is for exchange via the __array__ method, so despite throwing some cold water on that above, it's probably a better path than trying to get torch to add to_numpy in addition to numpy and __array__.

@jklymak
Copy link
Member

jklymak commented May 18, 2023

We discussed the history of this a bit on the dev call today. I think the below is close to correct, but I could be misunderstanding:

We officially support numpy arrays as inputs to our data plotting functions.

We also officially support mechanisms for objects to get passed that contain "unit" information (eg pint). Somewhat confusingly, this unit information is sometimes at the container level (eg pint), and sometimes at the element level, or the dtype of the elements (eg nparrays of datetime64, or lists of strings).

We unofficially support xarray and pandas objects, assume they have no units, by calling their values or to_numpy methods.

At the level that _unpack_to_numpy is called, we cannot strip units from objects with units, because they have not been checked for yet. In the case discussed here, it is indeed the unit checking that is slowing things down.

After we have checked for units, we usually call np.asarray. But we can't call that right away because of our unit support.

I'm not sure what the path out of the conundrum is - I somewhat feel the unit conversion interface should have been less magical, and more explicit, so users would have to specify a converter on an axis manually, rather than us guessing the converter.

@rgommers
Copy link

Would __array_namespace__ be a solution for us?

That isn't quite the right thing; the array API standard is meant to use "native functions", so this method is what you'd use if you want to retrieve the torch namespace and use torch.asarray & co. Here you specifically want numpy arrays instead.

I agree with @jakevdp and @mwaskom that use of __array__ is more idiomatic. The most standard thing is np.asarray (which relies on __array__, or the Python buffer protocol, or DLPack), but if that's too permissive than using __array__ directly is fine.

After we have checked for units, we usually call np.asarray. But we can't call that right away because of our unit support.

If units libraries silently lose data when np.asarray is called on their container objects, they really should implement __array__ and make it raise an exception. This is also what, for example, sparse arrays do.

@ksunden ksunden modified the milestones: v3.8.0, v3.9.0 Aug 8, 2023
@timhoffm
Copy link
Member

From an interface perspective, it's reasonable to rely on __array__. I think we should investigate how we can make this work internally.

@patel-zeel
Copy link
Contributor Author

@timhoffm It has been a while since I worked on this PR. Can you please suggest if your latest suggestion in #25882 will resolve the issue we are discussing in this PR? If not, could you please suggest potential workarounds?

@timhoffm
Copy link
Member

timhoffm commented Jan 4, 2024

@patel-zeel in my comment #25882 (comment), I hadn't considered the unit problem. That indeed makes the problem much more complicated.

To all: To summarize and comment on the above proposed solutions:

  1. Implement to_numpy() method in JAX and PyTorch
    IMO this won't happen. With some right, they say that __array__ is nowadays their ideomatic hook to turn them into numpy arrays.
  2. Import and type-check by type
    We don't want to try and import complex libraries just because someone might have passed an element of that type. This would be a performance hit for all users that happen to have that library installed, but don't use it.
  3. type-check by string
    This is inelegant and brittle.
  4. Rework our unit handling system
    Any changes to unit handling, that could help here, would definitively be a major project in matplotlib, and would likely also require changes for some downstream users.

There is no easy solution here. Special situations sometimes require special measures:

Given all the boundary conditions, I'd be +0.5 on type-checking by string, despite @tacaswell being strongly 👎 on this. Usually, I'd agree, but that's the only realistic way forward. 1. won't happen; 2. is introducing strong coupling, which IMHO is worse; 4. won't realistcally happen, because we don't have the capacity for it.

So what would we buy into with type-checking by string. Drawbacks are (1) the str comparison is slower than a type check - but that should be negligible; and (2) It's brittle because the str representation could change without us noticing and then the functionality would be broken. To alleviate (2), we could use f"{type(x).__module__}.{type(x).__qualname__}" That leaves out unnecessary fluff and would only change when the libraries reorganize. Additionally, in the worst case scenario that the string changes, we would fall back to the current solution.
In short: we can easily make using JAX/Torch arrays faster with the string-type check; With a not too high likelihood, that can break in the future, which would bring us back to where we are now. - Sounds like a reasonable deal to me.

The only other alternative would be to tell users to convert their JAX/Torch arrays explicitly (or live with the performance impact). But that'd be not user friendly.

@tacaswell
Copy link
Member

  1. won't realistcally happen, because we don't have the capacity for it.

This is what Kyle is working on, but is 1-2 years off, but I don't think we should wait for it.


I am convinced by @timhoffm 's analysis and am also +0.5 on string typing now.

@jakevdp
Copy link
Contributor

jakevdp commented Jan 5, 2024

Couldn’t you accomplish option 2 without the performance impact by looking to see if certain modules are already in sys.modules?

@patel-zeel
Copy link
Contributor Author

patel-zeel commented Jan 21, 2024

@timhoffm Thanks for the review and suggestion for testing. I have applied the suggested changes and implemented the first version of testing for this feature.

lib/matplotlib/tests/test_cbook.py Outdated Show resolved Hide resolved
lib/matplotlib/tests/test_cbook.py Outdated Show resolved Hide resolved
Copy link
Contributor

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good – one comment is that perhaps we should abstract this a bit: maybe have a configurable list of external objects to look for (e.g. external_objects = ['torch.Tensor', 'jax.Array']) and write just a single function that loops through and checks these.

Then in the future, if someone wanted to add cupy.ndarray or something it would be just a tiny change, and it could even be done at runtime if we wanted to provide that API.

lib/matplotlib/cbook.py Outdated Show resolved Hide resolved
@patel-zeel
Copy link
Contributor Author

patel-zeel commented Jan 25, 2024

@jakevdp How'd you suggest abstracting this? Some relevant points:

  • I tried creating a cupy array and getting the NumPy array with __array__() method but it breaks and suggests to use .get() method.
import cupy
array = cupy.array([1, 2, 3.0])
np_array = array.__array__()  # fails
# np_array = array.get()  # this works

Output:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], [line 1](vscode-notebook-cell:?execution_count=3&line=1)
----> [1](vscode-notebook-cell:?execution_count=3&line=1) array.__array__()

File cupy/_core/core.pyx:1475, in cupy._core.core._ndarray_base.__array__()

TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.
  • I came across this issue on PyTorch and realized that .numpy(force=True) method can help in cases where arrays need .detach() and/or .cpu() before .__array__() can successfully get the underlying NumPy array. As far as matplotlib is concerned, I guess .numpy(force=True) can be a better alternative compared to .__array__() for PyTorch (or even .detach().cpu().numpy() can work to support the older versions).

Considering both of the above cases, discussion in #25882, and discussion in this PR, would it be better to provide two methods, is_type and to_numpy, for each object like the following?

import sys
import numpy as np
from abc import ABC, abstractmethod

class TypeArray(ABC):
    @abstractmethod
    def is_type(x):
        pass
    
    @abstractmethod
    def to_numpy(x):
        pass

class TorchArray(TypeArray):
    def is_type(x):
        """Check if 'x' is a PyTorch Tensor."""
        try:
            # we're intentionally not attempting to import torch. If somebody
            # has created a torch array, torch should already be in sys.modules
            return isinstance(x, sys.modules['torch'].Tensor)
        except Exception:  # TypeError, KeyError, AttributeError, maybe others?
            # we're attempting to access attributes on imported modules which
            # may have arbitrary user code, so we deliberately catch all exceptions
            return False
    def to_numpy(x):
        """Convert to NumPy array"""
        # preferred over `.numpy(force=True)` to support older PyTorch versions.
        return x.detach().cpu().numpy()  
    
class JaxArray(TypeArray):
    def is_type(x):
        """Check if 'x' is a JAX array."""
        try:
            # we're intentionally not attempting to import jax. If somebody
            # has created a jax array, jax should already be in sys.modules
            return isinstance(x, sys.modules['jax'].Array)
        except Exception:  # TypeError, KeyError, AttributeError, maybe others?
            # we're attempting to access attributes on imported modules which
            # may have arbitrary user code, so we deliberately catch all exceptions
            return False
    def to_numpy(x):
        """Convert to NumPy array"""
        return x.__array__()  # works even if `x` is on GPU
    
class CupyArray(TypeArray):
    def is_type(x):
        """Check if 'x' is a CuPy array."""
        try:
            # we're intentionally not attempting to import cupy. If somebody
            # has created a cupy array, cupy should already be in sys.modules
            return isinstance(x, sys.modules['cupy'].ndarray)
        except Exception:  # TypeError, KeyError, AttributeError, maybe others?
            # we're attempting to access attributes on imported modules which
            # may have arbitrary user code, so we deliberately catch all exceptions
            return False
    def to_numpy(x):
        """Convert to NumPy array"""
        return x.get()

external_objects = [TorchArray, JaxArray, CupyArray]

def _unpack_to_numpy(x):
    """Internal helper to extract data from e.g. pandas and xarray objects."""
    if isinstance(x, np.ndarray):
        # If numpy, return directly
        return x
    if hasattr(x, 'to_numpy'):
        # Assume that any to_numpy() method actually returns a numpy array
        return x.to_numpy()
    if hasattr(x, 'values'):
        xtmp = x.values
        # For example a dict has a 'values' attribute, but it is not a property
        # so in this case we do not want to return a function
        if isinstance(xtmp, np.ndarray):
            return xtmp
    
    for obj in external_objects:
        assert isinstance(obj, TypeArray)
        if obj.is_type(x):
            xtmp = obj.to_numpy(x)
            
            # In case to_numpy() doesn't return a numpy array in future
            if isinstance(xtmp, np.ndarray):
                return xtmp
    return x

@timhoffm
Copy link
Member

IMHO further abstraction would be premature. The current implementation is simple and good enough. Paraphrased from https://youtu.be/UANN2Eu6ZnM?feature=shared

If something happens for the first time, do a concrete implementation. If it happens for the second time, copy andcadaptvrhe code. If it happens for the third time, factor out commonalities.

This has two major advantages: 1. You don't create abstractions that you don't use. 2. When you build the abstraction, you have three concrete use cases, so it's more likely the abstraction is suitable.

@jakevdp
Copy link
Contributor

jakevdp commented Jan 25, 2024

I didn't mean to suggest any complicated abstraction; I was thinking something simple like this:

ARRAYLIKE_OBJECTS = [('jax', 'Array'), ('torch', 'Tensor')]

def maybe_convert_to_array(x):
  for mod, name in ARRAYLIKE_OBJECTS:
    try:
      is_array = isinstance(x, getattr(sys.modules[mod], name)):
    except Exception:
      pass
    else:
      if is_array: return np.asarray(x)
  return x

It reduces duplication of logic and makes it easier to add additional types if/when needed.
If you wanted to add cupy support, it would just require doing ARRAYLIKE_OBJECTS.append(('cupy', 'ndarray'))

@@ -2358,6 +2382,12 @@ def _unpack_to_numpy(x):
# so in this case we do not want to return a function
if isinstance(xtmp, np.ndarray):
return xtmp
if _is_torch_array(x) or _is_jax_array(x):
xtmp = x.__array__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason not to do xtmp = np.asarray(x) here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are equivalent for the handled cases. While pandas claims you shouldn't call __array__ directly, I haven't found any official recommendation for it in numpy (which defines the __array__ API).

For me, either works. __array__ is more explicit, which can be a good thing, but might be too low level. OTOH I don't think np.asarray() will change the implementation so that the asarray abstraction would be safer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp @timhoffm What's your opinion about using x.numpy(force=True) or x.detach().cpu().numpy() for PyTorch? As a user, I'd find this change useful for day-to-day coding since it saves me from manually writing it for every array I want to plot with matplotlib.

if _is_torch_array(x):
    xtmp = x.numpy(force=True)  # or x.detach().cpu().numpy()
if _is_jax_array(x):
    xtmp = x.__array__()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to use the x.numpy method provided by torch. Though I have to say, I don't know whether __array__ would do something different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timhoffm To add more context, in JAX, .__array__() method converts a JAX array to a NumPy array irrespective of whether the JAX array is on CPU, GPU, TPU and probably other hardware accelerators. OTOH in PyTorch, they don't do it by default due to uncertainty about performance impact (full discussion is in this issue). So, when we do x.numpy(force=True), it forcefully converts it to NumPy irrespective of the device of the array (handles few others cases as well, e.g. if array has a computation graph for backpropagation then x.detach() needs to be called first). I am not sure if this should be handled in a separate issue or we can use x.numpy(force=True) in this PR itself but I'm sure that PyTorch users would love this change to avoid writing x.detach().cpu().numpy() every time they plot a PyTorch array.

cc @jakevdp.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My hesitation about this is that memory overflows will get shunted to Matpolotlib, and then we will get the bug reports, whereas if people did x.*.numpy() in their own code, they would see what the problem is. Jax arrays can be far larger than memory allows, and Matplotlib blindly unpacking them for naive users seems like a bad idea.

Copy link

@rgommers rgommers Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I have to say, I don't know whether __array__ would do something different.

__array__ is the last thing tried by NumPy when a user would call np.asarray, after the buffer protocol and __array_interface__ (as documented for example here). So they're not equivalent, and calling np.asarray is idiomatic.

@timhoffm
Copy link
Member

I didn't mean to suggest any complicated abstraction; I was thinking something simple like this:

[...]

It reduces duplication of logic and makes it easier to add additional types if/when needed. If you wanted to add cupy support, it would just require doing ARRAYLIKE_OBJECTS.append(('cupy', 'ndarray'))

Yes, that would be marginally better, and can optionally be done. In the interest of not endlessly bikeshedding the PR, I have accepted the current version. After all, this is all internal and can be refactored any time.

@jklymak
Copy link
Member

jklymak commented Jan 26, 2024

OK, apologies for not paying attention to this properly, but hard-coding certain libraries to have a cut around seems incorrect and brittle to me. What criteria will we have if we get requests to support other libraries?

I think the fundamental problem is with where cbook._reshape_2D gets called. This method is only used for hist and a couple of the helpers for violin_stats and boxplot_stats.

I think it would be a mistake to change _unpack_to_numpy, which is used for all unit conversion, and hence almost every other plot method, to work around this problem. If it were me, I'd split the preprocessing in hist and friends to keep units where they are required (eg using Quantity), and strip them when we are ready to do so.

The methods where this gets used are all binning methods. The bins need to keep units (so they can be added to the axes properly). However, the data needs to be turned into an array to pass to histogram (or other stats functions). I think the proper solution here is to properly differentiate these roles in hist (and friends).

@timhoffm
Copy link
Member

hard-coding certain libraries to have a cut around seems incorrect and brittle to me.

This is indeed a workaround. The proper way would be for _unpack_to_numpy() to use the __array__ interface for all types if available (maybe through np.asarray()). However, in the current internal usages this may lead to loss of units. If you have an alternate proposal how to make JAX and Pytorch arrays work, I'm more than happy to take that. - t.b.h. I don't fully oversee the unit handling and its implications.

Otherwise, I think this PR is good enough to be included in 3.9. It achieves the desired speedup and otherwise is completely internal, so we can still change the implementation whenever we like.

What criteria will we have if we get requests to support other libraries?

Case-by-case. Support them if it's easily possible, don't if it's not. There's little maintanance burden and no API liability. Also, I don't expect that there would be more than a hand full of such libraries.

@patel-zeel
Copy link
Contributor Author

@timhoffm @jakevdp Getting back to this after a while. To summarize the pending changes:

  • I guess .numpy(force=True) for PyTorch can go into silent memory overflows, so we may avoid doing that and let users do it manually if needed.
  • Simple abstraction suggested by @jakevdp seems great but np.asarray(x) doesn't work for cupy so I think that'd be hard to generalize just yet.

I think the accepted changes are optimal based on the current circumstances.

@tacaswell
Copy link
Member

I'm going to merge this to move forward. On one hand, I think it is reasonable to expect users to get their data back to the cpu and in numpy before we plot it, but we have gotten enough bugs and this is a light enough touch.

If using __array__() directly is something we really should not do, we can do a follow on tweaking the exact behavior for jax / torch.

This also sets a reasonble pattern for how we would add support for cupy / the next big library.

@tacaswell tacaswell merged commit 3323ae8 into matplotlib:main Mar 13, 2024
39 of 42 checks passed
@tacaswell
Copy link
Member

Thank you for your work on this @patel-zeel and congratulations on your first merged Matplotlib PR 🎉

I hope we hear from you again.

@patel-zeel
Copy link
Contributor Author

patel-zeel commented Mar 14, 2024

Thank you, @tacaswell, and all the contributors to this PR. This PR has taught me a lot of dev tricks. This wouldn't have been possible without everybody's diverse inputs, @jakevdp's robust ideas, and @timhoffm's pivotal role.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Waiting for author
Development

Successfully merging this pull request may close these issues.

[Bug]: plt.hist takes significantly more time with torch and jax arrays
9 participants