In [None]:
## Non-linear regression with feedforward networks

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import nonlinear_benchmarks
import optax
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler

In [None]:
key = jr.key(42)
keys = jr.split(key, 5)

In [None]:
train, test = nonlinear_benchmarks.Cascaded_Tanks(atleast_2d=True)

In [None]:
train.u.shape, train.y.shape

In [None]:
train.sampling_time

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10, 6))
plt.suptitle('Training (left) and test (right) data')
train_t = train.sampling_time * jnp.arange(train.u.shape[0])
ax[0, 0].plot(train_t, train.y)
ax[1, 0].plot(train_t, train.u)
test_t = train.sampling_time * jnp.arange(test.u.shape[0])
ax[0, 1].plot(test_t, test.y)
ax[1, 1].plot(test_t, test.u);

In [None]:
# Rescale data
scaler_u = StandardScaler()
u = scaler_u.fit_transform(train.u).astype(jnp.float32)

scaler_y = StandardScaler()
y = scaler_y.fit_transform(train.y).astype(jnp.float32)

In [None]:
# Same shapes as before...
u.shape, y.shape

In [None]:
# ... but normalized to zero mean and unit variance
u.mean(), u.std()

In [None]:
# Initialize all parameters and organize them in a dictionary

nu = 1; nx = 2; ny = 1; nh = 16

params_init = {
  "W1": jr.normal(keys[0], shape=(nh, nu+nx)), # nu + nx inputs to the network
  "b1": jr.normal(keys[1], shape=(nh,)),
  "W2": jr.normal(keys[2], shape=(nx, nh)) * 1e-3, # nx outputs from the network
  "b2": jr.normal(keys[3], shape=(nx,)) * 1e-3, 

  "C": jr.normal(keys[4], shape=(ny, nx)), # nx inputs and ny outputs
}

In [None]:
# Define the neural network as a function of parameters and inputs

def fg(p, x, u):

    # state update
    xu = jnp.concatenate([x, u]) # vec(x, u)
    z = jnp.tanh(p["W1"] @ xu + p["b1"])
    x_new = x + p["W2"] @ z + p["b2"]

    # output equation
    y = p["C"] @ x_new # linear output layer
    return x_new, y

In [None]:
# Apply the function to test it
fg(params_init, jnp.zeros((nx,)), jnp.zeros((nu,)))

In [None]:
# Loop implementation of the simulation
x0 = jnp.zeros((nx,))

x_step = x0
y_sim = []
for t in range(u.shape[0]):
    x_step, y_step = fg(params_init, x_step, u[t])
    y_sim.append(y_step)

xf = x_step # final state after simulation
y_sim = jnp.stack(y_sim, axis=0) # simulation output
y_sim

In [None]:
# Alternative implementation using jax.lax.scan (harder to read for the novice, but more efficient)

# define funfunction with parameters p fixed
def fg_p(x, u_t):
    x_new, y = fg(params_init, x, u_t)
    return x_new, y
# Use scan to simulate over the entire input sequence
xf, y_sim = jax.lax.scan(fg_p, x0, u)
y_sim

In [None]:
def simulate(p, x0, u):
    x_step = x0
    y_sim = []
    for t in range(u.shape[0]):
        x_step, y_step = fg(p, x_step, u[t])
        y_sim.append(y_step)
    y_sim = jnp.stack(y_sim, axis=0) # simulation output
    return y_sim

In [None]:
# Equivalent implementation using jax.lax.scan (generally faster)
def simulate_scan(p, x0, u):
    def fg_func(x, u_t):
        x_new, y = fg(p, x, u_t)
        return x_new, y
    xf, y_sim = jax.lax.scan(fg_func, x0, u)
    return y_sim

In [None]:
opt_vars = {
    "params": params_init,
    "x0": jnp.zeros((nx,))
}

In [None]:
def loss_fn(ov, y, u):
    y_sim = simulate(ov["params"], ov["x0"], u)
    return jnp.mean((y - y_sim)**2)

In [None]:
loss_fn(opt_vars, y, u)

In [None]:
# Setup optimizer
lr = 1e-3
iters = 5_000
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(opt_vars)
loss_grad_fn = jax.jit(jax.value_and_grad(loss_fn))

# Training loop
LOSS = []
for iter in (pbar := tqdm(range(iters))):
        loss_val, grads = loss_grad_fn(opt_vars, y, u)
        updates, opt_state = optimizer.update(grads, opt_state)
        opt_vars = optax.apply_updates(opt_vars, updates)
        LOSS.append(loss_val)
        if iter % 100 == 0:
            pbar.set_postfix_str(f"Loss step {iter}: {loss_val}")

In [None]:
plt.figure()
plt.plot(LOSS)

In [None]:
y_test_hat = simulate(opt_vars["params"], opt_vars["x0"], scaler_u.transform(test.u).astype(jnp.float32))
y_test_hat = scaler_y.inverse_transform(y_test_hat)

In [None]:
plt.figure()
plt.plot(test.y, "k")
plt.plot(y_test_hat, "b")
plt.plot(y_test_hat - test.y, "r")
plt.show()

In [None]:
rmse = jnp.sqrt(jnp.mean((y_test_hat - test.y)**2))
rmse