Skip to content

jax.lax.platform_dependent doesn't stop Pallas from trying to lower for other backends? #28594

@dlwh

Description

@dlwh

Description

I was trying to make a single function that supports CPU and non-CPU, and @froystig suggested I use jax.lax.platform_dependent. However, it doesn't seem to actually do the job (with pallas at least?)

from functools import partial

import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.megablox import gmm

key = jax.random.PRNGKey(42)

x = jax.random.normal(key, (128, 128))
weights = jax.random.normal(key, (3, 2, 128, 128))
group_sizes = jnp.array([60, 68])

def f(x, weight):
    out = jax.lax.platform_dependent(
        x, weight, group_sizes,
        tpu=partial(gmm, interpret=False),
        default=partial(gmm, interpret=True)
    )
    return out, out

out = jax.lax.scan(f, x, weights)

print(out)
...


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/dlwh/src/marin/tests/platform_dep.py", line 21, in <module>
    out = jax.lax.scan(f, x, weights)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/dlwh/src/marin/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1300, in _pallas_call_lowering
    return mlir.lower_per_platform(ctx, "pallas_call",
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/dlwh/src/marin/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1264, in cpu_lowering
    raise ValueError("Only interpret mode is supported on CPU backend.")
ValueError: Only interpret mode is supported on CPU backend.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.6.0
jaxlib: 0.6.0
numpy:  1.26.4
python: 3.11.11 (main, Dec  3 2024, 17:20:40) [Clang 16.0.0 (clang-1600.0.26.4)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='MacBook-Pro-57.local', release='24.4.0', version='Darwin Kernel Version 24.4.0: Fri Apr 11 18:33:47 PDT 2025; root:xnu-11417.101.15~117/RELEASE_ARM64_T6000', machine='arm64')

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions