diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f7ff209bde39..73642ee1db6e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -486,30 +486,6 @@ def f(x): # Annotation from pjit self.assertIn('sharding = "{replicated}"', str(hlo)) - @jtu.with_mesh([('x', 2), ('y', 1)]) - def testShardingConstraint(self): - @partial(pjit, in_shardings=None, out_shardings=None) - def f(x): - y = x + 1 - y = with_sharding_constraint(y, P('x', 'y')) - return y * 2 - - shape = (8, 8) - x = np.arange(math.prod(shape)).reshape(shape) - expected = (x + 1) * 2 - actual = f(x) - self.assertAllClose(actual, expected, check_dtypes=False) - _check_instance(self, actual) - self.assertLen(actual.addressable_shards, 2) - self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected, - check_dtypes=False) - - hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo") - # Annotation from with_sharding_constraint - self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text()) - # Annotation from pjit - self.assertIn("sharding={replicated}", hlo.as_hlo_text()) - def testShardingConstraintWithArray(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) @@ -562,60 +538,27 @@ def f(x): # Annotation from pjit self.assertIn("sharding={replicated}", hlo.as_hlo_text()) - @jtu.with_mesh([('x', 2), ('y', 1)]) - def testShardingConstraintPyTree(self): - @partial(pjit, in_shardings=None, out_shardings=None) - def f(x): - x = jax.lax.with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')]) - x = x.copy() - x[0]["a"] *= 2 - return x - - shape = (8, 8) - v = np.arange(math.prod(shape)).reshape(shape) - x = [{"a": v, "b": v * 2}, v * 3] - actual = f(x) - - expected = x.copy() - expected[0]["a"] *= 2 - self.assertAllClose(actual, expected, check_dtypes=False) - self.assertLen(actual[0]["a"].addressable_shards, 2) - - hlo = f.lower(x).compiler_ir(dialect="hlo") - # Annotations from with_sharding_constraint - self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text()) - self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text()) - # Annotation from pjit - self.assertIn("sharding={replicated}", hlo.as_hlo_text()) - def testShardingConstraintPyTreeWithArray(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) - s = NamedSharding(mesh, P(None)) - @partial(pjit, in_shardings=s, out_shardings=s) + @jax.jit def f(x): - x = with_sharding_constraint(x, [ - NamedSharding(mesh, P('x', 'y')), - NamedSharding(mesh, P('y', 'x')) - ]) - x = x.copy() - x[0]["a"] *= 2 - return x + return with_sharding_constraint(x, NamedSharding(mesh, P('x', 'y'))) shape = (8, 8) v = np.arange(math.prod(shape)).reshape(shape) - x = [{"a": v, "b": v * 2}, v * 3] - actual = f(x) + x = [v, v * 2] + out = f(x) - expected = x.copy() - expected[0]["a"] *= 2 - self.assertAllClose(actual, expected, check_dtypes=False) - self.assertLen(actual[0]["a"].addressable_shards, 2) + self.assertArraysEqual(out[0], v) + self.assertArraysEqual(out[1], v * 2) + self.assertLen(out[0].addressable_shards, 2) + self.assertLen(out[1].addressable_shards, 2) hlo = f.lower(x).compiler_ir(dialect="hlo") # Annotations from with_sharding_constraint self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text()) - self.assertIn('sharding={devices=[1,2]<=[2]}', hlo.as_hlo_text()) + self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text()) # Annotation from pjit self.assertIn("sharding={replicated}", hlo.as_hlo_text())