Skip to content

Commit

Permalink
Reverts 49bd4d6
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633221195
  • Loading branch information
cheshire authored and jax authors committed May 13, 2024
1 parent e66a234 commit de14e3b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 30 deletions.
86 changes: 57 additions & 29 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,51 +20,64 @@
import itertools
import operator
from typing import Any, Callable, TypeVar
import weakref

import jax
import weakref
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map, tree_flatten_with_path, keystr)
from jax._src.api_util import shaped_abstractify
from jax._src.tree_util import equality_errors
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.api_util import shaped_abstractify
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lax.control_flow.common import (
_abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src import state
from jax._src.state import discharge as state_discharge
from jax._src.numpy.ufuncs import logaddexp
from jax._src.state import discharge as state_discharge
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import equality_errors
from jax._src.typing import Array
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
split_list_checked, unzip2, weakref_lru_cache,
merge_lists)
from jax._src.util import (
merge_lists,
partition_list,
safe_map,
safe_zip,
split_list,
split_list_checked,
unzip2,
weakref_lru_cache,
)
from jax.tree_util import (
keystr,
tree_flatten,
tree_flatten_with_path,
tree_map,
tree_unflatten,
treedef_is_leaf,
)
import numpy as np

from jax._src.lax.control_flow.common import (
_abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)

_map = safe_map
zip = safe_zip

Expand Down Expand Up @@ -2434,16 +2447,31 @@ def register_lowering(fn, platform=None):
mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)),
platform=platform)

# Default for platforms not treated specially below.
register_lowering(partial(associative_scan, reduce_fn))
# On GPU, we choose between window reduction and associative scan
# based on the input size.
for platform in ['cuda', 'rocm']:
if xla_extension_version >= 263:
# In XLA, there's a rewriter for an O(N^2) reduce-window implementation.
for platform in ['cuda', 'rocm', 'tpu']:
register_lowering(
partial(cumred_reduce_window_impl, reduce_window_fn), platform
)

# TODO(https://github.com/llvm/llvm-project/issues/91883) Re-enable rewrite
# for CPU once the vectorizer crash is fixed..
register_lowering(partial(associative_scan, reduce_fn), 'cpu')
else:
# Older XLA versions only have this rewrite for TPU.
register_lowering(
partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform)
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
register_lowering(partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu')
partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu'
)
# Default for platforms not treated specially below.
register_lowering(partial(associative_scan, reduce_fn))

# On GPU, we choose between window reduction and associative scan
# based on the input size.
for platform in ['cuda', 'rocm']:
register_lowering(
partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform
)

return reducer_p

cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, windowed_reductions._reduce_window_sum)
Expand Down
2 changes: 1 addition & 1 deletion tests/for_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def body(i, refs):

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (8,))
np.testing.assert_allclose(cumsum(x), jnp.cumsum(x))
np.testing.assert_allclose(cumsum(x), jnp.cumsum(x), rtol=1e-6)

def for_body_swap(i, refs):
a_ref, b_ref = refs
Expand Down

0 comments on commit de14e3b

Please sign in to comment.