# JAX as accelerated numpy - Sep 19
### Replicating code from google documentation

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

103 ms ± 793 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

107 ms ± 556 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
x = jnp.arange(10)

In [9]:
x

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [10]:
long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()

3.95 ms ± 33.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# grad

In [11]:
def sum_of_squares(x):
    return jnp.sum(x**2)

In [13]:
sum_of_squares_dx = grad(sum_of_squares)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))

30.0
[2. 4. 6. 8.]
