Skip to content

Commit

Permalink
fix vmap-of-pmap mapped_invars logic bug
Browse files Browse the repository at this point in the history
fixes google#3399

This crept in via google#1959, but more importantly it shows we don't have
good test coverage here!
  • Loading branch information
mattjj committed Jun 14, 2020
1 parent b2105ab commit 29fa935
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
8 changes: 5 additions & 3 deletions jax/interpreters/batching.py
Expand Up @@ -161,10 +161,12 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
if all(dim is not_mapped for dim in dims):
return map_primitive.bind(f, *vals, **params)
else:
mapped_invars = params['mapped_invars']
size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
vals = [moveaxis(x, d, 1) if d is not not_mapped and d != 1 else x
for x, d in zip(vals, dims)]
dims = tuple(not_mapped if d is not_mapped else 0 for d in dims)
vals = [moveaxis(x, d, 1) if d == 0 and mapped_invar else x
for x, d, mapped_invar in zip(vals, dims, mapped_invars)]
dims = tuple(not_mapped if d is not_mapped else max(0, d - mapped_invar)
for d, mapped_invar in zip(dims, mapped_invars))
f, dims_out = batch_subtrace(f, self.master, dims)
vals_out = map_primitive.bind(f, *vals, **params)
dims_out = tuple(d + 1 if d is not not_mapped else d for d in dims_out())
Expand Down
24 changes: 23 additions & 1 deletion tests/pmap_test.py
Expand Up @@ -854,6 +854,29 @@ def s(keys):
ans = s(keys) # doesn't crash
self.assertEqual(ans.shape, (13, N_DEVICES))

def testVmapOfPmap3(self):
# https://github.com/google/jax/issues/3399
device_count = xla_bridge.device_count()
if device_count < 2:
raise SkipTest("test requires at least two devices")

def map_version(qs, pts):
return jax.lax.map(lambda x: func(x, pts), qs)

def vmap_version(qs, pts):
return jax.vmap(func, in_axes=(0, None))(qs, pts)

def func(q, pts):
q_from_pmap = jax.pmap(lambda x, y: y, in_axes=(0, None))(pts, q)
return q, q_from_pmap

pts = jnp.ones(device_count)
qs = jnp.asarray(((0,0), (3,3), (2,2)))

_, expected = map_version(qs, pts)
_, ans = vmap_version(qs, pts)
self.assertAllClose(ans, expected, check_dtypes=False)

def testVmapOfPmapNonLeadingAxis(self):
device_count = xla_bridge.device_count()
f0 = lambda x: x
Expand Down Expand Up @@ -1210,7 +1233,6 @@ def testPsumOnBooleanDtype(self):
out = pmap(lambda x: jax.lax.pmean(x, 'i'), 'i')(x)
self.assertEqual(list(out), [1])


class PmapWithDevicesTest(jtu.JaxTestCase):

def testAllDevices(self):
Expand Down

0 comments on commit 29fa935

Please sign in to comment.