Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.linalg.svd etc. do not respect __jax_array__ #10065

Open
patrick-kidger opened this issue Mar 29, 2022 · 27 comments
Open

jnp.linalg.svd etc. do not respect __jax_array__ #10065

patrick-kidger opened this issue Mar 29, 2022 · 27 comments
Assignees
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request

Comments

@patrick-kidger
Copy link
Collaborator

A couple things going on here. First of all, the following is an example of jnp.linalg.svd failing to respect __jax_array__.

import jax
import jax.numpy as jnp

class MyArray:
    def __jax_array__(self):
        return jnp.array([[1.]])

with jax.disable_jit():
    jnp.linalg.svd(MyArray())

# TypeError: Value '<__main__.MyArray object at 0x7f5ec0f584c0>' with dtype object is not a valid
# JAX array type. Only arrays of numeric types are supported by JAX.

Remove the disable_jit and this works.

The reason it works without disable_jit is that jnp.linalg.svd and friends all have jax.jit wrappers, which is what spots the __jax_array__ and handles things appropriately... unless the JAX arraylike is also a PyTree, in which case they don't. So this also fails (with a different error message this time):

import jax
import jax.numpy as jnp
from typing import NamedTuple

class MyArray(NamedTuple):
    def __jax_array__(self):
        return jnp.array([[1.]])

jnp.linalg.svd(MyArray())
# ValueError: Argument to singular value decomposition must have ndims >= 2

So whilst it takes either a disable_jit or a PyTree to actually trigger it, I think the fundamental issue here is that jnp.linalg.svd and friends do not check for JAX arraylikes.

@patrick-kidger patrick-kidger added the bug Something isn't working label Mar 29, 2022
@patrick-kidger patrick-kidger changed the title jnp.linalg.svd etc. do not respect __jax_array__ (+disable_jit sometimes failing) jnp.linalg.svd etc. does not respect __jax_array__ Mar 29, 2022
@patrick-kidger patrick-kidger changed the title jnp.linalg.svd etc. does not respect __jax_array__ jnp.linalg.svd etc. do not respect __jax_array__ Mar 29, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 29, 2022

I think the moral of the story here is we need every API entrypoint in JAX to call jnp.asarray() on array-like inputs.

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented Mar 29, 2022

import jax
import jax.numpy as jnp
import typing

class MyArray:
    def __jax_array__(self):
        return jnp.array([[1.]])

jnp.asarray(MyArray())

# TypeError: Value '<__main__.MyArray object at 0x7f67b007c4c0>' with dtype object is not a valid
# JAX array type. Only arrays of numeric types are supported by JAX.

unfortunately. (With our without the disable_jit.)

Nor in the PyTree case, in which it instead fails silently:

import jax
import jax.numpy as jnp
import typing

class MyArray(typing.NamedTuple):
    def __jax_array__(self):
        return jnp.array([[1.]])

out = jnp.asarray(MyArray())
print(repr(out))
# DeviceArray([], dtype=float32)

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 29, 2022

Well, in that case, we need __jax_array__ to be respected by jnp.array().

The real issue is that __jax_array__ is only partially supported throughout the JAX package with very little test coverage, and making it fully supported and tested will take a lot of work (for what its worth, this is the kind of issue I had in mind when I initially advocated against adding it)

I'll assign the issue to @mattjj, who didn't anticipate any support burden 😀

@patrick-kidger
Copy link
Collaborator Author

I'll assign the issue to @mattjj, who didn't anticipate any support burden 😀

Haha!

FWIW, as a developer I would also not have added __jax_array__ for exactly this reason. But as an end user I'm actually really enjoying having it around! For example over in patrick-kidger/equinox#53 we're discussing using it to implement spectral norm as a "parameterised weight" -- much like how PyTorch does via torch.nn.utils.parameterize.

@mattjj
Copy link
Member

mattjj commented Mar 29, 2022

@jakevdp that was a good one. Though the reason the support burden was limited is we didn't tell anyone about it or promise anything about it!

(Also I haven't paged in this whole thread yet but I'd be happy to delete __jax_array__.)

@mattjj mattjj added enhancement New feature or request and removed bug Something isn't working labels Mar 29, 2022
@mattjj
Copy link
Member

mattjj commented Mar 29, 2022

IIUC this issue only arises with an object that is both a pytree and has a __jax_array__ defined. I think we should say that in that case the pytree semantics take precedence, right? (In general that situation seems ambiguous as to whether something should be considered a container or an arraylike leaf. So we could also just say it's undefined behavior.)

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented Mar 29, 2022

(Also I haven't paged in this whole thread yet but I'd be happy to delete __jax_array__.)

:( Please don't!

IIUC this issue only arises with an object that is both a pytree and has a __jax_array__ defined. I think we should say that in that case the pytree semantics take precedence, right?

Nope -- it's just that __jax_array__ isn't being respected in some places.

I'm actually finding that it's very natural to have something that is both a PyTree and a JAX array:

class Buffer(typing.NamedTuple):
    value: jnp.ndarray

    def __jax_array__(self):
        return lax.stop_gradient(self.value)

class SpectralNorm(typing.NamedTuple):
    weight: jnp.ndarray
    u: jnp.ndarray
    v: jnp.ndarray

    def __jax_array__(self):
        u, v = power_iteration(self.u, self.v, self.weight)
        σ = jnp.einsum("i,ij,j->", u, self.weight, v)
        return self.weight / σ

class Symmetric(typing.NamedTuple):
    value: jnp.ndarray

    def __jax_array__(self):
        return self.value + self.value.T

etc. etc.

In each case it's a PyTree wrt JIT etc, as we need to transparently see the wrapped value. It's a __jax_array__ with respect to jnp etc. operations, which sees the transformed value.

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented Mar 29, 2022

If jnp.whatever happened to be generic over PyTrees then there would indeed be a problem. But they're not! They only accept JAX arrays. Meanwhile jax.jit only accepts PyTrees etc. and doesn't/shouldn't care about __jax_array__, I think. (So IMO it's wrong that jax.jit currently handles __jax_array__ in some places.)

AFAICT, PyTrees and __jax_array__ are pretty much orthogonal. Off the top of my head, there isn't an API which needs to make a choice between the two interpretations.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 29, 2022

Very interesting points @patrick-kidger – I started writing a comment pushing back against this idea, but while writing it I ended up realizing you're right 😁 . This idea of the orthogonality of tree flattening and __jax_array__ is pretty compelling. That said, if there is a class that is not a pytree and defines __jax_array__, ISTM that __jax_array__ should be called at the JIT boundary. What do you think?

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented Mar 29, 2022

:D

That said, if there is a class that is not a pytree and defines __jax_array__, ISTM that __jax_array__ should be called at the JIT boundary. What do you think?

I don't think so. (a) For consistency between pytree/non-pytree; (b) for consistency between jit/non-jit. For example the following would work without JIT but would fail with JIT.

class M:
    def __jax_array__(self):
        return jnp.array(1.)

    def foo(self):
        return jnp.array(2)

# @jax.jit
def call(m):
    return m + m.foo()

call(M())

Which is admittedly a bit contrived -- I don't think there's that many classes being passed into JIT that aren't also pytrees -- but I don't see a compelling reason to perform the __jax_array__ conversion across JIT boundaries either.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 29, 2022

I see what you're saying there, but I think it's probably counter to the original intent of the __jax_array__ mechanism, which (as I understand it) imagined it as a way of making an arbitrary object be interpreted by JAX as an array.

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented Mar 29, 2022

Right! Just, interpreted by jnp and friends, rather than interpreted by jax.jit. (What about jax.vmap, jax.grad etc?)

Perhaps our mental models as to what jax.jit should do are slightly different. (I see the argument the other way.) Anyway, wrt this latter point it's not a strong feeling on my part.

@mattjj
Copy link
Member

mattjj commented Mar 30, 2022

I disagree about orthogonality. For example, in jax.lax.scan(f, None, Symmetric(jnp.ones((1, 2)))), what's the length of the scanned-over axis? (In words, for APIs which accept pytrees-of-arrays like scan, an object which is both registered as a non-leaf pytree and has a __jax_array__ method can either be interpreted as a non-leaf pytree or as an array, and a decision must be made.)

I think we've got two discussions here, one narrowly about jax.linalg.svd and co (i.e. the original issue), and the other about how __jax_array__ should behave more generally, e.g. whether jit boundaries should call __jax_array__.

For jax.linalg.svd and co, there may be a quick way to extend __jax_array__ support to those functions even with disable_jit (and/or with pytree registration). But doing that is not a high priority enhancement. (I call it an enhancement and not a bug because #4725 made no promises about the functionality of __jax_array__ beyond whatever was handled in that PR.)

For the latter discussion about __jax_array__'s behavior e.g. at jit boundaries, while the behavior Patrick wants could be reasonable, it is indeed counter to the original intent of the __jax_array__ mechanism, which indeed was quite narrow.

I think we should

  1. eventually delete the __jax_array__ API, but also
  2. eventually provide more general user-level mechanisms for overriding jnp functions' behavior as well as transformation behaviors (we have some work on this already), and
  3. to better support users like @patrick-kidger, we can try to (a) sequence item 1 to be after item 2, and (b) consider making small fixes to __jax_array__ until item 2 lands.

WDYT?

@mattjj mattjj added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Mar 30, 2022
@patrick-kidger
Copy link
Collaborator Author

Hmm, good point about lax.scan. In a few cases a choice does have to be made.

Anyway, everything you say sounds reasonable. I enjoy using __jax_array__; any equivalent/more-general functionality also sounds great.

(In the end I think these kinds of proposals all end up being special cases of multiple dispatch.)

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Mar 30, 2022

Another problematic point might be jax.scipy.sparse.linalg.cg & co: imagine a PyTree-object that behaves as a matrix, but it is not one. And imagine that we can convert it to a jax.array.
What should the api do in that case? It should do the iterative solve using the lazy version, not the dense operator, which is there just for ease-of-use of users.

@pytree
class LazyMatMul:
   a
   b
   
   def __call__(self, v):
      return self.b@(self.a@v)

  def __matmul__(self, v):
      return self.b@(self.a@v)

   def __jax_array__(self):
      return self.b@self.a

ab = LazyMatMul(jnp.ones((3,3)), jnp.ones(3,3))
jax.scipy.sparse.linalg.cg(ab, jnp.ones(3))

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Mar 30, 2022

@mattjj In favour of not removing __jax_array__: it allows to write code that is backend-agnostic using NEP47.

@YouJiacheng
Copy link
Contributor

YouJiacheng commented Apr 1, 2022

maybe API can follow some priorities, e.g. method-specific-type(e.g. callable for cg) > pytree > __jax_array__.
BUT.
What about a pytree of (pytree and jax_array)?

@francois-rozet
Copy link

francois-rozet commented Jan 2, 2024

Hello @patrick-kidger and @mattjj, sorry to unearth this but I had the same problem with part of the JAX API. I want to make a PyTree which is a valid JAX array to attach metadata to arrays.

I noticed that jax._src.lax.lax.asarray does not comply with __jax_array__, which is easy to fix.

In addition, jax.core.Primitive.bind checks if arguments are jax.core.valid_jaxtype, but does not call __jax_array__() on arguments that need it. This is problematic if the result of __jax_array__() is a Tracer. Once again it is easy to fix as well.

With these two patches, it seems that my PyArray wrapper is compatible with a good part of jax.numpy and jax.lax. It worked with everything I tried actually.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 2, 2024

__jax_array__ is undocumented and (mostly) untested. We have no intent to support __jax_array__ universally in the JAX package, and I would suggest not writing code that relies on it.

I want to make a PyTree which is a valid JAX array to attach metadata to arrays.

Can you say more about this use-case? There may be better approaches to doing what you have in mind.

@francois-rozet
Copy link

francois-rozet commented Jan 2, 2024

Hello @jakevdp

Can you say more about this use-case? There may be better approaches to doing what you have in mind.

I am writing a small JAX library (Inox) in which modules are PyTrees (similar to Equinox) whose leaves are the internal arrays. However, I need a way to distinguish between arrays that are constants, parameters, running statistics, ... for updates. My approach is to wrap arrays into a shallow PyTree with static metadata (what I call a PyArray). This is similar to the way PyTorch and Flax indicate parameters.

With the __jax_array__ interface, I can make PyArray valid JAX arrays, meaning that users don't have to unpack PyArray instances to use them as arrays (same as torch.nn.Parameter). Note that I don't care about propagating the metadata.

I also tried to make PyArray a Tracer with the EvalTrace to not propagate the trace, but did not succeed.

__jax_array__ is undocumented

Indeed, it was fairly painful to find a solution. But I think the API has potential!

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 2, 2024

We don't have any support for this kind of implicit dispatch to user-defined types. It's something we've discussed, but we haven't yet found use-cases that warrant the kind of investment it would require. I'm fairly certain though that if we did this, it would not rely on __jax_array__ for dispatch.

For your use-case, I would suggest explicitly unwrapping your wrapped arrays before passing them to jax APIs.

@francois-rozet
Copy link

francois-rozet commented Jan 2, 2024

An alternative would be to enable metadata in arrays directly (names, roles, ...). In some sense that is the concept of Tracer.

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented Jan 2, 2024

An alternative would be to enable metadata in arrays directly (names, roles, ...). In some sense that is the concept of Tracer.

This can absolutely be done, actually -- take a look at Quax. this allows for tracing "array-ish" objects, and then doing multiple dispatch on them at the level of a primitive bind. This looks something like:

class LoraArray(quax.ArrayValue):
    ...

@quax.register(lax.dot_general_p)
def _(x: LoraArray, y: Array, ...);
    ...  # implement LoRA matmuls for this new array-ish type.

quax.quaxify(some_function)(LoraArray(...), jnp.array(...), ...)

with the quaxify transforming the function (just as jax.jit etc. do), and during the tracing multiple dispatch rules are looked up.

FWIW right now this is pretty experimental, but it may be useful to you as a starting point.

@francois-rozet
Copy link

francois-rozet commented Jan 2, 2024

That looks like what I need! Although, I have troubles understanding some parts of the Tracer and Trace interface which is not very well documented. For example shouldn't _QuaxTrace have a main attribute?

Also, instead of using a decorator, would it be possible to add a QuaxTrace in the trace_stack without ever removing it? Like the EvalTrace at the bottom.

@patrick-kidger
Copy link
Collaborator Author

That looks like what I need! Although, I have troubles understanding some parts of the Tracer and Trace interface which is not very well documented. For example shouldn't _QuaxTrace have a main attribute?

It does, via jax.core.Trace itself:

jax/jax/_src/core.py

Lines 525 to 535 in afa2f1e

class Trace(Generic[TracerType]):
__slots__ = ['main', 'level', 'sublevel']
main: MainTrace
level: int
sublevel: Sublevel
def __init__(self, main: MainTrace, sublevel: Sublevel) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel

Also, instead of using a decorator, would it be possible to add a QuaxTrace in the trace_stack without ever removing it? Like the EvalTrace at the bottom.

Hmm, that's an interesting idea! And one that I really like, actually.

I'm not competely sure -- it might end up having to touch JAX internals? (E.g. it might end up being morally equivalent to monkey-patching EvalTrace, which wouldn't be great.)


Okay, let's do some inside baseball -- let me just say a bit more about your idea of putting a QuaxTrace at the bottom of the trace_stack.

A big part of why I like that idea so much is that right now it's a fairly complicated business to write a dispatch rule. Using LoRA as an example, we actually need our Quax rule to call back into a special version of quaxify:

https://github.com/patrick-kidger/quax/blob/1a4d4e5c8ad6f5289673ea7c4209f6f23b8c21e4/quax/lora/_core.py#L168-L169

in order to redispatch on the type of rhs here:

https://github.com/patrick-kidger/quax/blob/1a4d4e5c8ad6f5289673ea7c4209f6f23b8c21e4/quax/lora/_core.py#L188-L197

or to handle the possibility that lhs.{w,a,b}, or are themselves array-ish values. This ends up making it a fairly tricky business to write rules correctly.

This also ties in with a personal (private) project I've got, reimplementing + varying some of JAX's ideas. This does something very similar to your suggestion -- it has multiple dispatch as a first-class citizen of the bottom-of-stack evaluation trace, and in doing so actually manages to handle stuff like abstract evaluation (and maybe also JIT'ing?) as a special case of this single notion of evaluation. And this is pretty neat!

Besides the above, I can see that you're also interested in this because it removes the need for the quaxify wrapper itself, which is a plus for usability.

If you're interested in playing with this idea then I'd love to know what you find. And I'm definitely open to changing Quax in this way if there's a better design choice. Maybe we can do something interesting with this!

@francois-rozet
Copy link

francois-rozet commented Jan 4, 2024

I will probably try to add multiple dispatch to the autodidax tutorial instead of the actual JAX API. I was also curious about caching dispatch results (e.g. in the case of reparameterizations).

@jakevdp Would adding a permanent DispatchTrace in trace_satck or modifying EvalTrace be a viable solution to enable multiple dispatch in JAX?

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 4, 2024

We would probably not make any change of that scope without a more comprehensive design process: i.e. writing out the goals and non-goals, evaluating the space of possible approaches and their advantages and disadvantages, giving stakeholders time to provide feedback, and only then starting to write the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants