In [None]:
from tqdm import tqdm
import time
import jax
import jax.numpy as jnp
from sklearn.preprocessing import StandardScaler
import diffrax
from flax import linen as nn
from typing import Sequence, Dict, Any
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
from interpolation import ZOHInterpolation
import nonlinear_benchmarks
import nonlinear_benchmarks.error_metrics as metrics

In [None]:
%matplotlib widget

In [None]:
train_val, test = nonlinear_benchmarks.Cascaded_Tanks(atleast_2d=True)
sampling_time = train_val.sampling_time
u_train, y_train = train_val

In [None]:
# Rescale data
scaler_u = StandardScaler()
u = scaler_u.fit_transform(u_train).astype(jnp.float32)

scaler_y = StandardScaler()
y = scaler_y.fit_transform(y_train).astype(jnp.float32)

ts = jnp.arange(0.0, u.shape[0]) * sampling_time


In [None]:
nx = 2
nu = u.shape[-1]
ny = y.shape[-1]

In [None]:
class MLP(nn.Module):
    features: Sequence[int]
    layer_kwargs: Dict[str, Any] = None
    last_layer_kwargs: Dict[str, Any] = None

    def setup(self):
        layer_kwargs = self.layer_kwargs if self.layer_kwargs is not None else {}
        last_layer_kwargs = self.last_layer_kwargs if self.last_layer_kwargs is not None else {}
        self.layers = [nn.Dense(feat, **layer_kwargs) for feat in self.features[:-1]] + [nn.Dense(self.features[-1], **last_layer_kwargs)]

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.tanh(x)
        return x
    
class StateUpdateMLP(nn.Module):
    features: Sequence[int]
    scale: float = 1e-3

    def setup(self):
        # Set custom initializers
        #kernel_init = jax.nn.initializers.normal(stddev=1e-4)  # Standard deviation for the normal distribution
        #bias_init = jax.nn.initializers.constant(0)  # Constant value for all biases

        # Create layers with custom initializers
        self.net = MLP(self.features)#, last_layer_kwargs={"kernel_init": kernel_init, "bias_init": bias_init})

    def __call__(self, x, u):
        dx = self.scale * self.net(jnp.r_[x, u])
        return dx  


In [None]:
f_xu = StateUpdateMLP(features=[32, 16, nx])
g_x = MLP(features=[16, ny])
x0 = jnp.zeros(nx)
_, params_f = f_xu.init_with_output(jax.random.key(0), jnp.ones(nx), jnp.ones(nu))
_, params_g = g_x.init_with_output(jax.random.key(0), jnp.ones(nx))

In [None]:
def simulate(params_f, params_g, x0, u):
    #u_fun = diffrax.LinearInterpolation(ts=ts, ys=u.ravel())
    u_fun = ZOHInterpolation(ts=ts, ys=u.ravel())
    def vector_field(t, x, args):
        ut = u_fun.evaluate(t)[..., None]
        dx = f_xu.apply(args, x, ut)
        return dx

    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        #diffrax.Euler(),
        #diffrax.Tsit5(),
        diffrax.Dopri5(),
        ts[0],
        ts[-1],
        dt0=sampling_time,
        y0=x0,
        #stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6, jump_ts=ts),
        stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
        saveat=diffrax.SaveAt(ts=ts),
        args=params_f,
        max_steps=int(1e6),
    )
    x = sol.ys
    y = g_x.apply(params_g, x)
    #y = x[:, 1]
    return y

In [None]:
opt_variables = (params_f, params_g, x0)
def loss_fn(opt_variables, u, y):
    params_f, params_g, x0 = opt_variables
    y_pred = simulate(params_f, params_g, x0, u)
    return jnp.mean((y - y_pred) ** 2)

loss_grad_fn = jax.jit(jax.value_and_grad(loss_fn))

In [None]:
# Setup optimizer
optimizer = optax.adam(learning_rate=1e-4)
opt_state = optimizer.init(opt_variables)

In [None]:
# Training loop
time_start = time.time()
LOSS = []
epochs = 10_000
for epoch in (pbar := tqdm(range(epochs))):
    loss_val, grads = loss_grad_fn(opt_variables, u, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    opt_variables = optax.apply_updates(opt_variables, updates)
    LOSS.append(loss_val)
    if epoch % 100 == 0:
        pbar.set_postfix_str(f"Loss step {epoch}: {loss_val}")
    #print()

train_time = time.time() - time_start
print(f"Training time: {train_time:.2f}")

In [None]:
plt.figure()
plt.plot(LOSS)

In [None]:
params_f, params_g, x0 = opt_variables
y_sim = simulate(params_f, params_g, x0, u)
y_sim.shape
plt.figure()
plt.plot(y)
plt.plot(y_sim)


In [None]:
u_test, y_test = test

u_test = scaler_u.transform(u_test)
y_test_hat = simulate(params_f, params_g, x0, u_test)
y_test_hat = scaler_y.inverse_transform(y_test_hat)

fit = metrics.fit_index(y_test, y_test_hat)[0]
rmse = metrics.RMSE(y_test, y_test_hat)[0] 
nrmse = metrics.NRMSE(y_test, y_test_hat)[0]

print(f"{fit=} \n{rmse=} \n{nrmse=}")
plt.figure()
plt.plot(y_test)
plt.plot(y_test_hat)