Skip to content

Commit

Permalink
Add small and big matmul to api_benchmarks.
Browse files Browse the repository at this point in the history
name                                  cpu/op
jit_small_matmul                      2.96µs ± 2%
jit_big_matmul                        22.1µs ±21%

name                                  time/op

jit_small_matmul                      2.96µs ± 2%
jit_big_matmul                        22.7µs ±21%

PiperOrigin-RevId: 435453853
  • Loading branch information
zhangqiaorjc authored and jax authors committed Mar 17, 2022
1 parent 53f52cb commit 5d7f639
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,30 @@ def jit_simple(state):
f(a, b).block_until_ready()


@google_benchmark.register
def jit_small_matmul(state):
x = np.random.uniform(size=(2, 2)).astype(np.float32)
x = jax.device_put(x)

f = jax.jit(lambda x: jnp.dot(x, x))
f(x).block_until_ready()

while state:
f(x).block_until_ready()


@google_benchmark.register
def jit_big_matmul(state):
x = np.random.uniform(size=(100, 100)).astype(np.float32)
x = jax.device_put(x)

f = jax.jit(lambda x: jnp.dot(x, x))
f(x).block_until_ready()

while state:
f(x).block_until_ready()


def jit_simple_many_args_dispatch(n, state):
args = [jax.device_put(i) for i in range(n)]
f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
Expand Down

0 comments on commit 5d7f639

Please sign in to comment.