A Python package based on JAX for linear and nonlinear system identification of state-space models, recurrent neural network (RNN) training, and nonlinear regression/classification.
- Contents
jax-sysid is a Python package based on JAX for linear and nonlinear system identification of state-space models, recurrent neural network (RNN) training, and nonlinear regression/classification. The algorithm can handle L1-regularization and group-Lasso regularization and relies on L-BFGS optimization for accurate modeling, fast convergence, and good sparsification of model coefficients.
The package implements the approach described in the following paper:
[1] A. Bemporad, "Linear and nonlinear system identification under $\ell_1$- and group-Lasso regularization via L-BFGS-B," submitted for publication. Available on arXiv at http://arxiv.org/abs/2403.03827, 2024. [bib entry]
pip install jax-sysid
Given input/output training data
where
The training problem to solve is
where
The regularization term
with
Let's start training a discrete-time linear model
from jax_sysid.models import LinearModel
model = LinearModel(nx, ny, nu)
model.loss(rho_x0=1.e-3, rho_th=1.e-2)
model.optimization(lbfgs_epochs=1000)
model.fit(Y,U)
Yhat, Xhat = model.predict(model.x0, U)
After identifying the model, to retrieve the resulting state-space realization you can use the following:
A,B,C,D = model.ssdata()
Given a new test sequence of inputs and outputs, an initial state that is compatible with the identified model can be reconstructed by running an extended Kalman filter and Rauch–Tung–Striebel smoothing (cf. [1]) and used to simulate the model:
x0_test = model.learn_x0(U_test, Y_test)
Yhat_test, Xhat_test = model.predict(x0_test, U_test)
R2-scores on training and test data can be computed as follows:
from jax_sysid.utils import compute_scores
R2_train, R2_test, msg = compute_scores(Y, Yhat, Y_test, Yhat_test, fit='R2')
print(msg)
It is good practice to scale the input and output signals. To identify a model on scaled signals, you can use the following:
from jax_sysid.utils import standard_scale, unscale
Ys, ymean, ygain = standard_scale(Y)
Us, umean, ugain = standard_scale(U)
model.fit(Ys, Us)
Yshat, Xhat = model.predict(model.x0, Us)
Yhat = unscale(Yshat, ymean, ygain)
Let us now retrain the model using L1-regularization and check the sparsity of the resulting model:
model.loss(rho_x0=1.e-3, rho_th=1.e-2, tau_th=0.03)
model.fit(Ys, Us)
print(model.sparsity_analysis())
To reduce the number of states in the model, you can use group-Lasso regularization as follows:
model.loss(rho_x0=1.e-3, rho_th=1.e-2, tau_g=0.1)
model.group_lasso_x()
model.fit(Ys, Us)
Groups in this case are entries in
Group-Lasso can be also used to try to reduce the number of inputs that are relevant in the model. You can do this as follows:
model.loss(rho_x0=1.e-3, rho_th=1.e-2, tau_g=0.15)
model.group_lasso_u()
model.fit(Ys, Us)
Groups in this case are entries in
jax-sysid also supports multiple training experiments. In this case, the sequences of training inputs and outputs are passed as a list of arrays. For example, if three experiments are available for training, use the following command:
model.fit([Ys1, Ys2, Ys3], [Us1, Us2, Us3])
In case the initial state train_x0=False
when calling model.loss
.
To attempt forcing that the identified linear model is asymptotically stable, i.e., that matrix
model.force_stability()
before calling the fit
function. This will introduce a custom regularization penalty that tries to enforce the constraint
To introduce a penalty that attempts forcing the identified linear model to have a given DC-gain matrix M
, you can use the following commands:
dcgain_loss = model.dcgain_loss(DCgain = M)
model.loss(rho_x0=1.e-3, rho_th=1.e-2, custom_regularization = dcgain_loss)
before calling the fit
function. Similarly, to fit instead the DC-gain of the model to steady-state input data Uss
and corresponding output data Yss
, you can use
dcgain_loss = model.dcgain_loss(Uss = Uss, Yss = Yss)
and use dcgain_loss
as the custom regularization function.
Given input/output training data
where
As for the linear case, the training problem to solve is
where
For example, let us consider the following residual RNN model without input/output feedthrough:
where
from jax_sysid.models import Model
Ys, ymean, ygain = standard_scale(Y)
Us, umean, ugain = standard_scale(U)
def sigmoid(x):
return 1. / (1. + jnp.exp(-x))
@jax.jit
def state_fcn(x,u,params):
A,B,C,W1,W2,W3,b1,b2,W4,W5,b3,b4=params
return A@x+B@u+W3@sigmoid(W1@x+W2@u+b1)+b2
@jax.jit
def output_fcn(x,u,params):
A,B,C,W1,W2,W3,b1,b2,W4,W5,b3,b4=params
return C@x+W5@sigmoid(W4@x+b3)+b4
model = Model(nx, ny, nu, state_fcn=state_fcn, output_fcn=output_fcn)
nnx = 5 # number of hidden neurons in state-update function
nny = 5 # number of hidden neurons in output function
# Parameter initialization:
A = 0.5*np.eye(nx)
B = 0.1*np.random.randn(nx,nu)
C = 0.1*np.random.randn(ny,nx)
W1 = 0.1*np.random.randn(nnx,nx)
W2 = 0.5*np.random.randn(nnx,nu)
W3 = 0.5*np.random.randn(nx,nnx)
b1 = np.zeros(nnx)
b2 = np.zeros(nx)
W4 = 0.5*np.random.randn(nny,nx)
W5 = 0.5*np.random.randn(ny,nny)
b3 = np.zeros(nny)
b4 = np.zeros(ny)
model.init(params=[A,B,C,W1,W2,W3,b1,b2,W4,W5,b3,b4])
model.loss(rho_x0=1.e-4, rho_th=1.e-4)
model.optimization(adam_epochs=1000, lbfgs_epochs=1000)
model.fit(Ys, Us)
Yshat, Xshat = model.predict(model.x0, Us)
Yhat = unscale(Yshat, ymean, ygain)
As the training problem, in general, is a nonconvex optimization problem, the obtained model often depends on the initial value of the parameters. The jax-sysid library supports training models in parallel (including static models) using the joblib
library. In the example above, we can train 10 different models using 10 jobs in joblib
as follows:
def init_fcn(seed):
np.random.seed(seed)
A = 0.5*np.eye(nx)
B = 0.1*np.random.randn(nx,nu)
C = 0.1*np.random.randn(ny,nx)
W1 = 0.1*np.random.randn(nnx,nx)
W2 = 0.5*np.random.randn(nnx,nu)
W3 = 0.5*np.random.randn(nx,nnx)
b1 = np.zeros(nnx)
b2 = np.zeros(nx)
W4 = 0.5*np.random.randn(nny,nx)
W5 = 0.5*np.random.randn(ny,nny)
b3 = np.zeros(nny)
b4 = np.zeros(ny)
return [A,B,C,W1,W2,W3,b1,b2,W4,W5,b3,b4]
models = model.parallel_fit(Ys, Us, init_fcn=init_fcn, seeds=range(10), n_jobs=10)
jax-sysid also supports recurrent neural networks defined via the flax.linen library (the flax
package can be installed via pip install flax
):
from jax_sysid.models import RNN
# state-update function
class FX(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=5)(x)
x = nn.swish(x)
x = nn.Dense(features=5)(x)
x = nn.swish(x)
x = nn.Dense(features=nx)(x)
return x
# output function
class FY(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=5)(x)
x = nn.tanh(x)
x = nn.Dense(features=ny)(x)
return x
model = RNN(nx, ny, nu, FX=FX, FY=FY, x_scaling=0.1)
model.loss(rho_x0=1.e-4, rho_th=1.e-4, tau_th=0.0001)
model.optimization(adam_epochs=0, lbfgs_epochs=2000)
model.fit(Ys, Us)
where the extra parameter x_scaling
is used to scale down (when x_scaling
jax-sysid also supports custom loss functions penalizing the deviations of
where
epsil=1.e-4
@jax.jit
def cross_entropy_loss(Yhat,Y):
loss=jnp.sum(-Y*jnp.log(epsil+Yhat)-(1.-Y)*jnp.log(epsil+1.-Yhat))/Y.shape[0]
return loss
model.loss(rho_x0=0.01, rho_th=0.001, output_loss=cross_entropy_loss)
By default, jax-sysid minimizes the classical mean squared error
jax-sysid also supports custom regularization terms
$$\frac{1}{2} \rho_{\theta} |\theta|2^2 + \rho{x_0} |x_0|_2^2 + \rho_c\max{|\theta|_2^2-1,0}^2$$
with
@jax.jit
def custom_reg_fcn(th,x0):
return 1000.*jnp.maximum(jnp.sum(th**2)-1.,0.)**2
model.loss(rho_x0=0.01, rho_th=0.001, custom_regularization= custom_reg_fcn)
As for linear systems, a special case of custom regularization function to fit the DC-gain of the model to steady-state input data Uss
and corresponding output data Yss
is obtained using the following commands:
dcgain_loss = model.dcgain_loss(Uss = Uss, Yss = Yss)
model.loss(rho_x0=1.e-3, rho_th=1.e-2, custom_regularization = dcgain_loss)
before calling the fit
function. Note that this penalty involves solving a system of nonlinear equations for every input/output steady-state pair to evaluate the loss function, so it can be slow if many steady-state data pairs are given.
To include lower and upper bounds on the parameters of the model and/or the initial state, use the following additional arguments when specifying the optimization problem:
model.optimization(params_min=lb, params_max=ub, x0_min=xmin, x0_max=xmax, ...)
where lb
and ub
are lists of arrays with the same structure as model.params
, while xmin
and xmax
are arrays of the same dimension model.nx
of the state vector. By default, each value is set equal to None
, i.e., the corresponding constraint is not enforced. See example_linear_positive.py
for examples of how to use nonnegative constraints to fit a positive linear system.
As a special case of nonlinear dynamical models, jax-sysid supports the identification of quasi-LPV models of the form
where the scheduling vector
and
For both LTI and qLPV models, jax-sysid must enable the feedthrough term feedthrough=True
when defining the model (by default, no feedthrough is in place). Moreover, for all linear, nonlinear, and qLPV models, one can force y_in_x=True
in the object constructor.
Let's train a quasi-LPV model on a sequence of inputs qlpv_fcn
, initial parameters qlpv_params_init
, regularization
from jax_sysid.models import qLPVModel
model = qLPVModel(nx, ny, nu, npar, qlpv_fcn, qlpv_params_init)
model.loss(rho_x0=1.e-3, rho_th=1.e-2)
model.optimization(lbfgs_epochs=1000)
model.fit(Ys, Us, LTI_training=True)
Yhat, Xhat = model.predict(model.x0, U)
where Us
, Ys
are the scaled input and output signals. The flag LTI_training=True
forces the training algorithm to initialize
After identifying the model, to retrieve the resulting matrices
A, B, C, D, Ap, Bp, Cp, Dp = model.ssdata()
where Ap
, Bp
, Cp
, Dp
are tensors (3D matrices) containing the corresponding linear matrices, i.e., Ap[i,:,:]
=$A_i$,
Bp[i,:,:]
=$B_i$, Cp[i,:,:]
=$C_i$, Dp[i,:,:]
=$D_i$.
To attempt to reduce the number of scheduling variables in the model, you can use group-Lasso regularization as follows:
model.loss(rho_x0=1.e-3, rho_th=1.e-2, tau_g=0.1)
model.group_lasso_p()
model.fit(Ys, Us)
Each group tau_g
is the weight associated with the corresponding group-Lasso penalty.
Parallel training from different initial guesses is also supported for qLPV models. To this end, you must define a function qlpv_param_init_fcn(seed)
that initializes the parameter vector seed
. For example,
you can train the model for 10 different random seeds on 10 jobs by running:
models = model.parallel_fit(Y, U, qlpv_param_init_fcn=qlpv_param_init_fcn, seeds=range(10), n_jobs=10)
The same optimization algorithms used to train dynamical models can be used to train static models, i.e., to solve the nonlinear regression problem:
where
For example, if the model is a shallow neural network you can use the following code:
from jax_sysid.models import StaticModel
from jax_sysid.utils import standard_scale, unscale
@jax.jit
def output_fcn(u, params):
W1,b1,W2,b2=params
y = W1@u.T+b1
y = W2@jnp.arctan(y)+b2
return y.T
model = StaticModel(ny, nu, output_fcn)
nn=10 # number of neurons
model.init(params=[np.random.randn(nn,nu), np.random.randn(nn,1), np.random.randn(1,nn), np.random.randn(1,1)])
model.loss(rho_th=1.e-4, tau_th=tau_th)
model.optimization(lbfgs_epochs=500)
model.fit(Ys, Us)
jax-sysid also supports feedforward neural networks defined via the flax.linen library:
from jax_sysid.models import FNN
from flax import linen as nn
# output function
class FY(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=20)(x)
x = nn.tanh(x)
x = nn.Dense(features=20)(x)
x = nn.tanh(x)
x = nn.Dense(features=ny)(x)
return x
model = FNN(ny, nu, FY)
model.loss(rho_th=1.e-4, tau_th=tau_th)
model.optimization(lbfgs_epochs=500)
model.fit(Ys, Us)
To include lower and upper bounds on the parameters of the model, use the following additional arguments when specifying the optimization problem:
model.optimization(lbfgs_epochs=500, params_min=lb, params_max=ub)
where lb
and ub
are lists of arrays with the same structure as model.params
. See example_static_convex.py
for examples of how to use nonnegative constraints to fit input-convex neural networks.
To solve classification problems, you need to define a custom loss function to change the default Mean-Squared-Error loss. For example, to train a classifier for a multi-category classification problem with
def cross_entropy(Yhat,Y):
return -jax.numpy.sum(jax.nn.log_softmax(Yhat, axis=1)*Y)/Y.shape[0]
model.loss(rho_th=1.e-4, output_loss=cross_entropy)
See example_static_fashion_mist.py
for an example using Keras with JAX backend to define the neural network model.
This package was coded by Alberto Bemporad.
This software is distributed without any warranty. Please cite the paper below if you use this software.
We thank Roland Toth for suggesting the use of Kung's method for initializing linear state-space models and Kui Xie for feedback on the reconstruction of the initial state via EKF + RTS smoothing.
@article{Bem24,
author={A. Bemporad},
title={Linear and nonlinear system identification under $\ell_1$- and group-{Lasso} regularization via {L-BFGS-B}},
note = {submitted for publication. Also available on arXiv
at \url{http://arxiv.org/abs/2403.03827}},
year=2024
}
Apache 2.0
(C) 2024 A. Bemporad