In [None]:
import numpy as np
import matplotlib.pyplot as plt

from speckit import compute_spectrum
from speckit.noise import white_noise, pink_noise, red_noise
from speckit.plotting import default_rc

plt.rcParams.update(default_rc)


def create_noise_validation_figure(
    asd_at_1hz: float = 1000.0,
    num_points: int = int(1e6),
    fs: float = 100e3,
    output_filename: str = "fig_noise_validation.pdf",
):
    """
    Generates a figure validating the colored noise generators against the speckit ASD estimator.

    This function iterates through white, pink, and red noise, generates a time series
    for each using the `speckit.noise` module, computes their Amplitude Spectral Density (ASD)
    using `speckit.compute_spectrum`, and plots them against their theoretical slopes.
    """
    print("=" * 60)
    print("Generating Noise Validation Figure")
    print(f"Target ASD @ 1 Hz: {asd_at_1hz:.1f} units/sqrt(Hz)")
    print(f"Sampling Frequency: {fs / 1e3:.1f} kHz")
    print(f"Number of Points: {num_points} ({num_points / fs:.1f} seconds)")
    print("=" * 60)

    # --- 1. Setup Figure ---
    fig, ax = plt.subplots()
    alphas_to_plot = [0.0, 1.0, 2.0]

    # Define plot styles for each alpha
    styles = {
        0.0: {"color": "gray", "label": r"White Noise ($\alpha=0$)"},
        1.0: {"color": "crimson", "label": r"Pink/Flicker Noise ($\alpha=1$)"},
        2.0: {"color": "dodgerblue", "label": r"Red/Brownian Noise ($\alpha=2$)"},
    }

    # Define frequencies for theoretical lines
    # We use the result from the first analysis to get a frequency axis
    theoretical_freqs = None

    # --- 2. Loop through alpha values, generate, and plot ---
    for i, alpha in enumerate(alphas_to_plot):
        print(f"Processing alpha = {alpha}...")

        # --- Generate the noise time series ---
        # The noise generators are scaled to have a two-sided PSD of 1.0 at 1 Hz.
        # We will generate the base noise and then scale it.
        f_min = 1e-2

        if alpha == 0.0:
            generator = white_noise(f_sample=fs, seed=i)
        elif alpha == 1.0:
            generator = pink_noise(f_sample=fs, f_min=f_min, f_max=fs / 2.1, seed=i)
        elif alpha == 2.0:
            generator = red_noise(f_sample=fs, f_min=f_min, seed=i)
        else:
            continue

        # Generate the base noise time-series
        base_noise_t = generator.get_series(num_points)

        # Calculate the required scaling factor and apply it
        scaling_factor = asd_at_1hz / np.sqrt(2.0)
        scaled_noise_t = base_noise_t * scaling_factor

        # --- Calculate the Amplitude Spectral Density (ASD) of the scaled noise ---
        # Using a high Jdes for good resolution, and a Hann window as it's standard.
        result = compute_spectrum(
            scaled_noise_t, fs=fs
        )

        # --- Plot the generated noise ASD ---
        ax.loglog(result.f, result.asd, **styles[alpha])

        # Store frequency axis for theoretical plots
        if theoretical_freqs is None:
            theoretical_freqs = result.f[result.f > f_min]

    # --- 3. Plot the theoretical ASD slopes ---
    for alpha in alphas_to_plot:
        if alpha == 0.0:
            theoretical_asd = np.full_like(theoretical_freqs, asd_at_1hz)
        else:
            theoretical_asd = asd_at_1hz / (theoretical_freqs ** (alpha / 2.0))

        label = "Theoretical Slopes" if alpha == 0.0 else None
        ax.loglog(
            theoretical_freqs, theoretical_asd, "--", color="k", label=label, lw=1.5
        )

    # --- 4. Finalize and Save Plot ---
    ax.legend(loc="lower left", fontsize=8)
    ax.set_xlabel("Frequency (Hz)")
    ax.set_ylabel(r"ASD (units / $\sqrt{\rm Hz}$)")
    ax.set_xlim([f_min, fs / 2])
    ax.set_ylim([asd_at_1hz / (fs / 2), asd_at_1hz * 10])  # Dynamic Y-axis limits

    plt.tight_layout(pad=0.5)
    plt.savefig(output_filename)
    print(f"\nFigure saved as {output_filename}")
    plt.show()


if __name__ == "__main__":
    create_noise_validation_figure()