In [None]:
from DeepFMKit.workers import run_single_trial

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import scipy.constants as sc
import multiprocessing
import os

def generate_error_budget_plot_hybrid_parallel(
    m_true=15.5,
    ndata=15,
    f0=193.55e12,
    delta_f=3e9,
    T_acq_range=np.logspace(0, 4, 6),
    n_trials=1000,
    amp_asd=1e-5,
    freq_asd=1e3,
    T_base_stat=0.1,
    n_cores=None
):
    """
    Generates the DFMI error budget plot using a hybrid parallel strategy.

    This function efficiently calculates and visualizes the trade-off between
    statistical and systematic errors. For the systematic error from laser
    drift, it calculates and plots both the mean bias and the standard
    deviation of the bias over many trials. This provides a complete picture
    of the error's behavior.

    The method uses three key techniques:
    1.  It characterizes a baseline statistical error at a short acquisition
        time (`T_base_stat`) via a parallel Monte Carlo simulation.
    2.  It analytically extrapolates this baseline error to all other
        acquisition times using the known `1/sqrt(T)` scaling of white noise.
    3.  It uses a parallelized hybrid simulation for laser drift, modeling
        the low-frequency drift accumulation over long times and injecting its
        characteristic distortion into short, high-rate signals for fitting.

    Parameters
    ----------
    m_true : float, optional
        The ground-truth modulation depth to test.
    ndata : int, optional
        The number of harmonics to use in the fit.
    f0 : float, optional
        The laser carrier frequency in Hz.
    delta_f : float, optional
        The laser frequency modulation amplitude in Hz.
    T_acq_range : array_like, optional
        The range of acquisition times (in seconds) to simulate.
    n_trials : int, optional
        The number of Monte Carlo trials to run for each simulation point.
    amp_asd : float, optional
        The ASD of the white amplitude noise.
    freq_asd : float,optional
        The ASD of the 1/f laser frequency noise at 1 Hz.
    T_base_stat : float, optional
        The fixed, short acquisition time (in seconds) used to establish the
        baseline statistical uncertainty.
    n_cores : int, optional
        Number of CPU cores to use. If None, uses all available cores.
    """
    if n_cores is None:
        n_cores = os.cpu_count()
        print(f"Using all available {n_cores} CPU cores.")

    print("=" * 60)
    print("Generating Error Budget Plot (Hybrid Parallel Version)")
    print(f"Parameters: m_true={m_true}, n_trials={n_trials}, n_cores={n_cores}")
    print("=" * 60)

    # --- 1. Characterize BASELINE Statistical Error ---
    print(f"\nCharacterizing baseline statistical error at T_acq = {T_base_stat} s...")
    job_params_base = [{
        'trial_num': i, 'sim_type': 'stat_base', 'T_acq': 0, 'T_base': T_base_stat,
        'm_true': m_true, 'delta_f': delta_f, 'amp_asd': amp_asd,
        'freq_asd': freq_asd, 'ndata': ndata
    } for i in range(n_trials)]

    with multiprocessing.Pool(processes=n_cores) as pool:
        desc = f"Base Stat Trials (T={T_base_stat}s)"
        results_iterator = pool.imap(run_single_trial, job_params_base)
        m_fits_base = list(tqdm(results_iterator, total=n_trials, desc=desc))

    m_std_base = np.std(m_fits_base)
    print(f"--> Baseline statistical uncertainty (std dev of m) = {m_std_base:.3e}")

    # --- 2. Build the two error curves over the full T_acq range ---
    # Prepare lists to store the results for each T_acq
    delta_m_stat_list = []
    mean_m_sys_list = []
    std_m_sys_list = []
    m_to_phi_factor = f0 / delta_f

    for T_acq in T_acq_range:
        print(f"\nProcessing T_acq = {T_acq:.4f} s...")

        # A) EXTRAPOLATE Statistical Error (fast, no simulation needed)
        delta_m_stat = m_std_base * np.sqrt(T_base_stat / T_acq)
        delta_m_stat_list.append(delta_m_stat)
        print(f"  --> Extrapolated stat error (m): {delta_m_stat:.3e}")

        # B) RUN HYBRID SIMULATION for Systematic Error (parallelized)
        job_params_sys = [{
            'trial_num': i, 'sim_type': 'sys_hybrid', 'T_acq': T_acq, 'T_base': 0,
            'm_true': m_true, 'delta_f': delta_f, 'amp_asd': amp_asd,
            'freq_asd': freq_asd, 'ndata': ndata
        } for i in range(n_trials)]

        with multiprocessing.Pool(processes=n_cores) as pool:
            desc = f"Sys Hybrid Trials (T={T_acq:.2f}s)"
            results_iterator = pool.imap(run_single_trial, job_params_sys)
            biases_sys = list(tqdm(results_iterator, total=n_trials, desc=desc))
        
        # Calculate both mean and standard deviation of the bias
        mean_m_sys_list.append(np.mean(biases_sys))
        std_m_sys_list.append(np.std(biases_sys))
        print(f"  --> Hybrid systematic error (m): mean={np.mean(biases_sys):.3e}, std={np.std(biases_sys):.3e}")

    # --- 3. Analysis and Plotting ---
    print("\nSimulation complete. Generating plot...")
    # Convert lists to numpy arrays
    delta_m_stat_arr = np.array(delta_m_stat_list)
    mean_m_sys_arr = np.array(mean_m_sys_list)
    std_m_sys_arr = np.array(std_m_sys_list)

    # Convert m errors to phase errors
    delta_phi_stat = delta_m_stat_arr * m_to_phi_factor
    delta_phi_sys_mean = mean_m_sys_arr * m_to_phi_factor
    delta_phi_sys_std = std_m_sys_arr * m_to_phi_factor
    
    # Total error is the quadrature sum of the UNCERTAINTIES (standard deviations)
    delta_phi_total = np.sqrt(delta_phi_stat**2 + delta_phi_sys_std**2)

    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Plot the statistical uncertainty
    ax.loglog(T_acq_range, delta_phi_stat, 'o-', color='tab:blue', label=r'Statistical Error (Uncertainty)')
    
    # Plot the systematic uncertainty (std dev of bias)
    ax.loglog(T_acq_range, delta_phi_sys_std, 's-', color='tab:red', label=r'Systematic Error from Drift (Std. Dev. of Bias)')

    # Plot the systematic mean bias (should be near zero)
    ax.loglog(T_acq_range, np.abs(delta_phi_sys_mean), 'x--', color='tab:purple', label=r'Systematic Error from Drift (Mean Bias)')
    
    # Plot the total error
    ax.loglog(T_acq_range, delta_phi_total, 'k-', linewidth=3, label='Total Error (Quadrature Sum)')
    
    # Plot the ambiguity limit
    ax.axhline(np.pi, color='k', linestyle='--', linewidth=2.5, label=r'Ambiguity Limit ($\pi$)')

    ax.set_xlabel(r'Acquisition Time, $T_{\rm acq}$ (s)', fontsize=14)
    ax.set_ylabel(r'Coarse Phase Error, $|\delta\Phi_{\rm coarse}|$ (rad)', fontsize=14)
    ax.set_title('DFMI Error Budget vs. Acquisition Time', fontsize=16)
    ax.grid(True, which='both', linestyle=':')
    
    wavelength = sc.c / f0
    phi_to_length_nm = (wavelength / (2 * np.pi)) * 1e9
    ax2 = ax.twinx()
    ymin, ymax = ax.get_ylim()
    ax2.set_yscale('log')
    ax2.set_ylim(ymin * phi_to_length_nm, ymax * phi_to_length_nm)
    ax2.set_ylabel(r'Equivalent Length Error (nm)', fontsize=14)

    # Find and mark the optimal point on the total error curve
    min_error_idx = np.argmin(delta_phi_total)
    optimal_T = T_acq_range[min_error_idx]
    min_error = delta_phi_total[min_error_idx]
    ax.plot(optimal_T, min_error, 'p', color='gold', markersize=15, markeredgecolor='black', label=f'Optimal Point ({optimal_T:.2f} s)')
    
    ax.legend(fontsize=12, loc='best')
    fig.tight_layout()
    plt.show()

In [None]:
generate_error_budget_plot_hybrid_parallel()