Skip to content

Commit

Permalink
Rollforward of Add a fastpath to pmap_lib for sharding np.ndarray dir…
Browse files Browse the repository at this point in the history
…ectly in c++.

py::array::ensure(arg) was not a strict enough check and scalars were matching.
PiperOrigin-RevId: 513322424
  • Loading branch information
pschuh authored and jax authors committed Mar 1, 2023
1 parent c73cc49 commit 3abae68
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,12 @@ def test_cache_uses_jax_key(self):
pmaped_f(inputs)
self.assertEqual(pmaped_f._cache_size, 1)

def test_constants_fallback(self):
fn = pmap(lambda x, y: x + y, in_axes=(0, None))

for _ in range(2):
fn(np.zeros((jax.device_count(), 5), dtype=np.float32), 2.0)


@jtu.pytest_mark_if_available('multiaccelerator')
class VmapOfPmapTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 3abae68

Please sign in to comment.