-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot bind to primitive Zero(AbstractToken()) #16303
Comments
After diving deep into jax's internal, I managed to devise this possible fix: Every time you try to construct a so simply editing this from jax._src.core import (AbstractToken, Token)
def read_cotangent(v):
if isinstance(v.aval, AbstractToken):
return ct_env.pop(v, Token())
else:
return ct_env.pop(v, Zero(v.aval)) fixes my reproducer above. Of course this might not be semantically correct but I'm sure you know more about it... |
@mattjj if the Issue is not very clear please let me know, and I can try to further clarify it. |
Thanks for the ping! I managed not to notice until just now.
Can you say more, and/or share a reproducer? If these are JAX tokens then they should be returnable from jitted functions, otherwise that's a JAX bug (even if it's not the main bug you're talking about). As for the main issue, I understand the general outline, but I need to look at mpi4jax more closely, or alternatively set up a toy model, to understand better. I have two gut reactions:
In particular I don't think the fix in this comment is on the right track, unfortunately.
Where is the token in this example? |
Thanks for answering! In the example I shared above the token is automatically generated by mpi4jax, but let me share an example that is more clear. The reproducer is the following: import jax
import jax.numpy as jnp
import mpi4jax
from mpi4py import MPI
def f(a, b):
token_a = jax.lax.create_token()
c, token_b = mpi4jax.allreduce(a, MPI.SUM, token=token_a)
d, token_c = mpi4jax.allreduce(b, MPI.SUM, token=token_b)
e = c+d
return d
x = jnp.ones(1)
y = jnp.ones(1)
r = f(x, y)
# jax.make_jaxpr(f)(x, y)
# jax.make_jaxpr(jax.linear_transpose(f, x, y))(r)
jax.linear_transpose(f, x, y)(r) Let me comment on what is going on in here by inspecting the jaxpr: In [4]: jax.make_jaxpr(f)(x,y)
Out[4]:
{ lambda ; a:f32[1] b:f32[1]. let
c:Tok = create_token
d:f32[1] e:Tok = allreduce_mpi[
comm=<mpi4jax._src.utils.HashableMPIType object at 0x13f669b50>
op=<mpi4jax._src.utils.HashableMPIType object at 0x13f668fd0>
transpose=False
] a c
f:f32[1] _:Tok = allreduce_mpi[
comm=<mpi4jax._src.utils.HashableMPIType object at 0x13f509710>
op=<mpi4jax._src.utils.HashableMPIType object at 0x13f0e03d0>
transpose=False
] b e
_:f32[1] = add d f
in (f,) } You can see that I have two calls to the primitive Now, what would be the correct transposition of this jaxpr? def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm,):
_, token = x_args
x_tan, token_tan = tan_args
res, token = mpi_allreduce_transpose_p.bind(
x_tan, token, op=op, comm=comm,
)
return res, token_tan notice that I bind the primal token instead of the tangent token. Is this correct? File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/core.py:1326, in concrete_aval(x)
1324 if hasattr(x, '__jax_array__'):
1325 return concrete_aval(x.__jax_array__())
-> 1326 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
1327 "type")
TypeError: Value UndefinedPrimal(AbstractToken()) with type <class 'jax._src.interpreters.ad.UndefinedPrimal'> is not a valid JAX type Another reason suggesting me that I should bind the tangent of the token instead of the primal token here is that I would like to get in the linear transposition an execution order that is reverted, which I only get by binding the tangent token to the tangent primitive. So in the branch def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm):
_, _ = x_args
x_tan, token_tan = tan_args
res, token = mpi_allreduce_transpose_p.bind(
x_tan, token_tan, op=op, comm=comm,
)
return res, token but this fails as well with the error I shared in the original post, namely File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/core.py:1326, in concrete_aval(x)
1324 if hasattr(x, '__jax_array__'):
1325 return concrete_aval(x.__jax_array__())
-> 1326 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
1327 "type")
TypeError: Value Zero(AbstractToken()) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type |
Apparently I was not up to date, and it seems that it is now possible to return tokens (I remember about a year ago it was not possible). import jax
import jax.numpy as jnp
import mpi4jax
from mpi4py import MPI
def f(a):
token_a = jax.lax.create_token()
b, token_b = mpi4jax.allreduce(a, MPI.SUM, token=token_a)
return b, token_b
x = jnp.ones(1)
r,s = f(x)
jax.make_jaxpr(f)(x)
jax.make_jaxpr(jax.linear_transpose(f, x))(r) that fails with File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/dtypes.py:530, in dtype(x, canonicalize)
528 dt = np.result_type(x)
529 except TypeError as err:
--> 530 raise TypeError(f"Cannot determine dtype of {x}") from err
531 if dt not in _jax_dtype_set and not is_opaque_dtype(dt):
532 raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
533 "type. Only arrays of numeric types are supported by JAX.")
TypeError: Cannot determine dtype of AbstractToken() Though that's a different bug from what I originally reported and I'm not so worried about this one because tokens usually remain inside the jitted functions... |
I don't mind at all! Thanks for the detailed repro and info. I'll take a look... |
I haven't had a chance yet :/ I expect I can in the next 48 hours or so. |
thanks for the update! looking forwards for a reply |
@mattjj any luck? Can I do anything to help you nail this problem down? |
@mattjj pretty please 🥹? |
Description
In
mpi4jax
we make heavy use of tokens to prevent XLA to reorder our MPI calls, which is particularly a problem on XLA:GPU.A standard function would look something like
and we do not return the token because in general
jax.jit
functions cannot return them, but still, a strong order is enforced within the compiled function because of the tokens.When we transpose our functions with
jax.linear_transpose
, we expect also the transpose function to enforce a strong ordering in the reverse order, for example:However, when mpi4jax attempts to bind the transposed token
Zero(AbstractToken)
an error is raised saying that it cannot be binded because XLA does not know how to represent it.I suspect that the correct way to treat the
Zero(AbstractToken)
should be exactly like a standard token, such thattc_t
in the example above isZero(AbstractToken)
you follow the same path as when binding a normal tokenan example code of how this impact mpi4jax can be had by installing the branch
pv/fix-token
by for example runningpip install git+https://github.com/mpi4jax/mpi4jax.git@pv/fix-token
and then using the following MWE:
which raises the error:
What jax/jaxlib version are you using?
jax 0.4.11 jaxlib 0.4.11
Which accelerator(s) are you using?
CPU
Additional system info
MacOs M1
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: