In [None]:
import numpy as np
import matplotlib.pyplot as plt
import time

from speckit import compute_spectrum
import speckit.core as _core
import speckit.analysis as _analysis


# --- helpers to toggle numba path globally in speckit ------------------------
def _set_numba_enabled(flag: bool):
    _core._NUMBA_ENABLED = bool(flag)
    _analysis._NUMBA_ENABLED = bool(flag)


# --- config ------------------------------------------------------------------
N_values = np.array([1e3, 1e4, 1e5, 1e6], dtype=int)
n_runs = 5  # Number of runs for averaging

orders = [-1, 0, 1, 2]
times_numba = {order: [] for order in orders}
times_numpy = {order: [] for order in orders}

# Inputs
inputs = []
np.random.seed(123)
for N in N_values:
    inputs.append(np.random.rand(N))

# --- warm-up: compile numba once so timing isn't dominated by JIT ------------
_set_numba_enabled(True)
_ = compute_spectrum(inputs[0], fs=2, order=0)

# --- benchmark: Numba path ---------------------------------------------------
for order in orders:
    print(f"[Numba] order={order}")
    _set_numba_enabled(True)
    for x in inputs:
        run_times = []
        for _ in range(n_runs):
            t0 = time.time()
            _ = compute_spectrum(x, fs=2, order=order)
            run_times.append(time.time() - t0)
        times_numba[order].append(np.mean(run_times))

# --- benchmark: NumPy fallback path ------------------------------------------
for order in orders:
    print(f"[NumPy] order={order}")
    _set_numba_enabled(False)
    for x in inputs:
        run_times = []
        for _ in range(n_runs):
            t0 = time.time()
            _ = compute_spectrum(x, fs=2, order=order)
            run_times.append(time.time() - t0)
        times_numpy[order].append(np.mean(run_times))

# --- Calculate speedups ------------------------------------------------------
speedups = {}
for order in orders:
    speedups[order] = [
        t_np / t_nb for t_np, t_nb in zip(times_numpy[order], times_numba[order])
    ]

print("\nSpeedup factors (NumPy / Numba):")
for order in orders:
    print(f"  order={order}: {[f'{s:.1f}x' for s in speedups[order]]}")

# --- quick numerical consistency check (choose an index into N_values) -------
check_idx = 1  # 0..len(N_values)-1; here N=1e4
x_check = inputs[check_idx]

print("\nNumerical consistency check (Numba vs NumPy):")
max_rel_err = {}
for order in orders:
    # Numba result
    _set_numba_enabled(True)
    res_nb = compute_spectrum(x_check, fs=2, order=order)
    # NumPy result
    _set_numba_enabled(False)
    res_np = compute_spectrum(x_check, fs=2, order=order)

    # Compare Gxx; also ASD as a sanity check
    gxx_nb, gxx_np = res_nb.Gxx, res_np.Gxx
    asd_nb, asd_np = res_nb.asd, res_np.asd

    def _rel_err(a, b, eps=1e-300):
        num = np.max(np.abs(a - b))
        den = np.max(np.abs(b)) + eps
        return num / den

    rel_gxx = _rel_err(gxx_nb, gxx_np)
    rel_asd = _rel_err(asd_nb, asd_np) if asd_nb is not None else np.nan
    max_rel_err[order] = max(rel_gxx, rel_asd)

    print(f"  order={order}: max rel err Gxx={rel_gxx:.3e}, ASD={rel_asd:.3e}")

# --- Save timing data for paper ----------------------------------------------
print("\nTiming data summary:")
print("N_values:", N_values)
for order in orders:
    print(f"\nOrder {order}:")
    print(f"  Numba: {[f'{t:.4f}' for t in times_numba[order]]}")
    print(f"  NumPy:  {[f'{t:.4f}' for t in times_numpy[order]]}")
    print(f"  Speedup: {[f'{s:.1f}x' for s in speedups[order]]}")

In [None]:
fig, ax = plt.subplots(figsize=(6, 4.5), dpi=300)
markers = {-1: "o", 0: "s", 1: "^", 2: "D"}
colors = {-1: "C0", 0: "C1", 1: "C2", 2: "C3"}

for order in orders:
    # Numba path (solid lines)
    ax.loglog(
        N_values,
        times_numba[order],
        marker=markers[order],
        ls="-",
        color=colors[order],
        label=f"Numba (order={order})",
        markersize=6,
        linewidth=1.5,
    )
    # NumPy path (dashed lines)
    ax.loglog(
        N_values,
        times_numpy[order],
        marker=markers[order],
        ls="--",
        color=colors[order],
        label=f"NumPy (order={order})",
        markersize=6,
        linewidth=1.5,
        alpha=0.7,
    )

ax.set_xlabel("Length of time series $N$", fontsize=11)
ax.set_ylabel("Computation time (s)", fontsize=11)
ax.legend(
    loc="upper left",
    fontsize=9,
    framealpha=0.9,
    ncol=2,
    columnspacing=0.8,
)
ax.grid(True, which="both", color="lightgray", linestyle="-", linewidth=0.5, alpha=0.5)
ax.set_xlim([800, 1.2e6])
fig.tight_layout()

plt.show()