In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import constants as sc
import pickle

from DeepFMKit import physics
from DeepFMKit.experiments import Experiment
from DeepFMKit.waveforms import second_harmonic_distortion

def create_distortion_configs(params: dict) -> dict:
    """
    Takes a dictionary of trial parameters and returns the fully configured
    physics objects needed for the simulation.
    """
    m_main = params['m_main']
    m_witness = params['m_witness']
    distortion_amp = params['distortion_amp']
    distortion_phase = params.get('distortion_phase', 0.0)

    # --- Create LaserConfig with the custom waveform ---
    laser_config = physics.LaserConfig()
    laser_config.waveform_func = second_harmonic_distortion
    # The parameters for the waveform function are passed via waveform_kwargs
    laser_config.waveform_kwargs = {
        'distortion_amp': distortion_amp,
        'distortion_phase': distortion_phase
    }

    main_ifo_config = physics.InterferometerConfig()
    
    # Configure laser df to achieve the target m_main
    opd_main = main_ifo_config.meas_arml - main_ifo_config.ref_arml
    if opd_main == 0: opd_main = 0.2
    laser_config.df = (m_main * sc.c) / (2 * np.pi * opd_main)
    
    # Configure the witness interferometer
    witness_ifo_config = physics.InterferometerConfig()
    if laser_config.df > 0 and 'm_witness' in params:
        opd_witness = (m_witness * sc.c) / (2 * np.pi * laser_config.df)
        witness_ifo_config.ref_arml = 0.01
        witness_ifo_config.meas_arml = witness_ifo_config.ref_arml + opd_witness
        # Automatically set the witness to the mid-fringe point for max sensitivity
        f0 = sc.c / laser_config.wavelength
        static_fringe_phase = (2 * np.pi * f0 * opd_witness) / sc.c
        witness_ifo_config.phi = (np.pi / 2.0) - static_fringe_phase
    
    return {
        'laser_config': laser_config,
        'main_ifo_config': main_ifo_config,
        'witness_ifo_config': witness_ifo_config
    }

In [None]:
exp = Experiment(description="Systematic Error from Modulation Non-Linearity")

m_true = 20.0

# Define the sweep axis and static parameters
exp.add_axis('distortion_amp', np.linspace(0, 0.2, 11))
exp.set_static({
    'm_main': m_true,
    'm_witness': 1.0
})

# Define the Monte Carlo simulation over the distortion phase
exp.n_trials = 100
exp.add_stochastic_variable('distortion_phase', lambda: np.random.uniform(0, 2 * np.pi))

# Set the function that creates the physics configs for each trial
exp.set_config_factory(create_distortion_configs)

# Define the two analyses to run on the *same* simulated data
exp.add_analysis(name='wdfmi_fit', fitter_method='wdfmi_ortho')
exp.add_analysis(name='nls_fit', fitter_method='nls', fitter_kwargs={'parallel': False})

results = exp.run()

In [None]:
results_filename = 'nonlinearity_analysis_refactored.pkl'
if results:
    with open(results_filename, 'wb') as f:
        pickle.dump(results, f)
    print(f"Results saved to {results_filename}")

# --- 6. Analysis and Plotting (no changes needed here) ---
dist_range = results['axes']['distortion_amp']
m_true = exp.static_params['m_main']

# Y-Axis Conversion
f0 = sc.c / (1064 * 1e-9)
opd_main = 0.2 # From the config factory
df_true = (m_true * sc.c) / (2 * np.pi * opd_main)
m_to_phase_error_factor = -f0 / df_true

# Calculate Stats from Raw 'm' Data
wdfmi_m_all_trials = results['wdfmi_fit']['m']['all_trials']
dfmi_m_all_trials = results['nls_fit']['m']['all_trials']
wdfmi_bias_all = wdfmi_m_all_trials - m_true
dfmi_bias_all = dfmi_m_all_trials - m_true
wdfmi_phase_error_all = wdfmi_bias_all * m_to_phase_error_factor
dfmi_phase_error_all = dfmi_bias_all * m_to_phase_error_factor
wdfmi_phase_error_mean = np.nanmean(wdfmi_phase_error_all, axis=-1)
wdfmi_phase_error_std = np.nanstd(wdfmi_phase_error_all, axis=-1)
dfmi_phase_error_mean = np.nanmean(dfmi_phase_error_all, axis=-1)
dfmi_phase_error_std = np.nanstd(dfmi_phase_error_all, axis=-1)

# Plotting
fig, ax = plt.subplots(figsize=(12, 7))
ax.plot(dist_range * 100, dfmi_phase_error_mean, 's-', color='tab:red', lw=2.5, label='Conventional DFMI Mean Error')
ax.fill_between(dist_range * 100,
                dfmi_phase_error_mean - dfmi_phase_error_std,
                dfmi_phase_error_mean + dfmi_phase_error_std,
                color='tab:red', alpha=0.2, label='Conventional DFMI ±1σ')
ax.plot(dist_range * 100, wdfmi_phase_error_mean, 'o-', color='tab:blue', lw=2.5, label='W-DFMI Mean Error')
ax.fill_between(dist_range * 100,
                wdfmi_phase_error_mean - wdfmi_phase_error_std,
                wdfmi_phase_error_mean + wdfmi_phase_error_std,
                color='tab:blue', alpha=0.2, label='W-DFMI ±1σ')
ambiguity_limit_rad = np.pi
ax.axhline(ambiguity_limit_rad, color='k', linestyle='--', linewidth=2, label=r'Ambiguity Limit ($\pm\pi$)')
ax.axhline(-ambiguity_limit_rad, color='k', linestyle='--', linewidth=2)
ax.set_xlabel('2nd Harmonic Distortion Amplitude (%)', fontsize=14)
ax.set_ylabel(r'Coarse Phase Error, $\delta\Phi_{\rm coarse}$ (rad)', fontsize=14)
ax.set_title(f"Systematic Error from Modulation Non-Linearity ($m = {m_true:.1f}$)", fontsize=16)
ax.grid(True, which='both', linestyle=':')
ax.legend(fontsize=12)
plt.tight_layout()
plt.show()