In [5]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
from dmphasev2 import DMPhaseEstimator

__all__ = [
    "make_synthetic_waterfall",
    "dedisperse_fd",
    "plot_power_spectrum",
    "plot_power_map",
    "plot_dm_diagnostics",
    "test_estimator",
]

K_DM_S = 4.148808e3  # MHz^2 pc^-1 cm^3 ms ---------------------------------------------------


def make_synthetic_waterfall(
    n_time: int = 2048,
    n_chan: int = 256,
    dt: float = 1e-4,
    f_hi: float = 800.0,
    f_lo: float = 400.0,
    dm_true: float = 350.0,
    pulse_width: float = 0.002,
    snr: float = 10.0,
    seed: int | None = 0,
):
    """Create a toy waterfall with a dispersed Gaussian pulse."""
    rng = np.random.default_rng(seed)
    freqs = np.linspace(f_hi, f_lo, n_chan)  # descending order
    t = np.arange(n_time) * dt

    # Dispersion delay relative to top of band
    ref_inv = 1.0 / (f_hi**2)
    delay_sec = 1e-3 * K_DM_S * dm_true * (1.0 / freqs**2 - ref_inv)

    waterfall = np.empty((n_time, n_chan), dtype="complex128")
    t0 = 0.5 * n_time * dt
    sigma = pulse_width / (2.0 * np.sqrt(2.0 * np.log(2.0)))

    for ic, d in enumerate(delay_sec):
        envelope = np.exp(-0.5 * ((t - (t0 + d)) / sigma) ** 2)
        noise = rng.standard_normal(n_time) + 1j * rng.standard_normal(n_time)
        waterfall[:, ic] = snr * envelope + noise

    return waterfall, freqs


# -------------------------------------------------------------------------------------------
# Helper: single‑DM frequency‑domain dedispersion -------------------------------------------


def dedisperse_fd(
    waterfall: np.ndarray, freqs: np.ndarray, dt: float, dm: float, nu_ref: float
):
    """Dedisperse a dynamic spectrum at one DM by FFT‑phase rotation (returns real part)."""
    n_t = waterfall.shape[0]
    freq_axis = np.fft.fftfreq(n_t, d=dt)
    delay_sec = 1e-3 * K_DM_S * dm * (1.0 / freqs**2 - 1.0 / nu_ref**2)
    phase = np.exp(-2j * np.pi * freq_axis[:, None] * delay_sec[None, :])
    spec = np.fft.fft(waterfall, axis=0)
    spec *= phase
    return np.fft.ifft(spec, axis=0).real


# -------------------------------------------------------------------------------------------
# Diagnostics: 1‑D power spectrum ------------------------------------------------------------


def plot_power_spectrum(est: DMPhaseEstimator, ax: plt.Axes | None = None):
    """Plot ω²‑weighted coherent power P'(ω) at DM_best and save PNG."""
    spec_cube = est._phase_cube()
    best_idx = np.argmin(np.abs(est.dm_grid - est.dm_best))
    power = est._coherent_power(spec_cube)[best_idx]
    win_mask = est._select_freq_window(power[None, :])

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    else:
        fig = ax.figure
    idx = np.fft.fftshift(np.arange(power.size))
    ax.plot(est.freq_axis[idx], power[idx], lw=0.8)
    ax.fill_between(est.freq_axis[idx], 0, power[idx], where=win_mask, alpha=0.2)
    ax.set_xlabel("Fluctuation frequency (Hz)")
    ax.set_ylabel(r"$P'_{\mathrm{Co}}(\omega)$ (arb.)")
    ax.set_title("Coherent power spectrum at DM$_{\mathrm{best}}$")
    #fig.tight_layout()
    plt.show()
    fig.savefig("power_spectrum_v2.png", dpi=150)
    if ax is None:
        plt.close(fig)


# -------------------------------------------------------------------------------------------
# NEW: 2‑D power map (DM × fluctuation frequency) -------------------------------------------


def plot_power_map(est: DMPhaseEstimator):
    """Render and save the 2‑D coherent‑power map P'(ω, DM)."""
    # Full power cube
    spec_cube = est._phase_cube()                  # (n_dm, n_t, n_chan)
    power_cube = est._coherent_power(spec_cube)       # (n_dm, n_t)

    freq_axis = est.freq_axis
    pos_mask = freq_axis >= 0                         # positive fluct. freqs
    power_sel = power_cube[:, pos_mask]
    freq_sel = freq_axis[pos_mask]

    y0, y1 = est.dm_grid.min(), est.dm_grid.max()

    fig, ax = plt.subplots(figsize=(7, 4))
    im = ax.imshow(
        power_sel,
        aspect="auto",
        origin="lower",
        extent=[freq_sel[0], freq_sel[-1], y0, y1],
        cmap="viridis",
    )

    ax.axhline(est.dm_best, color="w", lw=1, ls="--", label=f"DM_best = {est.dm_best:.3f}")
    ax.set_xlabel("Fluctuation frequency (Hz)")
    ax.set_ylabel("Trial DM (pc cm$^{-3}$)")
    ax.set_title("Coherent power map P'(ω, DM)")
    ax.set_ylim(y0, y1)             # ensure no padding below lowest DM
    ax.legend(loc="upper right", frameon=False)

    fig.colorbar(im, ax=ax, label=r"$P'_{m Co}$ (arb.)")
    #fig.tight_layout()
    plt.show()
    fig.savefig("power_map_v2.png", dpi=150)
    plt.close(fig)


# -------------------------------------------------------------------------------------------
# Composite diagnostics ---------------------------------------------------------------------


def plot_dm_diagnostics(est: DMPhaseEstimator, waterfall: np.ndarray, freqs: np.ndarray, dt: float):
    """Generate DM curve, power diagnostics, and dedispersed waterfalls."""
    res = est.result()
    dm_best, dm_sig, nu_ref = res["dm_best"], res["dm_sigma"], est.nu_ref

    # 2. 1‑D power spectrum
    plot_power_spectrum(est)

    # 3. 2‑D power map
    plot_power_map(est)

    # 4. Dedispersed waterfalls
    dms = [dm_best - 5 * dm_sig, dm_best, dm_best + 5 * dm_sig]
    titles = [f"DM = {d:.2f}" for d in dms]
    fig2, axs = plt.subplots(3, 1, figsize=(6, 7), sharex=True)
    t = np.arange(waterfall.shape[0]) * dt
    for ax, d, ttl in zip(axs, dms, titles):
        w_dedisp = dedisperse_fd(waterfall, freqs, dt, d, nu_ref)
        im = ax.imshow(
            np.abs(w_dedisp.T),
            aspect="auto",
            origin="upper",  # high frequencies at top
            extent=[t[0], t[-1], freqs[-1], freqs[0]],
        )
        ax.set_ylabel("Freq (MHz)")
        ax.set_title(ttl)
    axs[-1].set_xlabel("Time (s)")
    fig2.colorbar(im, ax=axs, shrink=0.8, label="|Voltage|")
    #fig2.tight_layout()
    plt.show()
    fig2.savefig("dedispersed_waterfalls_v2.png", dpi=150)
    plt.close(fig2)


# -------------------------------------------------------------------------------------------
# End‑to‑end test harness -------------------------------------------------------------------


def test_estimator(verbose: bool = True, make_plots: bool = True):
    """Run estimator on synthetic burst and (optionally) save diagnostic plots."""
    true_dm = 350.0
    waterfall, freqs = make_synthetic_waterfall(dm_true=true_dm, n_time=1024, n_chan=32, seed=42)
    dt = 1e-4
    dm_grid = np.linspace(348, 352, 101) #201)

    est = DMPhaseEstimator(
        waterfall,
        freqs,
        dt,
        dm_grid,
        ref="top",
        n_boot=21, #200,
        random_state=42,
    )
    res = est.result()

    if verbose:
        print(
            f"Recovered DM = {res['dm_best']:.3f} ± {res['dm_sigma']:.3f} pc cm^-3 (true {true_dm})"
        )
    #assert abs(res["dm_best"] - true_dm) < 0.5, "DM recovery outside tolerance"

    if make_plots:
        plot_dm_diagnostics(est, waterfall, freqs, dt)
        if verbose:
            print(
                "Diagnostic plots saved: dm_curve_v2.png, power_spectrum_v2.png, "
                "power_map_v2.png, dedispersed_waterfalls_v2.png"
            )

    return res


if __name__ == "__main__":
    test_estimator()
    print("Synthetic DM recovery test passed.")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Recovered DM = 348.877 ± 1.754 pc cm^-3 (true 350.0)


AttributeError: 'DMPhaseEstimator' object has no attribute '_select_freq_window'