Skip to content

Commit

Permalink
Use sharded shape to compute aliasing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 433762389
  • Loading branch information
jax authors committed Mar 10, 2022
1 parent 97fbb3a commit 3c49cb5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
41 changes: 39 additions & 2 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,31 @@ def flatten_lowering_ir_args(
_module_unique_id = itertools.count()
_module_name_regex = re.compile(r"[^\w.-]")

def sharded_aval(aval: core.ShapedArray,
sharding: Optional[xc.OpSharding]) -> core.ShapedArray:
"""Returns the new aval sharded based on sharding proto."""
if sharding is None:
return aval

if (sharding.type == xc.OpSharding.Type.REPLICATED or
sharding.type == xc.OpSharding.Type.MANUAL):
return aval

sharded_shape = []
tile_rank = len(sharding.tile_assignment_dimensions)
if sharding.replicate_on_last_tile_dim:
tile_rank -= 1
if sharding.last_tile_dims:
tile_rank -= len(sharding.last_tile_dims)
if tile_rank == 0:
return aval

for i in range(tile_rank):
partitions = sharding.tile_assignment_dimensions[i]
assert partitions > 0
sharded_shape.append((aval.shape[i] + partitions - 1) // partitions)
return aval.update(tuple(sharded_shape))

def lower_jaxpr_to_module(
module_name: str, jaxpr: core.ClosedJaxpr, platform: str,
axis_context: AxisContext,
Expand All @@ -426,13 +451,25 @@ def lower_jaxpr_to_module(
Handles the quirks of the argument/return value passing conventions of the
runtime."""
input_output_aliases = None
in_avals = jaxpr.in_avals
if arg_shardings is not None:
in_avals = [
sharded_aval(in_aval, in_sharding)
for in_aval, in_sharding in zip(in_avals, arg_shardings)
]
out_avals = jaxpr.out_avals
if result_shardings is not None:
out_avals = [
sharded_aval(out_aval, out_sharding)
for out_aval, out_sharding in zip(out_avals, result_shardings)
]
platforms_with_donation = ("gpu", "tpu")
if platform in platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
jaxpr.in_avals, jaxpr.out_avals, donated_args)
in_avals, out_avals, donated_args)
if any(donated_args):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(a) for a, d in zip(jaxpr.in_avals, donated_args)
unused_donations = [str(a) for a, d in zip(in_avals, donated_args)
if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/notebooks/faq.html#buffer-donation."
if platform not in platforms_with_donation:
Expand Down
8 changes: 1 addition & 7 deletions tests/xmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,7 @@ def testBufferDonationNamedShape(self):
axis_resources=dict(axis_resources))
x = shard(jnp.zeros((4, 5)))
f(x)
if isinstance(self, SPMDTestMixin):
# The buffer should be deleted when using SPMD partitioner too, if this
# assertion starts failing then congratulations, you've fixed a bug!
# TODO(apaszke,tomhennigan): Xmap is possibly introducing an extra axis.
self.assertNotDeleted(x)
else:
self.assertDeleted(x)
self.assertDeleted(x)

def testControlFlow(self):
x = jnp.arange(5)
Expand Down

0 comments on commit 3c49cb5

Please sign in to comment.