## Jax 101
---
This notebook accompanies the `Jax 101` [blog post](link).

In [7]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

%matplotlib inline

In [8]:
# Similar APIs for array creation!
arange_np = np.arange(5)
arange_jnp = jnp.arange(5)

linspace_np = np.linspace(-3, 3, 100)
linspace_jnp = jnp.linspace(-3, 3, 100)

zeros_np = np.zeros((10, 10), dtype=np.float16)
zeros_jnp = jnp.zeros((10, 10), dtype=jnp.float16)

In [None]:
print(f"NumPy dtype: {arange_np.dtype}")
print(f"Jax dtype: {arange_jnp.dtype}")

In [None]:
plt.plot(linspace_jnp)
plt.title("Simple plot of jnp data");

## Random Numbers

In [11]:
from jax import random

In [None]:
key = random.key(21)
print(key)

In [None]:
# These will give the same value!
x1 = random.normal(key, 3)
x2 = random.normal(key, 3)
print(x1)
print(x2)

In [None]:
# Split the key to get unique values for multiple random function calls
newkey1, newkey2 = random.split(key, 2)
print(newkey1)
print(newkey2)

In [None]:
x1 = random.normal(newkey1, 3)
x2 = random.normal(newkey2, 3)
print(x1)
print(x2)

## grad

In [16]:
from jax import grad

In [17]:
def relu(x):
    return jnp.maximum(0, x)

relu_grad = grad(relu)

In [None]:
xs = jnp.linspace(-3, 3, 200)
ys = relu(xs)
ys_grad = [relu_grad(x) for x in xs]

fix, ax = plt.subplots()
ax.plot(xs, ys, label="relu")
ax.plot(xs, ys_grad, label="relu_grad")
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)

ax.grid(True)
ax.legend();

In [19]:
def f(a, b):
    return 2*a**3 - b**2

f_grad_0 = grad(f, argnums=0)
f_grad_1 = grad(f, argnums=1)

In [None]:
xs = jnp.linspace(-3, 3, 200)
ys = jnp.linspace(-3, 3, 200)
ys_orig = f(xs, ys)
ys_grad_0 = [f_grad_0(x, y) for x, y in zip(xs, ys)]
ys_grad_1 = [f_grad_1(x, y) for x, y in zip(xs, ys)]

fix, ax = plt.subplots()
ax.plot(xs, ys_orig, label="f(a, b)")
ax.plot(xs, ys_grad_0, label="∂f/∂a")
ax.plot(xs, ys_grad_1, label="∂f/∂b")

ax.grid(True)
ax.legend();

## vmap

In [21]:
from jax import vmap

In [22]:
def relu(x):
    return jnp.maximum(0, x)

# vectorize
relu_vmap_grad = vmap(grad(relu))

In [None]:
xs = jnp.linspace(-3, 3, 200)
ys = relu(xs)
ys_grad = relu_vmap_grad(xs) # can handle batches of data now

_, ax = plt.subplots()
ax.plot(xs, ys, label="relu")
ax.plot(xs, ys_grad, label="relu_grad")
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)

ax.grid(True)
ax.legend();

## jit

In [24]:
from jax import jit, make_jaxpr

In [25]:
def f(x):
    x = x + 2
    x = x**2 - 4
    return jnp.sum(x)

f_jit = jit(f)

In [None]:
xs = jnp.linspace(-10, 10, 1_000_000)
# warm up jitted function (i.e. it compiles 1st time it runs)
_ = f_jit(xs)

%timeit f(xs)
%timeit f_jit(xs).block_until_ready()

In [None]:
make_jaxpr(f)(xs)