# Parallel Wiener-Hammerstein system

See http://arxiv.org/pdf/1708.06543 for more information on the dataset.


In [1]:
import sys
sys.path.insert(0, '..')

import matplotlib.pyplot as plt
import nonlinear_benchmarks as nlb
import numpy as np
import optimistix as optx

from src import best_linear_approximation as bla
from src import basis_functions, data_manager, inference_and_learning, nonlinear_functions

seed = 42
np.random.seed(seed)  # for reproducibility

# Load data
ParWH_full_train, ParWH_full_test = nlb.ParWH() 

# Initialise variables
N = 16384  # number of samples per period
R = 5  # number of random phase multisine realisations
P = 2  # number of periods
amplitude_level = 4  # must be one of {0, 1, 2, 3, 4}

nu, ny = 1, 1  # SISO system

fs = 78e3 # [Hz]
f_idx = np.arange(1, 4096)  # frequency lines of interest (excludes DC)

# Load data
ParWH_full_train, ParWH_full_test = nlb.ParWH() 
ParWH_train = [
    data for data in ParWH_full_train
    for phase in range(R)
    if data.name == f'Est-phase-{phase}-amp-{amplitude_level}'
]
ParWH_test = [
    data for data in ParWH_full_test
    if data.name == f'Val-amp-{amplitude_level}'
][0]

# Preprocess data
u_train = np.array([data.u for data in ParWH_train]).reshape(R, nu, N, P)
y_train = np.array([data.y for data in ParWH_train]).reshape(R, ny, N, P)
u_train = np.transpose(u_train, (2, 1, 0, 3))
y_train = np.transpose(y_train, (2, 1, 0, 3))

u_test = np.transpose(ParWH_test.u.reshape(1, nu, N, 2), (2, 1, 0, 3))
y_test = np.transpose(ParWH_test.y.reshape(1, ny, N, 2), (2, 1, 0, 3))

# Create input-output training data object
io_data = data_manager.create_data_object(u_train, y_train, f_idx, fs)

Step 1: Best Linear Approximation

In [2]:
##### (i) Nonparametric estimate #####
G_nonpar = bla.compute_nonparametric(io_data)

##### (ii) Parametrize using frequency-domain subspace identification method #####
nx = 12  # number of states
q = nx + 1  # subspace dimensioning parameter
bla_fsid = bla.freq_subspace_id(G_nonpar, nx, q)

# Simulate and check time-domain performance
u_bar = np.mean(io_data.time.u, axis=-1)  # we take the mean over the periods
y_bar = np.mean(io_data.time.y, axis=-1)  # we take the mean over the periods
handicap = 1000  # number of samples to start 'ahead of time' for transient effects to die out (only works for periodic data!)

y_sim_bla_fsid = bla_fsid.simulate(u_bar, handicap=handicap)[0]
NRMSE_bla_fsid = 100 * np.sqrt(np.mean((y_bar - y_sim_bla_fsid)**2)) / np.sqrt(np.mean(y_bar**2)) 

print(f'NRMSE of FSID BLA: {NRMSE_bla_fsid:.2f}%\n')

##### (iii) Frequency-domain iterative optimization starting from FSID BLA #####
solver = optx.LevenbergMarquardt(rtol=1e-3, atol=1e-6)
max_iter = 100

bla_opti = bla.freq_iterative_optimization(G_nonpar, bla_fsid, solver, max_iter)

# Simulate and check time-domain performance
y_sim_bla_opti = bla_opti.simulate(u_bar, handicap=handicap)[0]
NRMSE_bla_opti = 100 * np.sqrt(np.mean((y_bar - y_sim_bla_opti)**2)) / np.sqrt(np.mean(y_bar**2)) 

print(f'NRMSE of optimized BLA: {NRMSE_bla_opti:.2f}%\n')

NRMSE of FSID BLA: 13.31%

Starting iterative optimization...
   Iteration 0, Loss: 2.1803e+00
   Iteration 1, Loss: 1.7492e+00
   Iteration 2, Loss: 1.7492e+00
   Iteration 3, Loss: 1.7492e+00
   Iteration 4, Loss: 1.6736e+00
   Iteration 5, Loss: 1.6576e+00
   Iteration 6, Loss: 1.5595e+00
   Iteration 7, Loss: 1.5328e+00
   Iteration 8, Loss: 1.5101e+00
   Iteration 9, Loss: 1.4770e+00
   Iteration 10, Loss: 1.4434e+00
   Iteration 11, Loss: 1.4219e+00
   Iteration 12, Loss: 1.4133e+00
   Iteration 13, Loss: 1.4100e+00
   Iteration 14, Loss: 1.4081e+00
   Iteration 15, Loss: 1.4067e+00
   Iteration 16, Loss: 1.4054e+00
   Iteration 17, Loss: 1.4042e+00
   Iteration 18, Loss: 1.4030e+00
   Iteration 19, Loss: 1.4016e+00
   Iteration 20, Loss: 1.4001e+00
   Iteration 21, Loss: 1.3982e+00
   Iteration 22, Loss: 1.3957e+00
   Iteration 23, Loss: 1.3918e+00
   Iteration 24, Loss: 1.3918e+00
   Iteration 25, Loss: 1.3857e+00
   Iteration 26, Loss: 1.3857e+00
   Iteration 27, Loss: 1.3753e

In [3]:
from typing import NamedTuple, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx

from src.data_manager import FrequencyData, InputOutputData
from src.basis_functions import AbstractBasisFunction
from src.nonlinear_functions import create_custom_basis_function_model
from src._model_structures import ModelBLA, ModelNonlinearLFR
from src._solve import solve 


class _ThetaWZ(eqx.Module):
    B_w_star: jnp.ndarray = eqx.field(converter=jnp.asarray)
    C_z_star: jnp.ndarray = eqx.field(converter=jnp.asarray)
    D_yw_star: jnp.ndarray = eqx.field(converter=jnp.asarray)
    D_zu_star: jnp.ndarray = eqx.field(converter=jnp.asarray)


class _OptiArgs(NamedTuple):
    theta_uy: tuple
    phi: AbstractBasisFunction
    lambda_w: jnp.ndarray
    fixed_point_iters: int
    f_data: tuple
    Lambda: jnp.ndarray
    Tz_inv: jnp.ndarray
    G_yu: jnp.ndarray
    N: int


def solverrr(
   io_data: InputOutputData,
   bla: ModelBLA,
   phi: AbstractBasisFunction,
   nw: int,
   lambda_w: float,
   fixed_point_iters: int,
   solver: Union[optx.AbstractLeastSquaresSolver, optx.AbstractMinimiser],
   max_iter: int,
   seed: int
) -> ModelNonlinearLFR:

    theta0, args = _prepare_problem(
        io_data, bla, phi, nw, lambda_w, fixed_point_iters, seed
    )

    # Optimize the model parameters
    print('Starting iterative optimization...')
    solve_result = solve(theta0, solver, args, _loss_fn, max_iter)
    print('\n')
    
    theta_opt = solve_result.theta
    aux = solve_result.aux
    
    beta = aux[-1]
    
    return ModelNonlinearLFR(
        A=args.theta_uy[0],
        B_u=args.theta_uy[1],
        C_y=args.theta_uy[2],
        D_yu=bla.D_yu,
        B_w=theta_opt.B_w_star,
        C_z=args.Tz_inv @ theta_opt.C_z_star,
        D_yw=theta_opt.D_yw_star,
        D_zu=args.Tz_inv @ theta_opt.D_zu_star,
        f_static=create_custom_basis_function_model(
            nw, phi, beta
        ),
        ts=io_data.time.ts
    )


def _loss_fn(theta: _ThetaWZ, args: _OptiArgs) -> tuple:\
    
    f_full, fs, U, Y, G_yu = args.f_data

    A = args.theta_uy[0]
    B_u = args.theta_uy[1]
    C_y = args.theta_uy[2]

    B_w = theta.B_w_star
    C_z = args.Tz_inv @ theta.C_z_star
    D_yw = theta.D_yw_star
    D_zu = args.Tz_inv @ theta.D_zu_star

    ny, nw = D_yw.shape
    nz, nu = D_zu.shape
    F = U.shape[0]
    R = U.shape[2]

    Theta = jnp.vstack((B_w, D_yw)).T @ jnp.vstack((B_w, D_yw))

    z = 2 * jnp.pi * f_full / fs
    zj = jnp.exp(z * 1j)

    I_nw = jnp.eye(nw)
    I_nx = jnp.eye(A.shape[0])

    def _compute_parametric_Gs(k):
        G_x = jnp.linalg.solve(zj[k] * I_nx - A, jnp.hstack((B_u, B_w)))
        return (
            C_y @ G_x[:, nu:] + D_yw,  # G_yw
            C_z @ G_x[:, :nu] + D_zu,  # G_zu
            C_z @ G_x[:, nu:]          # G_zw
        )

    G_yw, G_zu, G_zw = jax.vmap(_compute_parametric_Gs)(jnp.arange(F))

    # --- Nonparametric inference ---
    def _infer_nonparametric_signals(k):
        Psi = G_yw[k, ...].T @ args.Lambda[k, ...]
        Phi = Psi @ G_yw[k, ...] + args.lambda_w * Theta + 1e-10 * I_nw
        W_hat = jnp.linalg.solve(
            Phi,
            Psi @ (Y[k, ...] - G_yu[k, ...] @ U[k, ...])
        )
        Z_hat = G_zu[k, ...] @ U[k, ...] + G_zw[k, ...] @ W_hat
        Y_hat = G_yu[k, ...] @ U[k, ...] + G_yw[k, ...] @ W_hat
        return W_hat, Z_hat, Y_hat

    W_star, Z_star, Y_hat = jax.vmap(_infer_nonparametric_signals)(jnp.arange(F))  # noqa: E501

    # --- Parametric learning ---
    w_star = jnp.fft.irfft(W_star, n=args.N, axis=0)
    z_star = jnp.fft.irfft(Z_star, n=args.N, axis=0)

    w_star_stacked = jnp.transpose(w_star, (2, 0, 1)).reshape(args.N * R, nw)
    z_star_stacked = jnp.transpose(z_star, (2, 0, 1)).reshape(args.N * R, nz)

    phi_z_star = args.phi.compute_features(z_star_stacked)
    beta_hat = jnp.linalg.solve(
        phi_z_star.T @ phi_z_star,
        phi_z_star.T @ w_star_stacked
    )

    # --- Fixed-point iterations ---
    def _fixed_point_iteration(_, phi_z):
        w_stacked = phi_z @ beta_hat
        w = jnp.transpose(w_stacked.reshape(R, args.N, nw), (1, 2, 0))
        W = jnp.fft.rfft(w, axis=0)
        Z = G_zu @ U + G_zw @ W
        z = jnp.fft.irfft(Z, n=args.N, axis=0)
        z_stacked = jnp.transpose(z, (2, 0, 1)).reshape(args.N * R, nz)
        return args.phi.compute_features(z_stacked)

    phi_z = jax.lax.fori_loop(
        0, args.fixed_point_iters, _fixed_point_iteration, phi_z_star, unroll=True  # noqa: E501
    )

    w_hat_stacked = phi_z @ beta_hat
    w_hat = jnp.transpose(w_hat_stacked.reshape(R, args.N, nw), (1, 2, 0))
    W_beta = jnp.fft.rfft(w_hat, axis=0)

    # --- Loss computation ---
    Y_hat = G_yu @ U + G_yw @ W_beta
    loss_Y = jnp.sqrt(args.Lambda / (R * args.N)) @ (Y - Y_hat)

    loss = (loss_Y.real, loss_Y.imag)

    MSE_loss = jnp.sum(jnp.abs(loss_Y)**2)
    return loss, (MSE_loss, beta_hat)


def _prepare_problem(
    io_data: InputOutputData,
    bla: ModelBLA,
    phi: AbstractBasisFunction,
    nw: int,
    lambda_w: float,
    fixed_point_iters: int,
    seed: int
) -> tuple[ModelNonlinearLFR, dict]:

    nz = phi.nz
    ny, nx = bla.C_y.shape
    N, nu, R, P = io_data.time.u.shape
    F = io_data.freq.U.shape[0]

    u = io_data.time.u.mean(axis=-1)

    # Initialize theta_wz
    key = jax.random.PRNGKey(seed)
    key_B_w, key_C_z, key_D_yw, key_D_zu = jax.random.split(key, 4)

    B_w_star = jax.random.normal(key_B_w, (nx, nw))
    C_z_star = jax.random.normal(key_C_z, (nz, nx))
    D_zu_star = jax.random.normal(key_D_zu, (nz, nu))
    D_yw_star = jax.random.normal(key_D_yw, (ny, nw))

    theta_wz = _ThetaWZ(B_w_star, C_z_star, D_yw_star, D_zu_star)
    theta_uy = (jnp.asarray(bla.A), jnp.asarray(bla.B_u), jnp.asarray(bla.C_y))

    # Compute z_star normalization
    beta_dummy = np.zeros((phi.num_features(), nw))
    f_static_dummy = create_custom_basis_function_model(
        nw, phi, beta_dummy
    )
    nonlin_lfr_dummy = ModelNonlinearLFR(
        A=bla.A,
        B_u=bla.B_u,
        C_y=bla.C_y,
        D_yu=bla.D_yu,
        B_w=np.zeros_like(B_w_star),
        C_z=C_z_star,
        D_yw=np.zeros_like(D_yw_star),
        D_zu=D_zu_star,
        f_static=f_static_dummy,
        ts=io_data.time.ts
    )
    handicap = int(np.ceil(0.25 * N))
    z_star = nonlin_lfr_dummy.simulate(u, handicap=handicap)[-1]
    z_star_min, z_star_max = z_star.min(axis=(0, 2)), z_star.max(axis=(0, 2))
    T_z_inv = jnp.diag(2 / (z_star_max - z_star_min))

    # Compute Lambda
    Lambda = np.zeros((F, ny, ny))

    Y = io_data.freq.Y
    Y_P = Y.mean(axis=3)  # Average over periods
    if P > 1:
        var_noise = ((np.abs(Y - Y_P[..., None])**2).sum(axis=(2, 3))
                     / R / (P - 1))
        for k in range(F):
            np.fill_diagonal(Lambda[k], 1 / var_noise[k])
    else:
        var_noise = None
        for k in range(F):
            np.fill_diagonal(Lambda[k], np.eye(ny))
            
    U_bar = jnp.asarray(io_data.freq.U.mean(axis=3))
    Y_bar = jnp.asarray(io_data.freq.Y.mean(axis=3))
    f_full = jnp.asarray(io_data.freq.f)
    G_yu = jnp.asarray(bla.frequency_response(f_full))
    f_data = (f_full, 1 / io_data.time.ts, U_bar, Y_bar, G_yu)

    args = _OptiArgs(
        theta_uy=theta_uy,
        phi=phi,
        lambda_w=jnp.asarray(lambda_w, dtype=jnp.float32),
        fixed_point_iters=fixed_point_iters,
        f_data=f_data,
        Lambda=jnp.asarray(Lambda),
        Tz_inv=jnp.asarray(T_z_inv),
        G_yu=jnp.asarray(bla.frequency_response(io_data.freq.f)),
        N=N
    )
    return theta_wz, args


Step 2: inference and learning

In [11]:
# Define the nonlinear basis function
polynomial_degree = 7
nw = 2
nz = 2
phi = basis_functions.Polynomial(nz, polynomial_degree)

import optax 

solver = optx.OptaxMinimiser(optax.adam(learning_rate=1e-3), rtol=1e-3, atol=1e-6)
solver = optx.BFGS(rtol=1e-3, atol=1e-6)


# Define inference and learning hyperparameters
lambda_w = 1
fixed_point_iterations = 5

# Solve the problem
nonlin_lfr = solverrr(io_data, bla_opti, phi, nw, lambda_w, fixed_point_iterations, solver, 10000, seed)

# Simulate and check time-domain performance
y_sim_nonlin_lfr = nonlin_lfr.simulate(u_bar, handicap=handicap)[0]
NRMSE_nonlin_lfr = 100 * np.sqrt(np.mean((y_bar - y_sim_nonlin_lfr)**2)) / np.sqrt(np.mean(y_bar**2)) 

print(f'NRMSE of nonlin_lfr: {NRMSE_nonlin_lfr:.2f}%\n')




Starting iterative optimization...
   Iteration 0, Loss: 2.2337e+04
   Iteration 1, Loss: 2.2337e+04
   Iteration 2, Loss: 2.2337e+04
   Iteration 3, Loss: 2.2337e+04
   Iteration 4, Loss: 2.2337e+04
   Iteration 5, Loss: 2.2337e+04
   Iteration 6, Loss: 2.2337e+04
   Iteration 7, Loss: 2.2337e+04
   Iteration 8, Loss: 2.2337e+04
   Iteration 9, Loss: 2.2337e+04
   Iteration 10, Loss: 2.2337e+04
   Iteration 11, Loss: 2.2337e+04
   Iteration 12, Loss: 1.9471e+04
   Iteration 13, Loss: 1.9471e+04
   Iteration 14, Loss: 1.9471e+04
   Iteration 15, Loss: 1.9471e+04
   Iteration 16, Loss: 1.9471e+04
   Iteration 17, Loss: 1.9471e+04
   Iteration 18, Loss: 1.9471e+04
   Iteration 19, Loss: 1.9471e+04
   Iteration 20, Loss: 1.9471e+04
   Iteration 21, Loss: 1.9471e+04
   Iteration 22, Loss: 1.9471e+04
   Iteration 23, Loss: 1.4297e+04
   Iteration 24, Loss: 1.4297e+04
   Iteration 25, Loss: 1.4297e+04
   Iteration 26, Loss: 1.4297e+04
   Iteration 27, Loss: 1.4297e+04
   Iteration 28, Loss: 

In [5]:
print(nonlin_lfr.f_static.beta)

[[ 2.1222291e-04 -2.4430049e-04]
 [-9.1545859e-05 -1.2276375e-04]
 [ 7.2191661e-04 -1.1256315e-03]
 [-1.3463055e-04 -2.7200184e-04]
 [ 4.3029609e-04  1.3532129e-04]
 [-1.9501754e-03  3.2395015e-03]
 [-5.0938519e-04  7.8019591e-05]
 [-4.4485510e-05 -5.3317507e-04]
 [ 5.7714613e-04 -2.1349544e-04]
 [-1.5150787e-03  2.0597121e-03]
 [ 9.7047101e-05 -1.1007484e-03]
 [-2.4303029e-06  2.7526354e-05]
 [ 4.6904184e-04  1.1575153e-04]
 [-6.5723871e-04  4.5630496e-04]
 [ 1.7549624e-03 -3.3680950e-03]
 [ 1.3637821e-03  1.5401845e-03]
 [-8.9733355e-04  7.8945112e-04]
 [ 5.9545587e-04 -3.4795579e-05]
 [ 1.5972934e-04 -4.5064121e-04]
 [-1.1526899e-04  2.1643669e-04]
 [-9.9257415e-04  2.5774606e-03]]


In [6]:
# Optimise one-step-ahead state predictions (optional)
ModelNonlinearLFR_opti = nonlinear_lfr_opti.optimise_state_predictions(
    IDM,
    solver=optx.BFGS(rtol=1e-8, atol=1e-8),
)

NameError: name 'nonlinear_lfr_opti' is not defined

In [None]:
# # Optimise one-step-ahead state predictions (optional)
# ModelNonlinearLFR_opti = nonlinear_lfr_opti.optimise_state_predictions(
#     IDM,
#     solver=optx.BFGS(rtol=1e-8, atol=1e-8),
# )

In [None]:
# Optimise simulation error# Optimise one-step-ahead state predictions (optional)
# ModelNonlinearLFR_opti = nonlinear_lfr_opti.optimise_state_predictions(
#     IDM,
#     solver=optx.BFGS(rtol=1e-8, atol=1e-8),
# )

from reinbos.utils import misc

custom_init = nonlinear_lfr_opti.DecisionVars(
    B_w=misc.OptiParam(1/10*ModelNonlinearLFR.B_w),
    D_yw=misc.OptiParam(1/10*ModelNonlinearLFR.D_yw),
)


ModelNonlinearLFR_opti_sim = nonlinear_lfr_opti.optimise_simulation_error(
    IDM,
    solver=optx.LevenbergMarquardt(rtol=1e-3, atol=1e-6),
    max_iter=200,
    custom_init=custom_init,
)

In [None]:
# Evaluate the model on test data
ParWH_test = [
    data for data in ParWH_full_test
    if data.name == f'Val-amp-{amp_level}' or data.name == 'ValArr'
]

test_ms = ParWH_test[0]
test_arr = ParWH_test[2]

# Multisine test
u_test_ms = np.transpose(test_ms.u.reshape(1, nu, N, 2), (2, 1, 0, 3))
y_test_ms = test_ms.y[::2]
u_test_ms = (u_test_ms - u_mean) / u_std

y_bla_ms = IDM.bla.opti.model.simulate(u_test_ms)[1]
y_lfr_ms = ModelNonlinearLFR_opti_sim.simulate(u_test_ms)[2]
y_bla_ms = np.squeeze(y_bla_ms[..., 1] * y_std + y_mean)
y_lfr_ms = np.squeeze(y_lfr_ms[..., 1] * y_std + y_mean)

e_bla_ms = y_test_ms - y_bla_ms
e_lfr_ms = y_test_ms - y_lfr_ms

E_bla_ms = 1 / N * np.fft.rfft(e_bla_ms, axis=0)
E_lfr_ms = 1 / N * np.fft.rfft(e_lfr_ms, axis=0)
Y_test_ms = 1 / N * np.fft.rfft(y_test_ms, axis=0)

print(f'Multisine test RMSE BLA: {np.sqrt(np.mean(e_bla_ms**2)):.4e} ({100*np.std(e_bla_ms)/np.std(y_test_ms):.2f}%)')
print(f'Multisine test RMSE nonlinear LFR: {np.sqrt(np.mean(e_lfr_ms**2)):.4e} ({100*np.std(e_lfr_ms)/np.std(y_test_ms):.2f}%)')

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].plot(IDM.data.time.t, y_test_ms, label='system output')
axs[0].plot(IDM.data.time.t, e_bla_ms, label='BLA error')
axs[0].plot(IDM.data.time.t, e_lfr_ms, label='nonlinear LFR error')
axs[0].set_title('Multisine - Time Domain')
axs[0].set_xlabel('time [s]')
axs[0].set_ylabel('amplitude [-]')
axs[0].legend()

axs[1].plot(IDM.data.freq.f[f_idx], 20*np.log10(np.abs(Y_test_ms[f_idx])), label='system output')
axs[1].plot(IDM.data.freq.f[f_idx], 20*np.log10(np.abs(E_bla_ms[f_idx])), label='BLA error')
axs[1].plot(IDM.data.freq.f[f_idx], 20*np.log10(np.abs(E_lfr_ms[f_idx])), label='nonlinear LFR error')
axs[1].set_title('Multisine - Frequency Domain')
axs[1].set_xlabel('frequency [Hz]')
axs[1].set_ylabel('magnitude [dB]')
axs[1].legend()
plt.tight_layout()
plt.show()


# Arrow test
u_test_arr = test_arr.u.reshape(-1, 1, 1)
u_test_arr = (u_test_arr - u_mean) / u_std
y_test_arr = test_arr.y

y_bla_ss = IDM.bla.opti.model.simulate(u_test_arr, P_trans=1)[1]
y_lfr_ss = ModelNonlinearLFR_opti_sim.simulate(u_test_arr, P_trans=1)[2]
y_bla_ss = np.squeeze(y_bla_ss * y_std + y_mean)
y_lfr_ss = np.squeeze(y_lfr_ss * y_std + y_mean)

e_bla_ss = y_test_arr - y_bla_ss
e_lfr_ss = y_test_arr - y_lfr_ss

print(f'Arrow test RMSE BLA: {np.sqrt(np.mean(e_bla_ss**2)):.4e} ({100*np.std(e_bla_ss)/np.std(y_test_arr):.2f}%)')
print(f'Arrow test RMSE nonlinear LFR: {np.sqrt(np.mean(e_lfr_ss**2)):.4e} ({100*np.std(e_lfr_ss)/np.std(y_test_arr):.2f}%)')

t_ss = np.linspace(0, len(y_test_arr) / fs, len(y_test_arr))

plt.figure(figsize=(8, 5))
plt.plot(t_ss, y_test_arr, label='system output')
plt.plot(t_ss, e_bla_ss, label='BLA error')
plt.plot(t_ss, e_lfr_ss, label='nonlinear LFR error')
plt.title('Arrow test - Time Domain')
plt.xlabel('time [s]')
plt.ylabel('amplitude [-]')
plt.legend()
plt.tight_layout()
plt.show()
