Skip to content

Commit

Permalink
Async dispatch expensive computations on the JAX CPU backend.
Browse files Browse the repository at this point in the history
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.

PiperOrigin-RevId: 633264117
  • Loading branch information
yueshengys authored and jax authors committed May 13, 2024
1 parent 54d4072 commit 9e7830d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.29

* Changes
* Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.

## jaxlib 0.4.29

## jax 0.4.28 (May 9, 2024)
Expand Down
4 changes: 1 addition & 3 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,9 @@
'"gloo" or "mpi"'
)

# TODO(yueshengys): turn default back to True after resolving memory increase
# issue.
_CPU_ENABLE_ASYNC_DISPATCH = config.DEFINE_bool(
name="jax_cpu_enable_async_dispatch",
default=False,
default=True,
help="Only applies to non-parallel computations. If False, run computations"
"inline without async dispatch.",
)
Expand Down

0 comments on commit 9e7830d

Please sign in to comment.