In [None]:
## Non-linear regression with feedforward networks

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

In [None]:
key = jr.key(42)
key_x, key_W1, key_b1, key_W2, key_b2 = jr.split(key, 5)

### Data Generation

Synthetic data from the function:

$$y = 2 \sin(3x) - 3 \cos(7x) $$

In [None]:
def f(x):
    return 2*jnp.sin(3*x)  - 3*jnp.cos(7*x)

In [None]:
a = -1 # lower limit x
b = 1 # higher limit x
n_samples = 200 # data point
sigma_e = 0.1 # noise std
x_train = a + jr.uniform(key_x, shape=(n_samples, 1))*(b - a);
x_train = x_train.reshape(-1, 1)
y_train = f(x_train)

plt.plot(x_train, y_train, "*k")
plt.title("Training data")
plt.xlabel("x")
plt.ylabel("y");

Always good to check data types and shapes. Saves like 80% of debugging time!

In [None]:
x_train.shape, y_train.shape

In [None]:
x_train.dtype, y_train.dtype

### Model Definition

Define the feedforward neural network with one hidden layer:

  $$\hat y = W_2 \tanh (W_1 x + b_1) + b_2 $$

The parameters to be tuned are:

  $$p = \mathrm{vec}(W_1, b_1, W_2, b_2)$$
  $$ W_1 \in \mathbb{R}^{n_h \times n_x}, b_1 \in \mathbb{R}^{n_h}, 
W_2 \in \mathbb{R}^{n_y \times n_h}, b_2 \in \mathbb{R}^{n_y}.$$

In [None]:
# Initialize all parameters and organize them in a dictionary

nx = 1; ny = 1; nh = 16
p_hat = {
  "W1": jr.normal(key_W1, shape=(nh, nx)),
  "b1": jr.normal(key_b1, shape=(nh,)),
  "W2": jr.normal(key_W2, shape=(ny, nh)),
  "b2": jr.normal(key_b2, shape=(ny,)),
}

p_hat

In [None]:
# Define the neural network as a function of parameters and inputs

def nn(p, x):
    z = jnp.tanh(p["W1"] @ x + p["b1"])
    y = p["W2"] @ z + p["b2"]
    return y

In [None]:
# Run the neural network with initial parameters and a sample input

nn(p_hat, x_train[10])

In [None]:
# This fails because of shape mismatch, we need to vectorize the nn function
# nn(p_hat, x_train)

In [None]:
# Do nothing for first arg, expect a batch axis at the left (0th axis) for second arg
batched_nn = jax.vmap(nn, in_axes=(None, 0))

In [None]:
# The batched output also has a batch axis at the left (0th axis). Just what we want!
y = batched_nn(p_hat, x_train)
y.shape

In [None]:
# Check that it is correct, if you don't believe!
nn(p_hat, x_train[10]), y[10]

### Model Training

From now on, it's more or less like what we did for linear regression! 

In [None]:
def loss_fn(p, y, x):
    ym = batched_nn(p, x)
    loss = jnp.mean((y - ym) ** 2)
    return loss

# the function loss_grad_fn will return both loss and gradient of the loss
loss_grad_fn = jax.value_and_grad(loss_fn, 0)

# Important performance trick: just-in-time compilation for this compute-intensive part!
loss_grad_fn = jax.jit(loss_grad_fn)

In [None]:
p_init = p_hat # save it just for reference

In [None]:
lr = 1e-2 # learning rate
LOSS = []
for i in range(10_000):
    l, g = loss_grad_fn(p_hat, y_train, x_train)
    p_hat = jax.tree.map(lambda x, y: x - lr*y, p_hat, g)
    LOSS.append(l)

In [None]:
x_train_srt = jnp.sort(x_train, axis=0)
plt.figure()
plt.title("Model fit")
plt.plot(x_train, y_train, "k*", label="y")
plt.plot(x_train_srt, batched_nn(p_hat, x_train_srt), "g", label="$f(p^{200}, x)$")
plt.plot(x_train_srt, batched_nn(p_init, x_train_srt), "b", label="$f(p^{1}, x)$")
plt.xlabel("x")
plt.ylabel("y")
plt.legend();

In [None]:
plt.figure()
plt.title("Loss vs. Iteration")
plt.plot(LOSS)
plt.xlabel("Iteration (-)")
plt.ylabel("Loss");