In [None]:
import DeepFMKit.physics as physics
from DeepFMKit.helpers import set_laser_df_for_effect
from DeepFMKit.fitters import StandardNLSFitter, EKFFitter

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tqdm import tqdm

In [None]:
f_samp = 200e3
f_mod = 1e3

laser = physics.LaserConfig()
laser.f_mod = f_mod

ifo = physics.InterferometerConfig()
main = physics.DFMIObject(label='Hello', laser_config=laser, ifo_config=ifo, f_samp=f_samp)
sg = physics.SignalGenerator()

set_laser_df_for_effect(laser, ifo, 6.)

In [None]:
dB_list = np.linspace(20,120,40)

viridis = cm.get_cmap('viridis_r', len(dB_list))  # 6 colors for 6 curves

fig, ax = plt.subplots(figsize=(4,2), dpi=300)
for i, dB in enumerate(dB_list):
    raw = sg.generate(main, n_seconds=1/f_mod, mode='snr', snr_db=dB, trial_num=3)['main']
    ax = raw.plot(ax=ax, ls='--', lw=0.8, color=viridis(i))
    ax.grid(False)

plt.show()

In [None]:
nls = StandardNLSFitter({'n': 1, 'ndata': 15})
ekf = EKFFitter({'n': 1})

print(nls.fit(raw, parallel=False))
print(ekf.fit(raw, verbose=False))

In [None]:
# Calculate R (raw samples per fit buffer) based on the sampling frequency
B = int(f_samp / laser.f_mod)

# Calculate the actual simulation time in seconds.
n_seconds_to_simulate = B / f_samp

In [None]:
from joblib import Parallel, delayed

num_trials = 500

def process_trial(dB, trial_num):
    raw = sg.generate(main, n_seconds=1/f_mod, mode='snr', snr_db=dB, trial_num=trial_num)['main']
    df_nls = nls.fit(raw, parallel=False)
    df_ekf = ekf.fit(raw, verbose=False)
    return df_nls['m'].iloc[0], df_ekf['m'].iloc[0]

nls_result = []
ekf_result = []

for dB in tqdm(dB_list):
    results = Parallel(n_jobs=-1, backend='loky')(
        delayed(process_trial)(dB, j) for j in range(num_trials)
    )
    nls_result.append([r[0] for r in results])
    ekf_result.append([r[1] for r in results])

In [None]:
nls_var = []
ekf_var = []

for i, dB in enumerate(dB_list):
    nls_var.append(np.var(nls_result[i]))
    ekf_var.append(np.var(ekf_result[i]))

In [None]:
fig, ax = plt.subplots(figsize=(3,2))

ax.semilogy(dB_list, 2*B*np.array(nls_var), label='NLS')
ax.semilogy(dB_list, 2*B*np.array(ekf_var), label='EKF', ls='--')
ax.semilogy(dB_list, (np.sqrt(8)/10**(dB_list/20))**2, ls=':', label='CRLB')

ax.legend(edgecolor='k', framealpha=1)
ax.set_xlabel('Voltage SNR (dB)')
ax.set_ylabel(r'Variance $\delta m^2$')
plt.show()