Skip to content

Commit

Permalink
Fix the pjit flakey test. The test was weirdly written in the first p…
Browse files Browse the repository at this point in the history
…lace. The current suspicion is that `x.copy()` inside jit made it flakey. Also delete some duplicate tests from the time of migrating to jax.Array

PiperOrigin-RevId: 590692653
  • Loading branch information
yashk2810 authored and jax authors committed Dec 13, 2023
1 parent 2e63352 commit e888806
Showing 1 changed file with 9 additions and 66 deletions.
75 changes: 9 additions & 66 deletions tests/pjit_test.py
Expand Up @@ -486,30 +486,6 @@ def f(x):
# Annotation from pjit
self.assertIn('sharding = "{replicated}"', str(hlo))

@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingConstraint(self):
@partial(pjit, in_shardings=None, out_shardings=None)
def f(x):
y = x + 1
y = with_sharding_constraint(y, P('x', 'y'))
return y * 2

shape = (8, 8)
x = np.arange(math.prod(shape)).reshape(shape)
expected = (x + 1) * 2
actual = f(x)
self.assertAllClose(actual, expected, check_dtypes=False)
_check_instance(self, actual)
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected,
check_dtypes=False)

hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
# Annotation from with_sharding_constraint
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())

def testShardingConstraintWithArray(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
s = NamedSharding(mesh, P(None))
Expand Down Expand Up @@ -562,60 +538,27 @@ def f(x):
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())

@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingConstraintPyTree(self):
@partial(pjit, in_shardings=None, out_shardings=None)
def f(x):
x = jax.lax.with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
x = x.copy()
x[0]["a"] *= 2
return x

shape = (8, 8)
v = np.arange(math.prod(shape)).reshape(shape)
x = [{"a": v, "b": v * 2}, v * 3]
actual = f(x)

expected = x.copy()
expected[0]["a"] *= 2
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertLen(actual[0]["a"].addressable_shards, 2)

hlo = f.lower(x).compiler_ir(dialect="hlo")
# Annotations from with_sharding_constraint
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())

def testShardingConstraintPyTreeWithArray(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
s = NamedSharding(mesh, P(None))

@partial(pjit, in_shardings=s, out_shardings=s)
@jax.jit
def f(x):
x = with_sharding_constraint(x, [
NamedSharding(mesh, P('x', 'y')),
NamedSharding(mesh, P('y', 'x'))
])
x = x.copy()
x[0]["a"] *= 2
return x
return with_sharding_constraint(x, NamedSharding(mesh, P('x', 'y')))

shape = (8, 8)
v = np.arange(math.prod(shape)).reshape(shape)
x = [{"a": v, "b": v * 2}, v * 3]
actual = f(x)
x = [v, v * 2]
out = f(x)

expected = x.copy()
expected[0]["a"] *= 2
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertLen(actual[0]["a"].addressable_shards, 2)
self.assertArraysEqual(out[0], v)
self.assertArraysEqual(out[1], v * 2)
self.assertLen(out[0].addressable_shards, 2)
self.assertLen(out[1].addressable_shards, 2)

hlo = f.lower(x).compiler_ir(dialect="hlo")
# Annotations from with_sharding_constraint
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
self.assertIn('sharding={devices=[1,2]<=[2]}', hlo.as_hlo_text())
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())

Expand Down

0 comments on commit e888806

Please sign in to comment.