Skip to content

Commit

Permalink
fix benchmark sums (#4329)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 18, 2020
1 parent 2911bcd commit 6a89f60
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions benchmarks/pmap_benchmark.py
Expand Up @@ -39,7 +39,7 @@ def pmap_shard_sharded_device_array_benchmark():
"""

def get_benchmark_fn(nargs, nshards):
pmap_fn = pmap(lambda *args: jnp.sum(args))
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
shape = (nshards, 4)
args = [np.random.random(shape) for _ in range(nargs)]
sharded_args = pmap(lambda x: x)(args)
Expand Down Expand Up @@ -69,7 +69,7 @@ def pmap_shard_device_array_benchmark():
"""

def get_benchmark_fn(nargs, nshards):
pmap_fn = pmap(lambda *args: jnp.sum(args))
pmap_fn = pmap(lambda *args: jnp.sum(jnp.array(args)))
shape = (nshards, 4)
args = [jnp.array(np.random.random(shape)) for _ in range(nargs)]
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
Expand Down

0 comments on commit 6a89f60

Please sign in to comment.