Skip to content

Commit

Permalink
Avoid tuple allreduce lowering of psum on TPUs (#2914)
Browse files Browse the repository at this point in the history
Tuple-shaped allreduces aren't supported in an XLA:TPU optimization pass (see internal bug), but since our use of them on GPU is due to compiler nondeterminism that isn't present on TPU, it should be fine to avoid this bug by disabling tuple psum on TPU.
  • Loading branch information
jekbradbury committed May 1, 2020
1 parent 25e8280 commit 279a077
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions jax/lax/lax_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def _allreduce_translation_rule(prim, c, val, replica_groups, platform=None):

# psum translation rule has special handling for complex dtypes
def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
if platform == "cpu":
return _cpu_psum_translation_rule(c, *args, replica_groups=replica_groups)
if platform in ("cpu", "tpu"):
return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups)

# XLA's tuple all-reduce doesn't support different dtypes in the same
# allreduce. Instead, we perform once all-reduce for each argument input type.
Expand Down Expand Up @@ -307,7 +307,9 @@ def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
# TODO(b/150476027): CPU doesn't support tuple all-reduce correctly. But
# fortunately we don't really need it in that case because CPU doesn't support
# cross-task communication either.
def _cpu_psum_translation_rule(c, *args, replica_groups):
# TODO(b/155446630): An XLA:TPU optimization pass also doesn't support
# tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior.
def _notuple_psum_translation_rule(c, *args, replica_groups):
def _translate(val):
psum = partial(_allreduce_translation_rule, lax.add_p, c,
replica_groups=replica_groups)
Expand Down

0 comments on commit 279a077

Please sign in to comment.