This notebook illustrates some features of the JAX library in the context of a simple linear regression problem. In real life, we could fit this model much more simply by using the the least squares estimator
$$
\hat{\beta}=(X^T X)^{-1}X^T y,
$$
but here we will optimize the mean-square error loss function via gradient descent.

In [1]:
import jax
import jax.numpy as jnp
import jax.random as random
from collections import namedtuple
import time
SEED = int(time.time())
print(f"seed is {SEED}")
key = random.key(SEED)
ModelParameters = namedtuple('ModelParameters', 'w b')

seed is 1692330859


In [2]:
@jax.jit
def predict(params: ModelParameters, x: jnp.array) -> jnp.array:
    return params.w.dot(x) + params.b
vpredict = jax.vmap(predict, in_axes=[None, 0])

JAX random numbers are a bit weird -- we have to push around some state in the `key` variable.

In [3]:
xs = random.normal(key, shape=(200,1))
key, _ = random.split(key)
Wtrue = random.normal(key, shape=(1,))
key, _ = random.split(key)
btrue = random.normal(key, shape=(1,))
true_params = ModelParameters(Wtrue, btrue)
true_ys = vpredict(true_params, xs)

key, _ = random.split(key)
W = random.normal(key, shape=(1,))
key, _ = random.split(key)
b = random.normal(key, shape=(1,))
params = ModelParameters(W, b)

Here we define our loss function, the mean of the square of the errors.

In [4]:
@jax.jit
def mse(parameters: ModelParameters, xs: jnp.array, ys: jnp.array) -> jnp.array:
    y_hats = vpredict(parameters, xs)
    return jax.numpy.mean(jnp.square(y_hats - ys))
grad_mse = jax.grad(mse)

Below the model is fitted.

In [5]:
lr = 1e-2
for i in range(1000):
    batch_grads = grad_mse(params, xs, true_ys)
    params = ModelParameters(params.w - lr * batch_grads.w, params.b - lr * batch_grads.b)
    if i % 100 == 0:
        print(params)

ModelParameters(w=Array([0.2682777], dtype=float32), b=Array([0.1782908], dtype=float32))
ModelParameters(w=Array([1.8501179], dtype=float32), b=Array([0.6013752], dtype=float32))
ModelParameters(w=Array([2.0439496], dtype=float32), b=Array([0.6323448], dtype=float32))
ModelParameters(w=Array([2.0680373], dtype=float32), b=Array([0.63335055], dtype=float32))
ModelParameters(w=Array([2.0710773], dtype=float32), b=Array([0.6330958], dtype=float32))
ModelParameters(w=Array([2.0714667], dtype=float32), b=Array([0.63301265], dtype=float32))
ModelParameters(w=Array([2.0715175], dtype=float32), b=Array([0.6329955], dtype=float32))
ModelParameters(w=Array([2.0715194], dtype=float32), b=Array([0.6329933], dtype=float32))
ModelParameters(w=Array([2.0715194], dtype=float32), b=Array([0.6329933], dtype=float32))
ModelParameters(w=Array([2.0715194], dtype=float32), b=Array([0.6329933], dtype=float32))


Finally, let's compare the true parameters to the learned ones.

In [6]:
assert jnp.isclose(true_params.w, params.w)
assert jnp.isclose(true_params.b, params.b)