In [None]:
import deepfmkit.core as dfm
import numpy as np
from scipy import signal
from deepfmkit.plotting import default_rc
import matplotlib.pyplot as plt

# Use the package's default publication-quality plot settings
plt.rcParams.update(default_rc)

def create_noise_validation_figure(
    asd_at_1hz=1000.0,
    num_points=int(2e6),
    fs=200e3,
    output_filename="fig_noise.pdf",
):
    """
    Generates a figure validating the colored noise generation for alpha = 0, 1, and 2.

    This function iterates through three common noise types, generates a time series for
    each using DeepFMKit's physics engine, computes their Amplitude Spectral Density (ASD),
    and plots them against their theoretical slopes on a single log-log plot.
    """
    print("=" * 60)
    print("Generating Noise Validation Figure for Paper")
    print(f"Target ASD @ 1 Hz: {asd_at_1hz:.1f} Hz/sqrt(Hz)")
    print(f"Sampling Frequency: {fs / 1e3:.1f} kHz")
    print(f"Number of Points: {num_points}")
    print("=" * 60)
    
    # --- 1. Setup Figure ---
    fig, ax = plt.subplots(figsize=(3.875, 3))
    alphas_to_plot = [0.0, 1.0, 2.0]
    
    # Define plot styles for each alpha
    styles = {
        0.0: {'color': 'gray', 'label': r'White noise ($\alpha=0$)'},
        1.0: {'color': 'C1', 'label': r'Flicker noise ($\alpha=1$)'},
        2.0: {'color': 'C2', 'label': r'Random walk ($\alpha=2$)'},
    }

    # --- 2. Loop through alpha values, generate, and plot ---
    for i, alpha in enumerate(alphas_to_plot):
        print(f"Processing alpha = {alpha}...")
        
        # --- Configure the simulation for this noise type ---
        laser_config = dfm.LaserConfig(f_n=asd_at_1hz, f_n_alpha=alpha)
        sim_config = dfm.SimConfig(
            "noise_test", laser_config, dfm.IfoConfig(), f_samp=fs
        )

        # --- Generate the noise time series ---
        generator = dfm.SignalGenerator()
        noise_dict = generator._generate_noise_arrays(sim_config, num_points)
        frequency_noise_t = noise_dict["laser_frequency"]

        # --- Calculate the Amplitude Spectral Density (ASD) ---
        nperseg = min(int(fs * 10), num_points // 10)
        freqs, psd = signal.welch(
            frequency_noise_t, fs=fs, window="hann", nperseg=nperseg, scaling="density"
        )
        freqs, psd = freqs[1:], psd[1:]
        asd = np.sqrt(psd)

        # --- Plot the generated noise ASD ---
        ax.loglog(freqs, asd, color=styles[alpha]['color'], label=styles[alpha]['label'])

        # --- Plot the theoretical ASD slope ---
        valid_freqs = freqs[freqs > 0]
        if alpha == 0.0:
            theoretical_asd = np.full_like(valid_freqs, asd_at_1hz)
        else:
            theoretical_asd = asd_at_1hz / (valid_freqs ** (alpha / 2.0))
        
        # Only add one label for all theoretical lines
        # label_th = r'Theoretical $1/f^{\alpha/2}$ Slope' if i == 0 else None
        ax.loglog(valid_freqs, theoretical_asd, '--', color='k')

    # --- 3. Finalize and Save Plot ---
    ax.legend(loc='lower left', framealpha=1, handlelength=2.9, fontsize=7, edgecolor='k', fancybox=False)
    ax.set_xlabel("Frequency (Hz)")
    ax.set_ylabel(r"ASD (Hz / $\sqrt{\rm Hz}$)")
    ax.set_xlim([0, fs / 2])
    ax.set_ylim([1e-2, 5e3]) # Adjust ylim to fit the data nicely
    ax.grid(True, which="both", ls="--", alpha=0.6)

    plt.tight_layout()
    plt.savefig(output_filename)
    print(f"\nFigure saved as {output_filename}")
    plt.show()


if __name__ == "__main__":
    create_noise_validation_figure()