Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot bind to primitive Zero(AbstractToken()) #16303

Open
PhilipVinc opened this issue Jun 7, 2023 · 10 comments
Open

Cannot bind to primitive Zero(AbstractToken()) #16303

PhilipVinc opened this issue Jun 7, 2023 · 10 comments
Assignees
Labels
bug Something isn't working

Comments

@PhilipVinc
Copy link
Contributor

Description

In mpi4jax we make heavy use of tokens to prevent XLA to reorder our MPI calls, which is particularly a problem on XLA:GPU.

A standard function would look something like

f(a):
  ta = create_token()
  b, tb = mpi_fun(a, token=ta)
  c, tc = mpi_fun2(a, token=tb)
  d = b+c
  return d

and we do not return the token because in general jax.jit functions cannot return them, but still, a strong order is enforced within the compiled function because of the tokens.

When we transpose our functions with jax.linear_transpose, we expect also the transpose function to enforce a strong ordering in the reverse order, for example:

f_t = jax.linear_transpose(f)
# should match roughly
def f_t(d_t):
  b_t = d_t , c_t = d_t
  tc_t = Zero(AbstractToken) # automatically there because we don-t return tc
  a_t, tb_t = mpi_fun2_transpose(c_t, token=tc_t)
  a_t2, ta_t = mpi_fun(b_t, token= tb_t)
  _ = create_token_transpose(ta_t) 
  return a_t + a_t2

However, when mpi4jax attempts to bind the transposed token Zero(AbstractToken) an error is raised saying that it cannot be binded because XLA does not know how to represent it.

I suspect that the correct way to treat the Zero(AbstractToken) should be exactly like a standard token, such that tc_t in the example above is

  • either a normal token that is passed around to enforce ordering in the transposed program.
  • either when binding a Zero(AbstractToken) you follow the same path as when binding a normal token

an example code of how this impact mpi4jax can be had by installing the branch pv/fix-token by for example running
pip install git+https://github.com/mpi4jax/mpi4jax.git@pv/fix-token
and then using the following MWE:

from mpi4py import MPI

import jax
import jax.numpy as jnp

import numpy as np

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()


from mpi4jax import allreduce

arr = jnp.ones((3, 2))
_arr = arr.copy()

def f(x):
    (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(x)
    return res

res = jax.jit(f)(arr)

which raises the error:

TypeError: Argument 'Zero(AbstractToken())' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type
Open for the whole stack trace
~/Dropbox/Ricerca/Codes/Python/mpi4jax pv/fix-token 35s
python-3.11.2python ex.py
Traceback (most recent call last):
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 39, in <module>
    arr = jnp.ones((3, 2))
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line -1, in test_allreduce_transpose
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line -1, in <lambda>
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/mpi4jax/_src/collective_ops/allreduce.py", line -1, in allreduce
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Argument 'Zero(AbstractToken())' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 22, in <module>
    test_allreduce_transpose()
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 19, in test_allreduce_transpose
    (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(_arr)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/api.py", line 2270, in transposed_fun
    in_cts = ad.backward_pass(jaxpr, reduce_axes, True, const, dummies, out_cts)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 253, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/mpi4jax/_src/collective_ops/allreduce.py", line 209, in mpi_allreduce_transpose_rule
    res, token = mpi_allreduce_p.bind(
                 ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/dispatch.py", line 131, in apply_primitive
    in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/dispatch.py", line 131, in <listcomp>
    in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
                                          ^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/dispatch.py", line 102, in arg_spec
    aval = xla.abstractify(x)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Documents/pythonenvs/mpi4jax/python-3.11.2/lib/python3.11/site-packages/jax/_src/interpreters/xla.py", line 200, in abstractify
    raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Argument 'Zero(AbstractToken())' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 22, in <module>
    test_allreduce_transpose()
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/ex.py", line 19, in test_allreduce_transpose
    (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(_arr)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/filippo.vicentini/Dropbox/Ricerca/Codes/Python/mpi4jax/mpi4jax/_src/collective_ops/allreduce.py", line 209, in mpi_allreduce_transpose_rule
    res, token = mpi_allreduce_p.bind(
                 ^^^^^^^^^^^^^^^^^^^^^
TypeError: Argument 'Zero(AbstractToken())' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type

What jax/jaxlib version are you using?

jax 0.4.11 jaxlib 0.4.11

Which accelerator(s) are you using?

CPU

Additional system info

MacOs M1

NVIDIA GPU info

No response

@PhilipVinc PhilipVinc added the bug Something isn't working label Jun 7, 2023
@PhilipVinc
Copy link
Contributor Author

PhilipVinc commented Jun 7, 2023

After diving deep into jax's internal, I managed to devise this possible fix: Every time you try to construct a Zero(AbstractToken) build a token instead.

so simply editing this read_cotangent function to

  from jax._src.core import (AbstractToken, Token)

  def read_cotangent(v):
    if isinstance(v.aval, AbstractToken):
      return ct_env.pop(v, Token())
    else:
      return ct_env.pop(v, Zero(v.aval))

fixes my reproducer above.

Of course this might not be semantically correct but I'm sure you know more about it...

@PhilipVinc
Copy link
Contributor Author

@mattjj if the Issue is not very clear please let me know, and I can try to further clarify it.
This is a big blocker for us..

@mattjj
Copy link
Member

mattjj commented Jun 10, 2023

Thanks for the ping! I managed not to notice until just now.

because in general jax.jit functions cannot return them

Can you say more, and/or share a reproducer? If these are JAX tokens then they should be returnable from jitted functions, otherwise that's a JAX bug (even if it's not the main bug you're talking about).

As for the main issue, I understand the general outline, but I need to look at mpi4jax more closely, or alternatively set up a toy model, to understand better. I have two gut reactions:

  1. taking a narrow pigeon-holed view, anywhere you see a symbolic zero Zero(AbstractToken), i.e. in a JVP or transpose rule (not in ad.py's backward_pass), you probably want to instantiate it so that it's no longer symbolic; but in the bigger picture...
  2. I don't think we want to rely on tangents-of-tokens to be token-like at all, since throughout JAX's AD system we assume tangent types are vector-space-like, in particular in that they have zero elements which have the behavior that any linear function applied to them is zero.

In particular I don't think the fix in this comment is on the right track, unfortunately.

an example code of how this impact mpi4jax can be had by installing the branch pv/fix-token by for example running pip install git+https://github.com/mpi4jax/mpi4jax.git@pv/fix-token and then using the following MWE:

Where is the token in this example?

@PhilipVinc
Copy link
Contributor Author

Thanks for answering!

In the example I shared above the token is automatically generated by mpi4jax, but let me share an example that is more clear.
I hope you don't mind installing mpi4jax (unfortunately tokens are used nowhere in jax so I can't build a reproducer there.

The reproducer is the following:

import jax
import jax.numpy as jnp
import mpi4jax
from mpi4py import MPI

def f(a, b):
    token_a = jax.lax.create_token()
    c, token_b = mpi4jax.allreduce(a, MPI.SUM, token=token_a)
    d, token_c = mpi4jax.allreduce(b, MPI.SUM, token=token_b)
    e = c+d
    return d

x = jnp.ones(1)
y = jnp.ones(1)
r = f(x, y)

# jax.make_jaxpr(f)(x, y)

# jax.make_jaxpr(jax.linear_transpose(f, x, y))(r)
jax.linear_transpose(f, x, y)(r)

Let me comment on what is going on in here by inspecting the jaxpr:

In [4]: jax.make_jaxpr(f)(x,y)
Out[4]:
{ lambda ; a:f32[1] b:f32[1]. let
    c:Tok = create_token
    d:f32[1] e:Tok = allreduce_mpi[
      comm=<mpi4jax._src.utils.HashableMPIType object at 0x13f669b50>
      op=<mpi4jax._src.utils.HashableMPIType object at 0x13f668fd0>
      transpose=False
    ] a c
    f:f32[1] _:Tok = allreduce_mpi[
      comm=<mpi4jax._src.utils.HashableMPIType object at 0x13f509710>
      op=<mpi4jax._src.utils.HashableMPIType object at 0x13f0e03d0>
      transpose=False
    ] b e
    _:f32[1] = add d f
  in (f,) }

You can see that I have two calls to the primitive allreduce_mpi, which is defined in here. This primitive takes two inputs: the array to be reduced and a token to prevent reordering.

Now, what would be the correct transposition of this jaxpr?
I would assume is the execution in reverse of the operations.
The transposition rule is defined here for master and it is essentially:

def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm,):
    _, token = x_args
    x_tan, token_tan = tan_args

    res, token = mpi_allreduce_transpose_p.bind(
        x_tan, token, op=op, comm=comm,
    )
    return res, token_tan

notice that I bind the primal token instead of the tangent token. Is this correct?
It seems not, as this fails with error

File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/core.py:1326, in concrete_aval(x)
   1324 if hasattr(x, '__jax_array__'):
   1325   return concrete_aval(x.__jax_array__())
-> 1326 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
   1327                  "type")

TypeError: Value UndefinedPrimal(AbstractToken()) with type <class 'jax._src.interpreters.ad.UndefinedPrimal'> is not a valid JAX type

Another reason suggesting me that I should bind the tangent of the token instead of the primal token here is that I would like to get in the linear transposition an execution order that is reverted, which I only get by binding the tangent token to the tangent primitive.
Does this make sense?

So in the branch mpi4jax@pv/fix-token I tried to modify the transposition rule to read

def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm):
    _, _ = x_args
    x_tan, token_tan = tan_args

    res, token = mpi_allreduce_transpose_p.bind(
        x_tan, token_tan, op=op, comm=comm,
    )
    return res, token

but this fails as well with the error I shared in the original post, namely

File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/core.py:1326, in concrete_aval(x)
   1324 if hasattr(x, '__jax_array__'):
   1325   return concrete_aval(x.__jax_array__())
-> 1326 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
   1327                  "type")

TypeError: Value Zero(AbstractToken()) with type <class 'jax._src.ad_util.Zero'> is not a valid JAX type

@PhilipVinc
Copy link
Contributor Author

because in general jax.jit functions cannot return them?

Can you say more, and/or share a reproducer? If these are JAX tokens then they should be returnable from jitted functions, otherwise that's a JAX bug (even if it's not the main bug you're talking about).

Apparently I was not up to date, and it seems that it is now possible to return tokens (I remember about a year ago it was not possible).
However it will still error if you try to transpose a token:

import jax
import jax.numpy as jnp
import mpi4jax
from mpi4py import MPI

def f(a):
    token_a = jax.lax.create_token()
    b, token_b = mpi4jax.allreduce(a, MPI.SUM, token=token_a)
    return b, token_b

x = jnp.ones(1)
r,s = f(x)

jax.make_jaxpr(f)(x)

jax.make_jaxpr(jax.linear_transpose(f, x))(r)

that fails with

File ~/Documents/pythonenvs/mpi4jax/python-3.11.1/lib/python3.11/site-packages/jax/_src/dtypes.py:530, in dtype(x, canonicalize)
    528     dt = np.result_type(x)
    529   except TypeError as err:
--> 530     raise TypeError(f"Cannot determine dtype of {x}") from err
    531 if dt not in _jax_dtype_set and not is_opaque_dtype(dt):
    532   raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
    533                   "type. Only arrays of numeric types are supported by JAX.")

TypeError: Cannot determine dtype of AbstractToken()

Though that's a different bug from what I originally reported and I'm not so worried about this one because tokens usually remain inside the jitted functions...

@mattjj
Copy link
Member

mattjj commented Jun 10, 2023

I hope you don't mind installing mpi4jax

I don't mind at all!

Thanks for the detailed repro and info. I'll take a look...

@mattjj
Copy link
Member

mattjj commented Jun 12, 2023

I haven't had a chance yet :/ I expect I can in the next 48 hours or so.

@PhilipVinc
Copy link
Contributor Author

thanks for the update! looking forwards for a reply

@PhilipVinc
Copy link
Contributor Author

@mattjj any luck? Can I do anything to help you nail this problem down?

@PhilipVinc
Copy link
Contributor Author

@mattjj pretty please 🥹?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants