In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

In [None]:
# Generate random keys for reproducibility
key = jr.key(42)
key_e, key_p = jr.split(key, 2)

## Linear regression

Linear function definition:

$$ f(p, x) : \mathbb{R}^2 \times \mathbb{R} \mapsto \mathbb{R} = p_1 x + p_2 $$

In [None]:
def f(p, x):
    y = p[0] * x  + p[1]
    return y

Apply function ``f`` with "true" parameter $p^o = [1.0\; 2.0]$ and $x=0.5$:

In [None]:
p_o = jnp.array([1.0, 2.0])
x = jnp.array(0.5)
y = f(p_o, x); y

Apply function ``f`` with "true" parameter $p^o = [1.0\; 2.0]$ and 100 linearly spaced points in $[-2, 2]$. Add Gaussian noise with zero mean and standard deviation 0.1 to the output.

In [None]:
N = 100
x = jnp.linspace(-2, 2, N)
y = f(p_o, x) + jr.normal(key_e, (N,)) * 0.1

plt.plot(x, y, "k*", label="y")
plt.title("Training dataset")
plt.xlabel("x")
plt.ylabel("y");

Function `f` works correctly both with scalar and vector input `x`.

## Loss definition

The Mean Squared Error (MSE) loss is:
$$ \mathcal{L}(p, y, x) : \mathbb{R}^{n_p} \times \mathbb{R}^N \times \mathbb{R}^N \mapsto \mathbb{R}
= \frac{1}{N}\sum_{i=1}^{N} \big(y_i - f(p, y_i, x_i)\big)^2
$$

In [None]:
def loss_fn(p, y, x):
    ym = f(p, x) # works both with scalar (data point) or vector (data batch) x
    loss = jnp.mean((y - ym) ** 2)
    return loss

In [None]:
p_hat = jax.random.normal(key_p, shape=(2,)); p_hat

In [None]:
loss_fn(p_hat, y, x)

In [None]:
loss_fn(p_o, y, x) # loss close to 0.001 (std^2) at true parameters, as expected!

## Automatic differentiation in Jax

Compute the gradient:

$$
\nabla_{1} \mathcal{L}(p, y, x): \mathbb{R}^{n_p} \times \mathbb{R}^N \times \mathbb{R}^N \mapsto \mathbb{R}^{n_p},
$$
i.e. the derivative of $\mathcal{L}$ with respect to its first argument: $p$.

In [None]:
# this defines the gradient function
grad_fn = jax.grad(loss_fn, 0) # gradient wrt 1st agrument

In [None]:
grad_fn(p_hat, y, x)

In [None]:
grad_fn(p_o, y, x) # gradient close to 0 at true parameters, as it should!

## Fitting a model with Jax

### Plain gradient descent by hand

In [None]:
p_init = p_hat

lr = 1e-2 # learning rate

for i in range(200):
    g = grad_fn(p_hat, y, x)
    p_hat = p_hat - lr * g 

In [None]:
plt.figure()
plt.plot(x, y, "k*", label="y")
plt.plot(x, f(p_hat, x), "g", label="$f(p^{200}, x)$")
plt.plot(x, f(p_init, x), "b", label="$f(p^{1}, x)$")
plt.xlabel("x")
plt.ylabel("y")
plt.legend();