In [None]:
import DeepFMKit.core as dfm
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

def validate_wdfmi_bias_correction(
    m_main=15.5,
    m_witness=0.05,
    ndata=15,
    distortion_range=np.linspace(0, 0.05, 11),
    n_phase_trials=50,
):
    """
    Generates a plot to validate the bias correction capability of W-DFMI.

    This script simulates a DFMI system with known amounts of modulation
    non-linearity. For each distortion level, it generates a primary (main)
    channel and a witness channel. Both channels are driven by the same
    distorted laser modulation waveform.

    It then processes the main channel's data using two different methods:
    1. The standard NLS fitter (`fit`), which is expected to be biased.
    2. The new W-DFMI fitter (`fit_wdfmi`), which uses the witness channel
       to correct for the distortion.

    The resulting plot compares the bias in the estimated modulation depth 'm'
    from both fitters, providing a clear demonstration of W-DFMI's ability to
    eliminate the systematic error.

    Parameters
    ----------
    m_main : float, optional
        The ground-truth modulation depth for the main channel.
    m_witness : float, optional
        The target modulation depth for the witness channel.
    ndata : int, optional
        The number of harmonics to use in the fit.
    distortion_range : array_like, optional
        The range of fractional distortion amplitudes (epsilon) to test.
    n_phase_trials : int, optional
        The number of random phase trials for the distortion.
    """
    print("=" * 60)
    print("Validating W-DFMI Bias Correction vs. Modulation Non-Linearity")
    print(f"Parameters: m_main={m_main}, m_witness={m_witness}, ndata={ndata}")
    print("=" * 60)

    # Lists to store the final statistics for plotting
    standard_bias_mean = []
    standard_bias_std = []
    wdfmi_bias_mean = []
    wdfmi_bias_std = []

    # --- Outer Loop: Iterate over distortion amplitude ---
    for eps in tqdm(distortion_range, desc="Distortion Level"):
        standard_biases = []
        wdfmi_biases = []

        # --- Inner Loop: Monte Carlo over distortion phase ---
        for i in range(n_phase_trials):
            dff = dfm.DeepFitFramework()

            # 1. Create the main channel configuration
            laser_config = dfm.LaserConfig()
            main_ifo_config = dfm.InterferometerConfig()
            # Set distortion for this trial
            laser_config.df_2nd_harmonic_frac = eps
            laser_config.df_2nd_harmonic_phase = np.random.uniform(0, 2 * np.pi)
            
            main_label = f"main_{eps:.2f}_{i}"
            main_channel = dfm.DFMIObject(main_label, laser_config, main_ifo_config)
            
            # Adjust laser df to achieve target m_main
            opd_main = main_ifo_config.meas_arml - main_ifo_config.ref_arml
            laser_config.df = (m_main * dfm.sc.c) / (2 * np.pi * opd_main)
            
            dff.sims[main_label] = main_channel

            # 2. Create the linked witness channel using the helper
            witness_label = f"witness_{eps:.2f}_{i}"
            witness_channel = dff.create_witness_channel(
                main_channel_label=main_label,
                witness_channel_label=witness_label,
                m_witness=m_witness
            )
            witness_channel.ifo.phi = 0.0

            # 3. Simulate both channels (noiseless for this test)
            n_seconds = main_channel.fit_n / laser_config.f_mod
            dff.simulate(main_label, n_seconds=n_seconds, witness_label=witness_label)
            
            # 4. Fit with Standard NLS Fitter
            fit_obj_std = dff.fit(main_label, fit_label="std_fit", verbose=False, parallel=False)
            if fit_obj_std and fit_obj_std.m.size > 0:
                standard_biases.append(fit_obj_std.m[-1] - m_main)

            # 5. Fit with new W-DFMI Fitter
            fit_obj_wdfmi = dff.fit_wdfmi(main_label, witness_label, fit_label="wdfmi_fit", verbose=False)
            if fit_obj_wdfmi and fit_obj_wdfmi.m.size > 0:
                wdfmi_biases.append(fit_obj_wdfmi.m[-1] - m_main)

        # Calculate statistics for this distortion level
        standard_bias_mean.append(np.mean(standard_biases))
        standard_bias_std.append(np.std(standard_biases))
        wdfmi_bias_mean.append(np.mean(wdfmi_biases))
        wdfmi_bias_std.append(np.std(wdfmi_biases))

    print(f'Mean of the mean: {np.mean(wdfmi_bias_mean)}')
    print(f'Mean of the standar deviation: {np.mean(wdfmi_bias_std)}')

    # --- Plotting ---
    fig, ax = plt.subplots(figsize=(12, 7))

    # Plot the bias from the standard fitter
    ax.errorbar(distortion_range * 100, standard_bias_mean, yerr=standard_bias_std,
                fmt='o-', capsize=5, color='tab:red', label='Standard NLS Fitter Bias')

    # Plot the bias from the W-DFMI fitter
    ax.errorbar(distortion_range * 100, wdfmi_bias_mean, yerr=wdfmi_bias_std,
                fmt='s-', capsize=5, color='tab:green', label='W-DFMI Fitter Bias (Corrected)')

    ax.axhline(0, color='k', linestyle='--', linewidth=1, alpha=0.7)
    ax.set_xlabel('2nd Harmonic Distortion Amplitude (%)', fontsize=14)
    ax.set_ylabel(r"Bias in Modulation Depth, $\delta m = \hat{m} - m_{\rm true}$ (rad)", fontsize=14)
    ax.set_title('W-DFMI Correction of Systematic Bias from Modulation Non-Linearity', fontsize=16)
    ax.grid(True, which='both', linestyle=':')
    ax.legend(fontsize=12)
    plt.tight_layout()
    return ax

ax = validate_wdfmi_bias_correction(m_main=11.5, m_witness=0.05, ndata=30, distortion_range=np.linspace(0.0001, 0.1, 5), n_phase_trials=20)
plt.show()