New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jnp.linalg.svd
etc. do not respect __jax_array__
#10065
Comments
jnp.linalg.svd
etc. do not respect __jax_array__
(+disable_jit
sometimes failing)jnp.linalg.svd
etc. does not respect __jax_array__
jnp.linalg.svd
etc. does not respect __jax_array__
jnp.linalg.svd
etc. do not respect __jax_array__
I think the moral of the story here is we need every API entrypoint in JAX to call |
import jax
import jax.numpy as jnp
import typing
class MyArray:
def __jax_array__(self):
return jnp.array([[1.]])
jnp.asarray(MyArray())
# TypeError: Value '<__main__.MyArray object at 0x7f67b007c4c0>' with dtype object is not a valid
# JAX array type. Only arrays of numeric types are supported by JAX. unfortunately. (With our without the Nor in the PyTree case, in which it instead fails silently: import jax
import jax.numpy as jnp
import typing
class MyArray(typing.NamedTuple):
def __jax_array__(self):
return jnp.array([[1.]])
out = jnp.asarray(MyArray())
print(repr(out))
# DeviceArray([], dtype=float32) |
Well, in that case, we need The real issue is that I'll assign the issue to @mattjj, who didn't anticipate any support burden 😀 |
Haha! FWIW, as a developer I would also not have added |
@jakevdp that was a good one. Though the reason the support burden was limited is we didn't tell anyone about it or promise anything about it! (Also I haven't paged in this whole thread yet but I'd be happy to delete |
IIUC this issue only arises with an object that is both a pytree and has a |
:( Please don't!
Nope -- it's just that I'm actually finding that it's very natural to have something that is both a PyTree and a JAX array: class Buffer(typing.NamedTuple):
value: jnp.ndarray
def __jax_array__(self):
return lax.stop_gradient(self.value)
class SpectralNorm(typing.NamedTuple):
weight: jnp.ndarray
u: jnp.ndarray
v: jnp.ndarray
def __jax_array__(self):
u, v = power_iteration(self.u, self.v, self.weight)
σ = jnp.einsum("i,ij,j->", u, self.weight, v)
return self.weight / σ
class Symmetric(typing.NamedTuple):
value: jnp.ndarray
def __jax_array__(self):
return self.value + self.value.T etc. etc. In each case it's a PyTree wrt JIT etc, as we need to transparently see the wrapped value. It's a |
If AFAICT, PyTrees and |
Very interesting points @patrick-kidger – I started writing a comment pushing back against this idea, but while writing it I ended up realizing you're right 😁 . This idea of the orthogonality of tree flattening and |
:D
I don't think so. (a) For consistency between pytree/non-pytree; (b) for consistency between jit/non-jit. For example the following would work without JIT but would fail with JIT. class M:
def __jax_array__(self):
return jnp.array(1.)
def foo(self):
return jnp.array(2)
# @jax.jit
def call(m):
return m + m.foo()
call(M()) Which is admittedly a bit contrived -- I don't think there's that many classes being passed into JIT that aren't also pytrees -- but I don't see a compelling reason to perform the |
I see what you're saying there, but I think it's probably counter to the original intent of the |
Right! Just, interpreted by Perhaps our mental models as to what |
I disagree about orthogonality. For example, in I think we've got two discussions here, one narrowly about For For the latter discussion about I think we should
WDYT? |
Hmm, good point about Anyway, everything you say sounds reasonable. I enjoy using (In the end I think these kinds of proposals all end up being special cases of multiple dispatch.) |
Another problematic point might be @pytree
class LazyMatMul:
a
b
def __call__(self, v):
return self.b@(self.a@v)
def __matmul__(self, v):
return self.b@(self.a@v)
def __jax_array__(self):
return self.b@self.a
ab = LazyMatMul(jnp.ones((3,3)), jnp.ones(3,3))
jax.scipy.sparse.linalg.cg(ab, jnp.ones(3)) |
@mattjj In favour of not removing |
maybe API can follow some priorities, e.g. method-specific-type(e.g. |
This relates to the long discussion in google#4725 and google#10065.
Hello @patrick-kidger and @mattjj, sorry to unearth this but I had the same problem with part of the JAX API. I want to make a PyTree which is a valid JAX array to attach metadata to arrays. I noticed that In addition, With these two patches, it seems that my PyArray wrapper is compatible with a good part of |
Can you say more about this use-case? There may be better approaches to doing what you have in mind. |
Hello @jakevdp
I am writing a small JAX library (Inox) in which modules are PyTrees (similar to Equinox) whose leaves are the internal arrays. However, I need a way to distinguish between arrays that are constants, parameters, running statistics, ... for updates. My approach is to wrap arrays into a shallow PyTree with static metadata (what I call a With the I also tried to make
Indeed, it was fairly painful to find a solution. But I think the API has potential! |
We don't have any support for this kind of implicit dispatch to user-defined types. It's something we've discussed, but we haven't yet found use-cases that warrant the kind of investment it would require. I'm fairly certain though that if we did this, it would not rely on For your use-case, I would suggest explicitly unwrapping your wrapped arrays before passing them to |
An alternative would be to enable metadata in arrays directly (names, roles, ...). In some sense that is the concept of |
This can absolutely be done, actually -- take a look at Quax. this allows for tracing "array-ish" objects, and then doing multiple dispatch on them at the level of a primitive bind. This looks something like: class LoraArray(quax.ArrayValue):
...
@quax.register(lax.dot_general_p)
def _(x: LoraArray, y: Array, ...);
... # implement LoRA matmuls for this new array-ish type.
quax.quaxify(some_function)(LoraArray(...), jnp.array(...), ...) with the FWIW right now this is pretty experimental, but it may be useful to you as a starting point. |
That looks like what I need! Although, I have troubles understanding some parts of the Also, instead of using a decorator, would it be possible to add a |
It does, via Lines 525 to 535 in afa2f1e
Hmm, that's an interesting idea! And one that I really like, actually. I'm not competely sure -- it might end up having to touch JAX internals? (E.g. it might end up being morally equivalent to monkey-patching Okay, let's do some inside baseball -- let me just say a bit more about your idea of putting a A big part of why I like that idea so much is that right now it's a fairly complicated business to write a dispatch rule. Using LoRA as an example, we actually need our Quax rule to call back into a special version of in order to redispatch on the type of or to handle the possibility that This also ties in with a personal (private) project I've got, reimplementing + varying some of JAX's ideas. This does something very similar to your suggestion -- it has multiple dispatch as a first-class citizen of the bottom-of-stack evaluation trace, and in doing so actually manages to handle stuff like abstract evaluation (and maybe also JIT'ing?) as a special case of this single notion of evaluation. And this is pretty neat! Besides the above, I can see that you're also interested in this because it removes the need for the If you're interested in playing with this idea then I'd love to know what you find. And I'm definitely open to changing Quax in this way if there's a better design choice. Maybe we can do something interesting with this! |
I will probably try to add multiple dispatch to the autodidax tutorial instead of the actual JAX API. I was also curious about caching dispatch results (e.g. in the case of reparameterizations). @jakevdp Would adding a permanent |
We would probably not make any change of that scope without a more comprehensive design process: i.e. writing out the goals and non-goals, evaluating the space of possible approaches and their advantages and disadvantages, giving stakeholders time to provide feedback, and only then starting to write the code. |
A couple things going on here. First of all, the following is an example of
jnp.linalg.svd
failing to respect__jax_array__
.Remove the
disable_jit
and this works.The reason it works without
disable_jit
is thatjnp.linalg.svd
and friends all havejax.jit
wrappers, which is what spots the__jax_array__
and handles things appropriately... unless the JAX arraylike is also a PyTree, in which case they don't. So this also fails (with a different error message this time):So whilst it takes either a
disable_jit
or a PyTree to actually trigger it, I think the fundamental issue here is thatjnp.linalg.svd
and friends do not check for JAX arraylikes.The text was updated successfully, but these errors were encountered: