# `nb02`: JAX: Numpy, but differentiable and on GPUs

![](figures/nb02/jax.png)

JAX is a drop-in replacement for Numpy, but differentiable and compiled using XLA to run on GPUs and CPUs.

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
import matplotlib.pyplot as plt

xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))
ys = xs * 3 - 1 + noise

plt.scatter(xs, ys)
plt.show()

In [None]:
def model(theta, x):
    w, b = theta
    return w * x + b

In [None]:
theta = jnp.array([1., 1.])
plt.scatter(xs, ys)
plt.plot(xs, model(theta, xs))

In [None]:
def loss_fn(theta, x, y):
    prediction = model(theta, x)
    return 0.5 * jnp.mean((prediction-y)**2)

In [None]:
def update(theta, x, y, lr=0.1):
    return theta - lr * jax.grad(loss_fn)(theta, x, y)

In [None]:
for _ in range(1000):
    theta = update(theta, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, model(theta, xs))

w, b = theta
print(f"w: {w:<.2f}, b: {b:<.2f}")