In [1]:
import os, copy
import numpy as np
import matplotlib.pyplot as plt
os.environ["JAX_PLATFORMS"] = "cpu"
import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True) 
import jax.numpy as jnp

from KalmanMagnetometry import *

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
np.random.seed(123)

# simulate signal

In [4]:
### SIMULATION PARAMETERS
f_sampling = 500
t2 = 3142.64 # s, fitted from experimental data

snr_dB = 10.0
var_mn = 1e2

f_0 = 84.61544342
D = 1e-10
L_signal = 5_400_000 # 3 h

In [5]:
t = np.arange(L_signal) / f_sampling
mn = np.sqrt(var_mn) * np.random.randn(L_signal)

snr_lin = 10.0**(snr_dB / 10)
A_0 = np.sqrt(2 * snr_lin * var_mn)
A_true = A_0 * np.exp(-t/t2)


var_fn = 2 * D / f_sampling

# no linear drift
f_func_lin = lambda tt: f_0 + (tt / t[-1]) * 0.0


fn = np.cumsum(np.sqrt(var_fn) * np.random.randn(L_signal)) #/ f_sampling
y, f_true, phase = generate_sine_wave_with_fn(t, f_func_lin, fn)
y = A_true * y + mn


In [6]:
data_results = {}

# Apply SinCos Fit

In [7]:

block_size_seconds = np.logspace(np.log10(10.0), np.log10(t[-2]), 40, endpoint=True)

# block_size_seconds = np.arange(1, 20, dtype=int)
block_size_samples = [int(np.floor(bz*f_sampling)) for bz in block_size_seconds]
n_blocksizes_sc = len(block_size_seconds)

# save for plotting
times_sc = []
freqs_sc = []
freqs_std_sc = []
amps_sc = []
amps_std_sc = []
mses_f_sc = []

for i, block_size in enumerate(block_size_seconds):
    try:
        sc_fit = Measurement(y.squeeze(), f_sampling)
        sc_fit.make_sincos_fit(f_0, block_size)

        t_sc = sc_fit.sincos_fit.time.squeeze()
        f_sc = sc_fit.sincos_fit.frequency.squeeze()
        f_std_sc = sc_fit.sincos_fit.frequency_std.squeeze()
        amp_sc = sc_fit.sincos_fit.amp.squeeze()
        amp_std_sc = sc_fit.sincos_fit.amp_std.squeeze()


        f_sc_dense = gen_dense_series_from_blocks(t, t_sc, f_sc)
        amp_sc_dense = gen_dense_series_from_blocks(t, t_sc, amp_sc)

        data_results[f"mse_f_sc_{block_size}"] = np.mean((f_true - f_sc_dense)**2)
        data_results[f"mse_amp_sc_{block_size}"] = np.mean((A_true - amp_sc_dense)**2)

        mses_f_sc.append(data_results[f"mse_f_sc_{block_size}"])

        # calculate percentage of true samples within uncertainty bounds
        f_std_sc_dense = gen_dense_series_from_blocks(t, t_sc, f_std_sc)
        data_results[f"p_f_1sigma_sc_{block_size}"] = np.sum(np.abs(f_true - f_sc_dense) <= f_std_sc_dense) / L_signal
        data_results[f"p_f_2sigma_sc_{block_size}"] = np.sum(np.abs(f_true - f_sc_dense) <= 2 * f_std_sc_dense) / L_signal

        amp_std_sc_dense = gen_dense_series_from_blocks(t, t_sc, amp_std_sc)
        data_results[f"p_amp_1sigma_sc_{block_size}"] = np.sum(np.abs(A_true - amp_sc_dense) <= amp_std_sc_dense) / L_signal
        data_results[f"p_amp_2sigma_sc_{block_size}"] = np.sum(np.abs(A_true - amp_sc_dense) <= 2 * amp_std_sc_dense) / L_signal

        # save for plotting
        times_sc.append(sc_fit.sincos_fit.time)
        # # init_freqs_cs.append(sc_fit.sincos_fit.est_init_freq)
        freqs_sc.append(sc_fit.sincos_fit.frequency)
        freqs_std_sc.append(sc_fit.sincos_fit.frequency_std)
        amps_sc.append(sc_fit.sincos_fit.amp)
        amps_std_sc.append(sc_fit.sincos_fit.amp_std)

    except:
        mses_f_sc.append(np.inf)
        continue

idx_min_f_mse_sc = np.argmin(np.array(mses_f_sc))

# apply EKS

### perform EM

In [None]:
f_ini_est = f_0
A_ini_est = A_0

In [None]:
# effective filter bandwidth.
# for small drift, L = 1 should always be fine
L = 1
blocklength_s = 4.5 # sec


t_blocks_ekf, y_ekf, T, N, M, delta_f_ini = prepare_signal_stft(y, f_sampling, blocklength_s, f_ini_est, L=L)

# state space model
sm, DIM_Q, DIM_R, GQ = get_stft_model_int_freq_int_amp(M, L, T)
logamp = False


Q_ini =  jnp.diag(10.0**jnp.array([
    -5.0,  # amp, this value should not matter
    -5.0, # phase, this value should not matter
    -5.0, # freq, this value should not matter
    -5.0,  # integrated freq
    -4.0,  # integrated amp
    # -10.0,  # perturbation
    ][:DIM_Q])) # for const freq, lin amp
P_ini = jnp.diag(10.0**jnp.array([
    -1.0,  # amp
    -1.0, # phase
    -1.0, # freq
    -2.0,  # integrated freq
    -2.0,  # integrated amp
    # -10.0,  # perturbation
    ][:DIM_Q])) # for const freq, lin amp

m_ini = jnp.array([(A_ini_est/(4*np.pi))**0.5, 1.0, delta_f_ini, 0, 0][:DIM_Q]).reshape((DIM_Q, 1))

params = {
    "Q": copy.copy(Q_ini),
    "P_ini": copy.copy(P_ini),
    "R": 10.0**(2.0) * jnp.eye(DIM_R),
    "m_ini": m_ini,
    "f_sampling": f_sampling,
    "GQ": GQ,
    "alpha_R": jnp.array(0.995),
    "alpha_Q": jnp.array(0.995),
}

data_results["Q_ini"] = params["Q"]

init_params = copy.deepcopy(params)

## parameter optimization
progress = nc.jax.RunProgress(0)
L_EM = N

# ### EM bisection ###
q_bisection_idx = [i for i in range(DIM_Q) if GQ[i,i]!=0.0]
alpha_em = 0.8

params, progress = EM_bisection(y_ekf, sm, params, L_EM, n_iter_burn=200, n_iter_bisect = 20, n_max_switches=6, q_factor_ini=100.0, idx_qs=q_bisection_idx, alpha_Q=0.0, bisection_root_factor=0.75, alpha=alpha_em, verbose=False)

## regular EM
l_train_list = [L_EM] * 4
max_iter_list = [200, 200, 200, 400]

params, progress = perform_EM(y_ekf, l_train_list, max_iter_list, sm, params, progress=progress, verbose=True, alpha=alpha_em)


save_params = copy.deepcopy(params)


In [None]:
fig, ax = plot_optimization_progress_stft(progress, M,T,f_sampling)

# Apply EKF + EKS

In [None]:

### apply EKF
print("Apply EKF")

w_filter = np.arange(0, N)
w_filter_full = np.arange(0, N*T)
L_filter = len(w_filter)

m_est, P_est, m_smooth, P_smooth, aux_data = apply_filter(y_ekf, w_filter, sm, save_params)
# sig_rec, ekf_residuals = reconstruct_signal(y_ekf, w_filter, m_est, sm, save_params)

## calculate amps and freqs from tracked variables ##
data_filter = {}
data_filter["t_blocks"] = t_blocks_ekf
### AMPLITUDE ###
amp_ekf = 4*np.pi*(m_est[:,0,0]**2)
sigma_amp_ekf = np.sqrt(
    (8*np.pi*m_est[:,0,0])**2 * P_est[:,0,0] + 0.5 * (8*np.pi)**2 * P_est[:,0,0]**2)

amp_eks = 4*np.pi*(m_smooth[:,0,0]**2)
sigma_amp_eks = np.sqrt(
    (8*np.pi*m_smooth[:,0,0])**2 * P_smooth[:,0,0] + 0.5 * (8*np.pi)**2 * P_smooth[:,0,0]**2)

data_filter["amp_ekf"] = amp_ekf
data_filter["sigma_amp_ekf"] = sigma_amp_ekf
data_filter["amp_eks"] = amp_eks
data_filter["sigma_amp_eks"] = sigma_amp_eks

### FREUQNCY ###
f_ekf = (m_est[:,2,0]+M)/ (T) *f_sampling
f_eks = (m_smooth[:,2,0]+M)/ (T) *f_sampling

sigma_f_ekf = np.sqrt( (P_est[:,2,2])/ (T) *f_sampling )
sigma_f_eks = np.sqrt( (P_smooth[:,2,2])/ (T) *f_sampling )

data_filter["f_ekf"] = f_ekf
data_filter["sigma_f_ekf"] = sigma_f_ekf
data_filter["f_eks"] = f_eks
data_filter["sigma_f_eks"] = sigma_f_eks

# calculate MSE

In [None]:
####################### CALCULATE METRICS ######################

### AMPLITUDE ###
amp_ekf_dense = gen_dense_series_from_blocks(t[w_filter_full], t_blocks_ekf[w_filter], amp_ekf)
amp_eks_dense = gen_dense_series_from_blocks(t[w_filter_full], t_blocks_ekf[w_filter], amp_eks)
sigma_amp_ekf_dense = gen_dense_series_from_blocks(t[w_filter_full], t_blocks_ekf[w_filter], sigma_amp_ekf)
sigma_amp_eks_dense = gen_dense_series_from_blocks(t[w_filter_full], t_blocks_ekf[w_filter], sigma_amp_eks)

mse_amp_ekf = np.mean((A_true[w_filter_full] - amp_ekf_dense)**2)
mse_amp_eks = np.mean((A_true[w_filter_full] - amp_eks_dense)**2)
p_amp_ekf = np.sum(np.abs(A_true[w_filter_full] - amp_ekf_dense) < sigma_amp_ekf_dense) / w_filter_full.shape[0]
p_amp_eks = np.sum(np.abs(A_true[w_filter_full] - amp_eks_dense) < sigma_amp_eks_dense) / w_filter_full.shape[0]

data_results["mse_amp_ekf"] = mse_amp_ekf
data_results["mse_amp_eks"] = mse_amp_eks

data_results[f"p_amp_1sigma_ekf"] = p_amp_ekf
data_results[f"p_amp_1sigma_eks"] = p_amp_eks

### FREUQNCY ###
f_ekf_dense = gen_dense_series_from_blocks(t[w_filter_full], t_blocks_ekf[w_filter], f_ekf)
f_eks_dense = gen_dense_series_from_blocks(t[w_filter_full], t_blocks_ekf[w_filter], f_eks)
sigma_f_ekf_dense = gen_dense_series_from_blocks(t[w_filter_full], t_blocks_ekf[w_filter], sigma_f_ekf)
sigma_f_eks_dense = gen_dense_series_from_blocks(t[w_filter_full], t_blocks_ekf[w_filter], sigma_f_eks)

mse_f_ekf = np.mean((f_true[w_filter_full] - f_ekf_dense)**2)
mse_f_eks = np.mean((f_true[w_filter_full] - f_eks_dense)**2)
p_f_ekf = np.sum(np.abs(f_true[w_filter_full] - f_ekf_dense) < sigma_f_ekf_dense) / w_filter_full.shape[0]
p_f_eks = np.sum(np.abs(f_true[w_filter_full] - f_eks_dense) < sigma_f_eks_dense) / w_filter_full.shape[0]


data_results["mse_f_ekf"] = mse_f_ekf
data_results["mse_f_eks"] = mse_f_eks

data_results[f"p_f_1sigma_ekf"] = p_f_ekf
data_results[f"p_f_1sigma_eks"] = p_f_eks


In [None]:
print(f"MSE F EKF: {data_results['mse_f_ekf']:.2e}")
print(f"MSE F EKS: {data_results['mse_f_eks']:.2e}")


print(f"MSE F SC BEST: {mses_f_sc[idx_min_f_mse_sc]:.2e} @ BL {block_size_seconds[idx_min_f_mse_sc]:.1f} s blocklength")


# plot results

In [None]:
fig, ax = plot_tracked_vars_stft(data_filter, n_sigma=2)

ax[0].plot(t, f_true, color="red", linestyle="--", zorder=-4, label="True")
ax[0].errorbar(times_sc[idx_min_f_mse_sc], freqs_sc[idx_min_f_mse_sc], yerr=2*freqs_std_sc[idx_min_f_mse_sc], color="green", capsize=5, linestyle="", marker="^", label="best SC +- 2 sigma")
ax[1].plot(t, A_true, color="red", linestyle="--", zorder=4, label="True")
ax[1].errorbar(times_sc[idx_min_f_mse_sc], amps_sc[idx_min_f_mse_sc], yerr=amps_std_sc[idx_min_f_mse_sc], color="green", capsize=5, linestyle="", marker="^", label="best SC +- 2 sigma")
for a in ax:
    a.legend()