In [1]:
import numpy as np
import freq_statespace as fss

import math

import numpy as np
import optimistix as optx
import optax

# Load data
amplitude = 300  # [mV] either 100, 200, 300

u_train, y_train = np.load(f"data/u_{amplitude}mV_train.npy"), np.load(f"data/y_{amplitude}mV_train.npy")

N = u_train.shape[0]

nu, ny = u_train.shape[1], y_train.shape[1]

fs = 6400  # [Hz]
f_max = 3000  # [Hz] 
f_idx = np.arange(1, math.ceil(f_max / (fs / N)))  # frequency lines of interest (excludes DC)

data = fss.create_data_object(u_train, y_train, f_idx, fs)

In [None]:
# Fit Best Linear Approximation
nx = 28
q = nx + 1

solver = optx.OptaxMinimiser(optax.adam(learning_rate=1e-4), rtol=1e-3, atol=1e-5)
max_iter = 1000 if amplitude == 100 else 5000  # otherwise overfits too much to noise at poorest SNR

bla = fss.lin.subspace_id(data, nx, q)
bla = fss.lin.optimize(bla, data, solver=solver, max_iter=max_iter)



In [4]:
## Evaluate BLA on test data at all 3 amplitudes

amplitudes = [100, 200, 300]
for amp in amplitudes:
    
    u_test, y_test = np.load(f"data/u_{amp}mV_test.npy"), np.load(f"data/y_{amp}mV_test.npy")
    
    N, nu, R_test, P_test = u_test.shape
    ny = y_test.shape[1]
    
    # Stack the periods in the first dimension (for compatibility with _simulate())
    u_test_stacked = np.transpose(u_test, (0, 3, 1, 2)).reshape(N * P_test, nu, R_test, order="F")
    y_test_stacked = np.transpose(y_test, (0, 3, 1, 2)).reshape(N * P_test, ny, R_test, order="F")
    
    # Normalise input 
    u_test_stacked = (u_test_stacked - data.norm.u_mean.reshape(1, -1, 1)) / data.norm.u_std.reshape(1, -1, 1)
    
    # Simulate
    y_test_sim = bla._simulate(u_test_stacked, offset=1000)[0]
    
    # Denormalise output
    y_test_sim = y_test_sim * data.norm.y_std.reshape(1, -1, 1) + data.norm.y_mean.reshape(1, -1, 1)
    
    # Compute NRMSE per output channel
    error = y_test_stacked - y_test_sim
    mse = np.mean(error**2, axis=(0, 2))
    norm = np.mean(y_test_stacked**2, axis=(0, 2))
    nrmse = 100 * np.sqrt(mse / norm) 
    
    print(f"BLA test error at {amp}mV")
    for i, val in enumerate(nrmse):
        print(f"    Output {i+1}: {val:.2f}%")
    print()


BLA test error at 100mV
    Output 1: 26.42%
    Output 2: 40.76%
    Output 3: 30.97%

BLA test error at 200mV
    Output 1: 11.91%
    Output 2: 17.56%
    Output 3: 14.34%

BLA test error at 300mV
    Output 1: 4.44%
    Output 2: 4.57%
    Output 3: 4.58%

