-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
I was trying to make a single function that supports CPU and non-CPU, and @froystig suggested I use jax.lax.platform_dependent. However, it doesn't seem to actually do the job (with pallas at least?)
from functools import partial
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.megablox import gmm
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (128, 128))
weights = jax.random.normal(key, (3, 2, 128, 128))
group_sizes = jnp.array([60, 68])
def f(x, weight):
out = jax.lax.platform_dependent(
x, weight, group_sizes,
tpu=partial(gmm, interpret=False),
default=partial(gmm, interpret=True)
)
return out, out
out = jax.lax.scan(f, x, weights)
print(out)
...
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/dlwh/src/marin/tests/platform_dep.py", line 21, in <module>
out = jax.lax.scan(f, x, weights)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/dlwh/src/marin/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1300, in _pallas_call_lowering
return mlir.lower_per_platform(ctx, "pallas_call",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/dlwh/src/marin/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1264, in cpu_lowering
raise ValueError("Only interpret mode is supported on CPU backend.")
ValueError: Only interpret mode is supported on CPU backend.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System info (python version, jaxlib version, accelerator, etc.)
>>> import jax; jax.print_environment_info()
jax: 0.6.0
jaxlib: 0.6.0
numpy: 1.26.4
python: 3.11.11 (main, Dec 3 2024, 17:20:40) [Clang 16.0.0 (clang-1600.0.26.4)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='MacBook-Pro-57.local', release='24.4.0', version='Darwin Kernel Version 24.4.0: Fri Apr 11 18:33:47 PDT 2025; root:xnu-11417.101.15~117/RELEASE_ARM64_T6000', machine='arm64')
Metadata
Metadata
Labels
bugSomething isn't workingSomething isn't working