In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import constants as sc
import pickle

from DeepFMKit.experiments import Experiment
from DeepFMKit.factories import StandardWDFMIExperimentFactory
from DeepFMKit.waveforms import second_harmonic_distortion

# --- 1. Declaratively Define the Experiment ---
exp = Experiment(description="Systematic Error due to Modulation Non-Linearity and Calibration")
exp.set_config_factory(
    StandardWDFMIExperimentFactory(waveform_function=second_harmonic_distortion)
)

exp.add_axis('distortion_amp', np.linspace(0.00, 0.1, 11))
exp.set_static({
    'm_main': 20.0,
    'm_witness': 0.04,
})
exp.n_trials = 4 # Increased trials for better worst-case statistics
exp.n_seconds_per_trial = 0.5

exp.add_stochastic_variable(
    'waveform_kwargs', 
    lambda dist_amp: {'distortion_amp': dist_amp, 'distortion_phase': np.random.uniform(0, 2*np.pi)},
    depends_on='distortion_amp'
)

exp.add_analysis(name='wdfmi_fit', fitter_method='wdfmi_ortho', result_cols=['tau'])
exp.add_analysis(name='nls_fit', fitter_method='nls', result_cols=['m'], fitter_kwargs={'ndata': 30, 'parallel': False})

print("Experiment configured.")

# --- 2. Run the Experiment ---
results = exp.run()

results_filename = '4_panel_comparison_results.pkl'
with open(results_filename, 'wb') as f:
    pickle.dump(results, f)
print(f"Results saved to {results_filename}")

In [None]:
df_cal_error_frac = 0.01

# --- 1. Ground Truth and Experimenter's Estimate Calculation ---
sample_params = exp.get_params_for_point(axis_idx=0)
configs = exp.config_factory(sample_params)
opd_true = configs['main_ifo_config'].meas_arml - configs['main_ifo_config'].ref_arml
tau_true = opd_true / sc.c
wavelength = configs['laser_config'].wavelength
df_true = configs['laser_config'].df
df_est = df_true * (1.0 - df_cal_error_frac)

# --- 2. Extract and Process Fitter Results ---
wdfmi_tau_all = results['wdfmi_fit']['tau']['all_trials']
nls_m_all = results['nls_fit']['m']['all_trials']
nls_tau_all = np.array(nls_m_all) / (2 * np.pi * df_est)

# Calculate absolute length error for all trials
wdfmi_len_err_all = (wdfmi_tau_all - tau_true) * sc.c
nls_len_err_all = (nls_tau_all - tau_true) * sc.c

# --- 3. Calculate Statistics ---
# Mean and Standard Deviation
wdfmi_mean_err = np.nanmean(wdfmi_len_err_all, axis=-1)
wdfmi_std_err = np.nanstd(wdfmi_len_err_all, axis=-1)
nls_mean_err = np.nanmean(nls_len_err_all, axis=-1)
nls_std_err = np.nanstd(nls_len_err_all, axis=-1)

# Worst-Case (Maximum Absolute) Error
wdfmi_worst_err = np.nanmax(np.abs(wdfmi_len_err_all), axis=-1)
nls_worst_err = np.nanmax(np.abs(nls_len_err_all), axis=-1)

# --- 4. Plotting ---
fig, axes = plt.subplots(2, 2, figsize=(14, 10), sharex=True, sharey=True)
fig.suptitle(exp.description, fontsize=18, y=0.95)

dist_axis_pct = results['axes']['distortion_amp'] * 100
ambiguity_limit_um = (wavelength / 2) * 1e6

# --- Panel 1: NLS Mean Bias ---
ax = axes[0, 0]
ax.errorbar(dist_axis_pct, nls_mean_err * 1e6, yerr=nls_std_err * 1e6,
            fmt='o-', capsize=4, color='tab:red', label='Mean ± 1σ')
ax.set_title('Standard DFMI: Mean Error', fontsize=14)
ax.set_ylabel(r'Absolute Length Error ($\mu$m)', fontsize=12)
ax.legend()

# --- Panel 2: W-DFMI Mean Bias ---
ax = axes[0, 1]
ax.errorbar(dist_axis_pct, wdfmi_mean_err * 1e6, yerr=wdfmi_std_err * 1e6,
            fmt='s-', capsize=4, color='tab:green', label='Mean ± 1σ')
ax.set_title('W-DFMI: Mean Error', fontsize=14)
ax.legend()

# --- Panel 3: NLS Worst-Case Bias ---
ax = axes[1, 0]
ax.plot(dist_axis_pct, nls_worst_err * 1e6, 'o-', color='maroon', label='Max Absolute Error')
ax.set_title('Standard DFMI: Worst-Case Error', fontsize=14)
ax.set_xlabel('2nd Harmonic Distortion (%)', fontsize=12)
ax.set_ylabel(r'Absolute Length Error ($\mu$m)', fontsize=12)
ax.legend()

# --- Panel 4: W-DFMI Worst-Case Bias ---
ax = axes[1, 1]
ax.plot(dist_axis_pct, wdfmi_worst_err * 1e6, 's-', color='darkgreen', label='Max Absolute Error')
ax.set_title('W-DFMI: Worst-Case Error', fontsize=14)
ax.set_xlabel('2nd Harmonic Distortion (%)', fontsize=12)
ax.legend()

# Apply common formatting to all panels
for ax_row in axes:
    for ax in ax_row:
        ax.axhline(ambiguity_limit_um, color='k', linestyle='--', linewidth=1.5, alpha=0.8, label=r'$\pm\lambda_0/2$ Limit')
        ax.axhline(-ambiguity_limit_um, color='k', linestyle='--', linewidth=1.5, alpha=0.8)
        ax.axhline(0, color='k', linestyle='-', linewidth=1, alpha=0.5)
        ax.grid(True, linestyle=':', alpha=0.7)

# Clean up duplicate legends from the ambiguity limit lines
handles, labels = axes[0,0].get_legend_handles_labels()
by_label = dict(zip(labels, handles))
axes[0,0].legend(by_label.values(), by_label.keys())

handles, labels = axes[0,1].get_legend_handles_labels()
by_label = dict(zip(labels, handles))
axes[0,1].legend(by_label.values(), by_label.keys())

plt.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust for suptitle
plt.show()