In [None]:
import DeepFMKit.core as dfm
import numpy as np
import scipy.optimize
from tqdm import tqdm
import matplotlib.pyplot as plt

def test_tau_estimation_varpro(m_true=15.5, distortion_eps=0.1):
    """
    Tests and visualizes the robustness of tau estimation using the
    Variable Projection (VarPro) / Orthogonal Demodulation method.

    This method correctly isolates the non-linear search for `tau` by
    analytically solving for the linear parameters (amplitudes and phases)
    inside the cost function. This provides a well-defined cost landscape
    with a clear global minimum at the correct tau value.
    """
    print("=" * 60)
    print("Testing Robustness of Step 1: Tau Estimation (VarPro Method)")
    print(f"m_true = {m_true}, distortion = {distortion_eps*100}%")
    print("=" * 60)
    
    # --- 1. Simulate a single, representative signal ---
    dff = dfm.DeepFitFramework()
    laser = dfm.LaserConfig()
    laser.df_2nd_harmonic_frac = distortion_eps
    ifo = dfm.InterferometerConfig()
    main_channel = dfm.DFMIObject("main", laser, ifo)
    
    opd = ifo.meas_arml - ifo.ref_arml
    laser.df = (m_true * dfm.sc.c) / (2 * np.pi * opd)
    dff.sims["main"] = main_channel
    
    witness_channel = dff.create_witness_channel("main", "witness", m_witness=0.1)
    
    n_seconds = main_channel.fit_n / laser.f_mod
    dff.simulate("main", n_seconds=n_seconds, witness_label="witness")

    # --- 2. Prepare Measured Data and Witness Basis ---
    main_raw = dff.raws["main"]
    main_buffer = np.array(main_raw.data).flatten()
    v_main_ac = main_buffer - np.mean(main_buffer)
    
    R = len(main_buffer)
    time_axis = np.arange(R) / main_raw.f_samp
    
    witness_raw = dff.raws["witness"]
    v_w_ac = np.array(witness_raw.data).flatten() - witness_raw.data.mean().iloc[0]
    f_mod_basis = -v_w_ac / np.max(np.abs(v_w_ac))
    dt = time_axis[1] - time_axis[0]
    phi_mod_basis = np.cumsum(f_mod_basis) * dt
    phi_mod_unscaled = 2 * np.pi * laser.df * phi_mod_basis

    # --- 3. Define the VarPro Cost Function for Tau ---
    # We fix psi to its true value (0.0 in this simulation) for this test
    psi_fixed = 0.0
    omega_mod = 2 * np.pi * laser.f_mod
    t_shift_psi = -psi_fixed / omega_mod
    t_interp_psi = time_axis - t_shift_psi
    phi_mod_shifted = np.interp(t_interp_psi, time_axis, phi_mod_unscaled, period=time_axis[-1])

    def cost_function_tau_varpro(tau):
        # For a given tau, construct the two orthogonal basis functions
        t_interp_tau = time_axis - tau
        phi_mod_delayed = np.interp(t_interp_tau, time_axis, phi_mod_shifted, period=time_axis[-1])
        delta_phi_mod = phi_mod_delayed - phi_mod_shifted
        
        basis_I = np.cos(delta_phi_mod)
        basis_Q = np.sin(delta_phi_mod)
        
        # Analytically find the best linear amplitudes I and Q
        A_matrix = np.vstack([basis_I, basis_Q]).T
        _, res, _, _ = np.linalg.lstsq(A_matrix, v_main_ac, rcond=None)
        
        # The cost is the residual of this best possible linear fit
        return res[0] if res.size > 0 else np.inf

    # --- 4. Scan the cost function and find the minimum ---
    tau_true = opd / dfm.sc.c
    tau_range = np.linspace(tau_true * 0.9, tau_true * 1.1, 200)
    
    print("Scanning VarPro cost function landscape for tau...")
    cost_values = [cost_function_tau_varpro(t) for t in tqdm(tau_range)]
    
    res_tau = scipy.optimize.minimize_scalar(cost_function_tau_varpro, bracket=(tau_range[0], tau_range[-1]), method='brent')
    tau_fit = res_tau.x
    
    # --- 5. Plot the results ---
    fig, ax = plt.subplots(figsize=(12, 6))
    
    ax.plot(tau_range * 1e9, cost_values, '.-', label='VarPro Cost Function S(τ)')
    ax.axvline(tau_true * 1e9, color='k', linestyle='--', label=f'True τ ({tau_true*1e9:.3f} ns)')
    ax.axvline(tau_fit * 1e9, color='r', linestyle=':', linewidth=2.5, label=f'Found τ ({tau_fit*1e9:.3f} ns)')
    
    ax.set_title('Robustness Test of τ Estimation (VarPro Method)', fontsize=16)
    ax.set_xlabel('Trial Time Delay, τ (ns)', fontsize=14)
    ax.set_ylabel('Cost Function (Linear Fit Residual)', fontsize=14)
    ax.legend()
    ax.grid(True)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    test_tau_estimation_varpro()