# Jax musings


In [1]:
import jax
import jax.numpy as jnp

import numpy as np

# Jax as a numpy replacement
Is Jax faster than numpy? Can we just use jax now?

## Init random arrays
Note that jax doesn't have stateful random number generation.

In [47]:
def create_numpy_array(seed, shape, dtype):
    """Create numpy random array."""
    return np.random.RandomState(seed).rand(*shape).astype(dtype)

def create_jax_array(seed, shape, dtype):
    """Create jax random array."""
    seed = jax.random.PRNGKey(42)
    return jax.random.uniform(seed, shape).astype(dtype)

In [41]:
jar = create_jax_array(42, (1_000, 1_000), jnp.float32)
nar = create_numpy_array(42, (1_000, 1_000), np.float32)

In [54]:
%timeit (jar + 1).block_until_ready()

136 µs ± 5.65 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [55]:
%timeit (nar + 1)

275 µs ± 885 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Simple tests
Run some very simple linear algebra tests

In [60]:
def timeit(func, ar, *args, **kwargs):
    """Generic timing function"""
    if isinstance(ar, jnp.ndarray): 
        # jax cheats, need to block until ready because it is async.
        out = %timeit -o func(ar, *args, **kwargs).block_until_ready()
    else:
        out = %timeit -o func(ar, *args, **kwargs)
    
    return out.best

In [61]:
def axpy_self(ar, scalar):
    return ar * scalar + ar

In [62]:
timeit(axpy_self, nar, 1)

553 µs ± 4.04 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


0.0005484555149996594

In [63]:
timeit(axpy_self, jar, 1)

335 µs ± 25.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


0.0003004327329999796

In [30]:
%timeit jar + 1

141 µs ± 5.15 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
