Skip to content

Commit

Permalink
Add pytree support to pjit's with_sharding_constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
skye committed Mar 18, 2021
1 parent d326b07 commit 4a8c129
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
11 changes: 8 additions & 3 deletions jax/experimental/pjit.py
Expand Up @@ -162,9 +162,14 @@ def _pjit_callable(


def with_sharding_constraint(x, axis_resources):
resource_env = maps.thread_resources.env
return sharding_constraint_p.bind(x, axis_resources=axis_resources,
resource_env=resource_env)
x_flat, tree = tree_flatten(x)
axis_resources_flat = tuple(
flatten_axes("with_sharding_constraint axis_resources",
tree, axis_resources))
env = maps.thread_resources.env
outs = [sharding_constraint_p.bind(y, axis_resources=r, resource_env=env)
for y, r in safe_zip(x_flat, axis_resources_flat)]
return tree_unflatten(tree, outs)

def _sharding_constraint_impl(x, axis_resources, resource_env):
# TODO(skye): can we also prevent this from being called in other
Expand Down
26 changes: 26 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -174,6 +174,32 @@ def f(x):
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())

@with_mesh([('x', 2), ('y', 1)])
def testShardingConstraintPyTree(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x):
x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
x = x.copy()
x[0]["a"] *= 2
return x

shape = (8, 8)
v = np.arange(prod(shape)).reshape(shape)
x = [{"a": v, "b": v * 2}, v * 3]
actual = f(x)

expected = x.copy()
expected[0]["a"] *= 2
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertLen(actual[0]["a"].device_buffers, 2)

hlo = jax.xla_computation(f)(x)
# Annotations from with_sharding_constraint
self.assertIn("sharding={devices=[2,1]0,1}", hlo.as_hlo_text())
self.assertIn("sharding={devices=[1,2]0,1}", hlo.as_hlo_text())
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())


# TODO(skye): add more unit tests once API is more finalized

Expand Down

0 comments on commit 4a8c129

Please sign in to comment.