In [None]:
import DeepFMKit.core as dfm
from DeepFMKit.workers import run_efficiency_trial # Import the corrected worker
from DeepFMKit.fit import coeffs                  # Import from the correct module
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from tqdm import tqdm
import multiprocessing
import os

def calculate_crlb_for_m(m_true, ndata, snr_db, buffer_size):
    """
    Calculates the Cramér-Rao Lower Bound for the precision of 'm'.
    """
    signal_power = 0.5
    snr_linear = 10**(snr_db / 10.0)
    noise_power_td = signal_power / snr_linear
    sigma_iq_sq = noise_power_td / (2 * buffer_size)
    
    perfect_params = np.array([1.0, m_true, 0.0, 0.0])
    dummy_data = np.zeros(2 * ndata)
    
    _, JTJ_flat, _ = coeffs(ndata, dummy_data, perfect_params)
    fisher_matrix = JTJ_flat.reshape(4, 4)
    
    try:
        covariance_matrix = np.linalg.inv(fisher_matrix)
        variance_m = covariance_matrix[1, 1] * sigma_iq_sq
        return np.sqrt(variance_m)
    except np.linalg.LinAlgError:
        return np.nan

def validate_fitter_efficiency(
    m_true=15.5,
    ndata=15,
    snr_db=70.0,
    n_trials=500,
    n_cores=None
):
    """
    Performs a Monte Carlo simulation to test the NLS fitter's performance
    against the theoretical Cramér-Rao Lower Bound (CRLB).
    """
    if n_cores is None:
        n_cores = os.cpu_count()
        
    print("="*60)
    print("Fitter Efficiency Validation: Comparing Measured vs. Theoretical Precision")
    print(f"Parameters: m = {m_true}, ndata = {ndata}, SNR = {snr_db} dB")
    print(f"Number of Monte Carlo trials: {n_trials}")
    print("="*60)
    
    # --- 1. Setup Simulation Base Configurations ---
    laser_config = dfm.LaserConfig()
    ifo_config = dfm.InterferometerConfig()
    
    opd = ifo_config.meas_arml - ifo_config.ref_arml
    laser_config.df = (m_true * dfm.sc.c) / (2 * np.pi * opd)
    
    # Need to calculate buffer_size and n_seconds to pass to the worker
    temp_obj = dfm.DFMIObject("temp", laser_config, ifo_config)
    buffer_size = int(temp_obj.fit_n * (temp_obj.f_samp / temp_obj.laser.f_mod))
    n_seconds_per_buffer = temp_obj.fit_n / temp_obj.laser.f_mod
    
    # --- 2. Calculate Theoretical CRLB ---
    print("\nCalculating theoretical precision (CRLB)...")
    delta_m_crlb = calculate_crlb_for_m(m_true, ndata, snr_db, buffer_size)
    print(f"Theoretical Precision (CRLB): δm = {delta_m_crlb:.4e}")

    # --- 3. Run Monte Carlo Simulation in Parallel ---
    jobs = []
    for i in range(n_trials):
        jobs.append({
            'trial_num': i,
            'laser_config': laser_config, # Pass the config objects
            'ifo_config': ifo_config,
            'n_seconds': n_seconds_per_buffer,
            'snr_db': snr_db,
            'ndata': ndata,
            'm_true': m_true,
        })
        
    m_estimates = []
    print(f"\nRunning {n_trials} Monte Carlo simulations in parallel...")
    if __name__ == "__main__":
        with multiprocessing.Pool(processes=n_cores) as pool:
            results_iterator = pool.imap(run_efficiency_trial, jobs)
            for result in tqdm(results_iterator, total=len(jobs), desc="Running Trials"):
                m_estimates.append(result)

    # --- 4. Analyze and Print Results ---
    m_estimates = np.array([m for m in m_estimates if not np.isnan(m)])
    delta_m_measured = np.std(m_estimates)
    mean_m_measured = np.mean(m_estimates)
    bias = mean_m_measured - m_true
    
    efficiency = (delta_m_crlb**2 / delta_m_measured**2) * 100 if delta_m_measured > 0 else 0

    print("\n--- Results ---")
    print(f"Measured Mean of estimates:  <m> = {mean_m_measured:.6f}")
    print(f"Estimator Bias (<m> - m_true):   = {bias:.4e}")
    print(f"Measured Precision (Std Dev): δm = {delta_m_measured:.4e}")
    print(f"Estimator Efficiency (CRLB² / Measured²): {efficiency:.1f}%")

    # --- 5. Plot the Distribution of Estimates ---
    fig, ax = plt.subplots(figsize=(12, 7))
    
    # Using raw strings (r"...") to fix the SyntaxWarning
    ax.hist(m_estimates, bins=50, density=True, label=r'Histogram of $\hat{m}$ estimates' f'\n($N_{{trials}}={n_trials}$)', alpha=0.6)
    mu, std = norm.fit(m_estimates)
    bins = np.linspace(mu - 4*std, mu + 4*std, 100)
    p = norm.pdf(bins, mu, std)
    ax.plot(bins, p, 'k--', linewidth=2, label='Fitted Normal Distribution')

    ax.axvline(m_true, color='red', linestyle='-', linewidth=2, label=f'True m = {m_true:.4f}')
    ax.axvline(mean_m_measured, color='black', linestyle='--', linewidth=2, label=f'Measured Mean = {mean_m_measured:.4f}')

    ax.set_title(f"Estimator Performance vs. CRLB (m={m_true}, ndata={ndata}, SNR={snr_db}dB)", fontsize=16)
    ax.set_xlabel(r'Estimated Modulation Depth ($\hat{m}$)', fontsize=14)
    ax.set_ylabel('Probability Density', fontsize=14)
    
    results_str = (f"Bias = " f"{bias:.2e}\n"
                   f"Measured = {delta_m_measured:.3e}\n"
                   f"CRLB = {delta_m_crlb:.3e}\n"
                   f"Efficiency = {efficiency:.1f}%")
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax.text(0.05, 0.95, results_str, transform=ax.transAxes, fontsize=12, verticalalignment='top', bbox=props)

    ax.legend()
    ax.grid(True, linestyle=':')
    fig.tight_layout()
    plt.show()

# --- Run the script ---
if __name__ == "__main__":
    validate_fitter_efficiency(
        m_true=15.5,
        ndata=20,
        snr_db=70.0,
        n_trials=1000
    )