In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from scipy import constants as sc
import pickle

from DeepFMKit.experiments import Experiment
from DeepFMKit.factories import StandardWDFMIExperimentFactory
from DeepFMKit.waveforms import second_harmonic_distortion # Even if unused, good practice

fitter_to_test = 'wdfmi_ortho'

exp = Experiment(description=f"W-DFMI Residual Bias")
exp.set_config_factory(StandardWDFMIExperimentFactory(waveform_function=second_harmonic_distortion, opd_main=0.1))
exp.n_fit_buffers_per_trial = 1
exp.f_samp = 200e3
exp.n_trials = 100
exp.add_axis('m_witness', np.linspace(0.1, 1.0, 15)) # Outer loop
exp.add_axis('m_main', np.linspace(3, 25, 15))    # Inner loop
exp.add_stochastic_variable(
    'waveform_kwargs', 
    lambda: {'distortion_amp': 0.2, 'distortion_phase': np.random.uniform(0, 2*np.pi)}
)
exp.add_stochastic_variable(
    'phi', 
    lambda: np.random.uniform(0, 2*np.pi)
)
exp.add_analysis(name='wdfmi_fit', fitter_method=fitter_to_test, result_cols=['m'])

# Run the experiment
results = exp.run()

In [None]:
# --- Parameters ---
fitter_name = 'wdfmi_fit'
threshold = 0.1  # For color scaling of fractional bias plots

# --- 1. Extract Axes and Fitted Data ---
m_main_axis = results['axes']['m_main']
m_witness_axis = results['axes']['m_witness']
m_fit_all_trials = results[fitter_name]['m']['all_trials']  # Shape: (N_witness, N_main, N_trials)

# --- 2. Compute Bias Metrics ---
# Broadcast true m_main over trials and m_witness
m_true_broadcast = m_main_axis[np.newaxis, :, np.newaxis]  # Shape: (1, N_main, 1)
absolute_bias_all = m_fit_all_trials - m_true_broadcast  # Shape: (N_witness, N_main, N_trials)

# Fractional bias: (m_est - m_true) / m_true, safe divide
fractional_bias_all = np.divide(
    absolute_bias_all,
    m_true_broadcast,
    out=np.zeros_like(absolute_bias_all),
    where=m_true_broadcast != 0
)

# --- 3. Aggregate Metrics ---
mean_fractional_bias = np.abs(np.mean(fractional_bias_all, axis=-1))        # Shape: (N_witness, N_main)
worst_case_fractional_bias = np.max(np.abs(fractional_bias_all), axis=-1)   # Shape: (N_witness, N_main)

# --- 4. Plotting ---
fig, (ax1, ax2) = plt.subplots(
    1, 2, figsize=(6.875, 2.5), dpi=150, sharey=True
)

# Custom colormap: white to dark blue
custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    'custom_blue', [(0, 'white'), (1, 'darkblue')]
)

# Panel 1: Mean Fractional Bias
pcm1 = ax1.pcolormesh(
    m_main_axis,
    m_witness_axis,
    mean_fractional_bias,
    cmap=custom_cmap,
    shading='nearest',
    vmin=0,
    vmax=threshold
)
ax1.set_title('Mean')
ax1.set_xlabel(r'$m_{\rm main}$ (rad)')
ax1.set_ylabel(r'$m_{\rm witness}$ (rad)')
ax1.grid(False)

# Panel 2: Worst-Case Fractional Bias
pcm2 = ax2.pcolormesh(
    m_main_axis,
    m_witness_axis,
    worst_case_fractional_bias,
    cmap=custom_cmap,
    shading='nearest',
    vmin=0,
    vmax=threshold
)
ax2.set_title('Worst-case')
ax2.set_xlabel(r'$m_{\rm main}$ (rad)')
ax2.grid(False)

# Colorbar
fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.87, 0.15, 0.03, 0.7])
cbar = fig.colorbar(pcm2, cax=cbar_ax)
cbar.set_label(r'Fractional bias $|\delta m / m|$')
cbar_ax.grid(False)

plt.show()