In [22]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
from jax._src.api import block_until_ready

In [4]:
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

## Updating Arrays Functionally

In [15]:
# Arrays are immutable in Jax, so x[0] = y won't work.
# Keys are also passed to random.normal() explicitly instead of relying on explicit state.
key = jax.random.PRNGKey(100)
x = jax.random.normal(key=key, shape=(5,))
print('Original x:', x)
x = x.at[0].set(20.0)
print('Updated x:', x)

Original x: [-0.69846076 -0.7054794   1.3151264  -2.0337505  -0.08158301]
Updated x: [20.         -0.7054794   1.3151264  -2.0337505  -0.08158301]


## Speed Test

In [21]:
key = jax.random.PRNGKey(100)
x = jax.random.normal(key, shape=(1000, 1000))
key = jax.random.PRNGKey(200)
y = jax.random.normal(key, shape=(1000, 1000))

def numpy_matmul(x, y):
    return np.dot(x, y)

def jnp_matmul(x, y):
    return jnp.dot(x, y)

# Just-in-time (JIT) compilation can be used to speed up execution.
jnp_matmul_compiled = jit(jnp_matmul)

print("np matmul:")
%timeit numpy_matmul(x, y)
print("jnp matmul:")
%timeit jnp_matmul(x, y).block_until_ready()
print("jnp matmul compiled:")
%timeit jnp_matmul_compiled(x, y).block_until_ready()

np matmul:
48.7 ms ± 8.83 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
jnp matmul:
643 µs ± 20.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
jnp matmul compiled:
643 µs ± 4.73 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
