In [1]:
import jax
import jax.numpy as jnp

@jax.jit
def power_scale(value, scale):
    """Returns value scaled by the exponent factor."""
    return value ** scale

In [2]:
value = jnp.arange(10000.0)
scale_A = jnp.array(1.0)
scale_B = jnp.array(3.0)
scale_C = jnp.array(5.0)
scale_D = jnp.array(.5)

In [3]:
power_scale(value, scale_A)

Array([0.000e+00, 1.000e+00, 2.000e+00, ..., 9.997e+03, 9.998e+03,
       9.999e+03], dtype=float32)

In [4]:
%timeit power_scale(value, scale_A).block_until_ready()

13.9 µs ± 319 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [5]:
%timeit power_scale(value, scale_B).block_until_ready()

41.1 µs ± 46.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
%timeit power_scale(value, scale_C).block_until_ready()

41 µs ± 99 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
%timeit power_scale(value, scale_D).block_until_ready()

41.8 µs ± 544 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
%timeit power_scale(value, scale_A).block_until_ready()

13.8 µs ± 105 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [10]:
from jax import lax

@jax.jit
def power_scale(value, scale):
    """Returns value scaled by the exponent factor."""
    return lax.cond(scale == 1, lambda: value, lambda: value ** scale)



In [11]:
%timeit power_scale(value, scale_A).block_until_ready()

3.49 µs ± 68.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [12]:
%timeit power_scale(value, scale_B).block_until_ready()

41.2 µs ± 86.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
