Skip to content

Commit

Permalink
Fix the bug where XLA doesn't provide shardings for all the outputs i…
Browse files Browse the repository at this point in the history
…f all the elements in the output tuple have the same sharding. XLA decides to run the `FusionTupleDeduplicator` to put the sharding on ROOT instead of the tuple.

PiperOrigin-RevId: 477343328
  • Loading branch information
yashk2810 authored and jax authors committed Sep 28, 2022
1 parent c8bff11 commit 933b6a2
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions jax/interpreters/pxla.py
Expand Up @@ -3126,8 +3126,16 @@ def _get_op_sharding_shardings_from_executable(

in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)

return ([OpShardingSharding(device_assignment, i) for i in in_op_shardings],
[OpShardingSharding(device_assignment, o) for o in out_op_shardings])
in_shardings_xla = [OpShardingSharding(device_assignment, i) for i in in_op_shardings]
out_shardings_xla = [OpShardingSharding(device_assignment, o) for o in out_op_shardings]
# This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
# put the sharding on ROOT instead of the tuple.
# TODO(b/245667823): Remove this when XLA fixes this.
if len(out_shardings_xla) == 1 and len(out_shardings_xla) < num_out_avals:
out_shardings_xla = out_shardings_xla * num_out_avals
return in_shardings_xla, out_shardings_xla



# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
Expand Down

0 comments on commit 933b6a2

Please sign in to comment.