# Moment-Based One-Step DM Estimator

This notebook implements and stress-tests the moment-based approach to dispersion measure estimation.

## Key Idea
Instead of trial dedispersion, we compute weighted moments of the dynamic spectrum and solve a 2×2 linear system for (t₀, DM) in closed form.

## Outline
1. Synthetic data generation
2. Algorithm implementation
3. Validation against known DM
4. Stress tests: S/N, scattering, RFI, pulse morphology
5. Bias-variance analysis
6. GPU implementation and benchmarks


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage, stats
from typing import Tuple, Optional, NamedTuple
import time

# Check for GPU availability
try:
    import cupy as cp
    HAS_GPU = True
    print(f"GPU available: {cp.cuda.runtime.getDeviceCount()} device(s)")
except ImportError:
    HAS_GPU = False
    print("CuPy not available, using CPU only")

np.random.seed(42)
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['image.cmap'] = 'viridis'


---
## 1. Physical Constants and Dispersion Law


In [None]:
# Dispersion constant: delay = K_DM * DM * (nu^-2 - nu_ref^-2)
# K_DM in seconds when DM is in pc/cm^3 and frequency in MHz
K_DM = 4.148808e3  # s * MHz^2 / (pc cm^-3)

def dispersion_delay(dm: float, freq_mhz: np.ndarray, freq_ref_mhz: float) -> np.ndarray:
    """Compute dispersion delay in seconds.
    
    Args:
        dm: Dispersion measure in pc/cm^3
        freq_mhz: Frequency array in MHz
        freq_ref_mhz: Reference frequency in MHz
    
    Returns:
        Time delay in seconds (positive = arrives later)
    """
    return K_DM * dm * (freq_mhz**-2 - freq_ref_mhz**-2)


---
## 2. Synthetic Dynamic Spectrum Generator

We need a realistic generator that supports:
- Configurable DM, pulse width, S/N
- Scattering (exponential tail)
- Spectral index
- RFI injection
- Frequency-dependent pulse broadening


In [None]:
class PulseParams(NamedTuple):
    """Parameters for synthetic pulse generation."""
    dm: float                    # pc/cm^3
    t0: float                    # arrival time at reference frequency (s)
    width: float                 # intrinsic pulse width (s)
    amplitude: float = 1.0       # peak amplitude
    spectral_index: float = 0.0  # S ∝ ν^α
    scattering_time: float = 0.0 # scattering timescale at ref freq (s)
    scattering_index: float = -4.0  # τ_scat ∝ ν^β (typically -4 to -4.4)


class ObsParams(NamedTuple):
    """Observational parameters."""
    freq_lo: float      # MHz
    freq_hi: float      # MHz
    n_chan: int         # number of frequency channels
    t_start: float      # s
    t_end: float        # s
    n_time: int         # number of time samples
    
    @property
    def freq_ref(self) -> float:
        return self.freq_hi  # reference at top of band
    
    @property
    def freqs(self) -> np.ndarray:
        return np.linspace(self.freq_hi, self.freq_lo, self.n_chan)
    
    @property
    def times(self) -> np.ndarray:
        return np.linspace(self.t_start, self.t_end, self.n_time)
    
    @property
    def dt(self) -> float:
        return (self.t_end - self.t_start) / (self.n_time - 1)
    
    @property
    def df(self) -> float:
        return (self.freq_hi - self.freq_lo) / (self.n_chan - 1)


In [None]:
def generate_dynamic_spectrum(
    pulse: PulseParams,
    obs: ObsParams,
    noise_level: float = 1.0,
    rfi_fraction: float = 0.0,
    rfi_strength: float = 10.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Generate a synthetic dynamic spectrum.
    
    Args:
        pulse: Pulse parameters
        obs: Observation parameters
        noise_level: Standard deviation of Gaussian noise
        rfi_fraction: Fraction of channels with RFI
        rfi_strength: RFI amplitude multiplier
    
    Returns:
        dynamic_spectrum: (n_chan, n_time) array
        freqs: frequency array in MHz
        times: time array in seconds
    """
    freqs = obs.freqs
    times = obs.times
    
    # Compute arrival time at each frequency
    delays = dispersion_delay(pulse.dm, freqs, obs.freq_ref)
    arrival_times = pulse.t0 + delays  # shape: (n_chan,)
    
    # Compute scattering timescale at each frequency
    if pulse.scattering_time > 0:
        tau_scat = pulse.scattering_time * (freqs / obs.freq_ref) ** pulse.scattering_index
    else:
        tau_scat = np.zeros_like(freqs)
    
    # Compute spectral envelope
    spectral_envelope = (freqs / obs.freq_ref) ** pulse.spectral_index
    
    # Generate pulse profile at each channel
    # shape: (n_chan, n_time)
    t_grid, arr_grid = np.meshgrid(times, arrival_times)
    
    # Intrinsic Gaussian pulse
    signal = pulse.amplitude * np.exp(-0.5 * ((t_grid - arr_grid) / pulse.width) ** 2)
    
    # Apply scattering (convolve with one-sided exponential)
    if pulse.scattering_time > 0:
        for i, tau in enumerate(tau_scat):
            if tau > 0:
                # Create exponential kernel
                kernel_len = min(int(5 * tau / obs.dt) + 1, obs.n_time // 2)
                if kernel_len > 1:
                    t_kernel = np.arange(kernel_len) * obs.dt
                    kernel = np.exp(-t_kernel / tau)
                    kernel /= kernel.sum()
                    signal[i] = np.convolve(signal[i], kernel, mode='same')
    
    # Apply spectral envelope
    signal *= spectral_envelope[:, np.newaxis]
    
    # Add noise
    noise = noise_level * np.random.randn(obs.n_chan, obs.n_time)
    dynamic_spectrum = signal + noise
    
    # Add RFI
    if rfi_fraction > 0:
        n_rfi_chans = int(rfi_fraction * obs.n_chan)
        rfi_chans = np.random.choice(obs.n_chan, n_rfi_chans, replace=False)
        for ch in rfi_chans:
            # Broadband RFI: constant offset + time-variable component
            dynamic_spectrum[ch] += rfi_strength * (1 + 0.5 * np.random.randn(obs.n_time))
    
    return dynamic_spectrum, freqs, times


In [None]:
# Test the generator
obs = ObsParams(
    freq_lo=1100, freq_hi=1500, n_chan=256,
    t_start=0.0, t_end=0.5, n_time=512
)
pulse = PulseParams(dm=500.0, t0=0.1, width=0.002, amplitude=5.0)

ds, freqs, times = generate_dynamic_spectrum(pulse, obs, noise_level=1.0)

fig, ax = plt.subplots(figsize=(12, 6))
im = ax.imshow(ds, aspect='auto', origin='lower',
               extent=[times[0]*1e3, times[-1]*1e3, freqs[-1], freqs[0]])
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Frequency (MHz)')
ax.set_title(f'Synthetic Dynamic Spectrum (DM = {pulse.dm} pc/cm³, S/N ≈ {pulse.amplitude:.1f})')
plt.colorbar(im, label='Intensity')
plt.tight_layout()
plt.show()

print(f"Expected delay across band: {dispersion_delay(pulse.dm, obs.freq_lo, obs.freq_hi)*1e3:.2f} ms")


---
## 3. Moment-Based DM Estimator Implementation

### Mathematical Foundation

Define $f_i = k_{\rm DM}(\nu_i^{-2} - \nu_{\rm ref}^{-2})$ as the "dispersion coordinate". 

The objective function is:
$$J(\text{DM}, t_0) = \sum_{i,k} w_{ik} \left[t_k - t_0 - f_i \cdot \text{DM}\right]^2$$

Setting $\partial J/\partial t_0 = 0$ and $\partial J/\partial \text{DM} = 0$ yields:

$$\begin{pmatrix} W & W_f \\ W_f & W_{ff} \end{pmatrix} \begin{pmatrix} t_0 \\ \text{DM} \end{pmatrix} = \begin{pmatrix} W_t \\ W_{ft} \end{pmatrix}$$

where:
- $W = \sum w_{ik}$
- $W_f = \sum w_{ik} f_i$  
- $W_{ff} = \sum w_{ik} f_i^2$
- $W_t = \sum w_{ik} t_k$
- $W_{ft} = \sum w_{ik} f_i t_k$

Solving by Cramer's rule:
$$\text{DM} = \frac{W \cdot W_{ft} - W_f \cdot W_t}{W \cdot W_{ff} - W_f^2}$$
$$t_0 = \frac{W_t - W_f \cdot \text{DM}}{W}$$


In [None]:
class MomentEstimatorResult(NamedTuple):
    """Result from moment-based DM estimation."""
    dm: float
    t0: float
    dm_err: Optional[float] = None
    t0_err: Optional[float] = None
    weights_sum: float = 0.0
    condition_number: float = 0.0


def moment_dm_estimator(
    dynamic_spectrum: np.ndarray,
    freqs: np.ndarray,
    times: np.ndarray,
    freq_ref: Optional[float] = None,
    sigma_threshold: float = 2.0,
    max_iterations: int = 5,
    pulse_width_hint: Optional[float] = None,
) -> MomentEstimatorResult:
    """Estimate DM using moment-based closed-form solution with iterative refinement.
    
    CRITICAL: This method works well at high S/N (>15) but degrades at low S/N.
    For low S/N, use bowtie/trial dedispersion instead.
    
    Args:
        dynamic_spectrum: (n_chan, n_time) intensity array
        freqs: frequency array in MHz
        times: time array in seconds
        freq_ref: reference frequency (default: max freq)
        sigma_threshold: S/N threshold for pixel selection
        max_iterations: number of refinement iterations
        pulse_width_hint: approximate pulse width in seconds (for proximity weighting)
    
    Returns:
        MomentEstimatorResult with DM, t0, and diagnostics
    """
    if freq_ref is None:
        freq_ref = freqs.max()
    
    n_chan, n_time = dynamic_spectrum.shape
    dt = times[1] - times[0]
    
    # Compute dispersion coordinate f = K_DM * (nu^-2 - nu_ref^-2)
    f = K_DM * (freqs**-2 - freq_ref**-2)
    t_grid, f_grid = np.meshgrid(times, f)
    
    # Robust noise estimation
    med = np.median(dynamic_spectrum)
    mad = np.median(np.abs(dynamic_spectrum - med))
    sigma = 1.4826 * mad
    
    # Initial weights: only significant pixels
    excess = dynamic_spectrum - med
    w = np.where(excess > sigma_threshold * sigma, excess, 0)
    
    def solve_moments(weights):
        W = weights.sum()
        if W < 1e-10:
            return None, None, 0, np.inf
        W_f = (weights * f_grid).sum()
        W_ff = (weights * f_grid**2).sum()
        W_t = (weights * t_grid).sum()
        W_ft = (weights * f_grid * t_grid).sum()
        
        det = W * W_ff - W_f**2
        if abs(det) < 1e-20:
            return None, None, 0, np.inf
        
        dm = (W * W_ft - W_f * W_t) / det
        t0 = (W_t - W_f * dm) / W
        
        A = np.array([[W, W_f], [W_f, W_ff]])
        cond = np.linalg.cond(A)
        return dm, t0, W, cond
    
    # Initial estimate
    dm, t0, W, cond = solve_moments(w)
    if dm is None:
        return MomentEstimatorResult(dm=np.nan, t0=np.nan, condition_number=np.inf)
    
    # Iterative refinement with proximity weighting
    scale = pulse_width_hint if pulse_width_hint else max(3 * dt, 0.005)
    
    for iteration in range(max_iterations):
        # Expected arrival time at each frequency
        expected_t = t0 + f_grid * dm
        
        # Down-weight pixels far from dispersion curve
        residual_t = np.abs(t_grid - expected_t)
        proximity_weight = np.exp(-0.5 * (residual_t / scale)**2)
        w_new = w * proximity_weight
        
        dm_new, t0_new, W, cond = solve_moments(w_new)
        if dm_new is None:
            break
        
        if abs(dm_new - dm) < 0.01:
            dm, t0 = dm_new, t0_new
            break
        dm, t0 = dm_new, t0_new
    
    return MomentEstimatorResult(dm=dm, t0=t0, weights_sum=W, condition_number=cond)


In [None]:
# Test the estimator on clean data
result = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref)

print(f"True DM:      {pulse.dm:.4f} pc/cm³")
print(f"Estimated DM: {result.dm:.4f} pc/cm³")
print(f"Error:        {result.dm - pulse.dm:.4f} pc/cm³ ({100*(result.dm - pulse.dm)/pulse.dm:.2f}%)")
print(f"\nTrue t0:      {pulse.t0*1e3:.4f} ms")
print(f"Estimated t0: {result.t0*1e3:.4f} ms")
print(f"\nCondition number: {result.condition_number:.2e}")


---
## 4. Stress Tests

### 4.1 Accuracy vs S/N and DM


In [None]:
def run_accuracy_test(dm_values, snr_values, n_trials=50):
    """Test DM estimation accuracy across parameter space."""
    results = []
    
    obs = ObsParams(
        freq_lo=1100, freq_hi=1500, n_chan=256,
        t_start=0.0, t_end=0.5, n_time=512
    )
    
    for dm_true in dm_values:
        for snr in snr_values:
            dm_estimates = []
            
            for trial in range(n_trials):
                pulse = PulseParams(dm=dm_true, t0=0.15, width=0.002, amplitude=snr)
                ds, freqs, times = generate_dynamic_spectrum(pulse, obs, noise_level=1.0)
                
                result = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref)
                dm_estimates.append(result.dm)
            
            dm_estimates = np.array(dm_estimates)
            results.append({
                'dm_true': dm_true,
                'snr': snr,
                'dm_mean': np.mean(dm_estimates),
                'dm_std': np.std(dm_estimates),
                'bias': np.mean(dm_estimates) - dm_true,
                'bias_pct': 100 * (np.mean(dm_estimates) - dm_true) / dm_true,
            })
    
    return results

# Run accuracy tests
dm_values = [100, 300, 500, 1000]
snr_values = [3, 5, 10, 20, 50]

print("Running accuracy tests (this may take a minute)...")
accuracy_results = run_accuracy_test(dm_values, snr_values, n_trials=30)

# Display results
print(f"\n{'DM':>8} {'S/N':>6} {'Est. DM':>10} {'Std':>10} {'Bias':>10} {'Bias %':>10}")
print("-" * 60)
for r in accuracy_results:
    print(f"{r['dm_true']:>8.1f} {r['snr']:>6.1f} {r['dm_mean']:>10.2f} {r['dm_std']:>10.2f} "
          f"{r['bias']:>10.2f} {r['bias_pct']:>10.2f}%")


In [None]:
# Visualize accuracy results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bias vs S/N for different DMs
ax = axes[0]
for dm in dm_values:
    subset = [r for r in accuracy_results if r['dm_true'] == dm]
    snrs = [r['snr'] for r in subset]
    biases = [r['bias_pct'] for r in subset]
    ax.plot(snrs, biases, 'o-', label=f'DM = {dm}', linewidth=2, markersize=8)

ax.axhline(0, color='k', linestyle='--', alpha=0.3)
ax.set_xlabel('S/N', fontsize=12)
ax.set_ylabel('Bias (%)', fontsize=12)
ax.set_title('DM Estimation Bias vs S/N', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

# Standard deviation vs S/N
ax = axes[1]
for dm in dm_values:
    subset = [r for r in accuracy_results if r['dm_true'] == dm]
    snrs = [r['snr'] for r in subset]
    stds = [r['dm_std'] for r in subset]
    ax.plot(snrs, stds, 'o-', label=f'DM = {dm}', linewidth=2, markersize=8)

ax.set_xlabel('S/N', fontsize=12)
ax.set_ylabel('DM Std. Dev. (pc/cm³)', fontsize=12)
ax.set_title('DM Estimation Uncertainty vs S/N', fontsize=14)
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


### 4.2 Scattering Test

Scattering broadens the pulse asymmetrically, which should bias the moment estimator.


In [None]:
def test_scattering_bias(scattering_times, snr=20, n_trials=30):
    """Test how scattering affects DM estimation."""
    obs = ObsParams(
        freq_lo=1100, freq_hi=1500, n_chan=256,
        t_start=0.0, t_end=0.5, n_time=512
    )
    dm_true = 500.0
    
    results = []
    for tau_scat in scattering_times:
        dm_estimates = []
        dm_estimates_robust = []
        
        for trial in range(n_trials):
            pulse = PulseParams(
                dm=dm_true, t0=0.15, width=0.002, amplitude=snr,
                scattering_time=tau_scat, scattering_index=-4.0
            )
            ds, freqs, times = generate_dynamic_spectrum(pulse, obs, noise_level=1.0)
            
            result = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref)
            result_robust = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref, robust=True)
            
            dm_estimates.append(result.dm)
            dm_estimates_robust.append(result_robust.dm)
        
        results.append({
            'tau_scat_ms': tau_scat * 1e3,
            'bias': np.mean(dm_estimates) - dm_true,
            'bias_robust': np.mean(dm_estimates_robust) - dm_true,
            'std': np.std(dm_estimates),
            'std_robust': np.std(dm_estimates_robust),
        })
    
    return results

# Test scattering
scattering_times = [0, 0.001, 0.002, 0.005, 0.01, 0.02]  # seconds at ref freq
print("Testing scattering effects...")
scattering_results = test_scattering_bias(scattering_times)

print(f"\n{'τ_scat (ms)':>12} {'Bias':>10} {'Bias (robust)':>15} {'Std':>10} {'Std (robust)':>12}")
print("-" * 65)
for r in scattering_results:
    print(f"{r['tau_scat_ms']:>12.1f} {r['bias']:>10.2f} {r['bias_robust']:>15.2f} "
          f"{r['std']:>10.2f} {r['std_robust']:>12.2f}")


In [None]:
# Visualize scattering effects
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Show a scattered pulse
ax = axes[0]
pulse_scat = PulseParams(dm=500, t0=0.15, width=0.002, amplitude=20,
                    scattering_time=0.01, scattering_index=-4.0)
ds_scat, freqs, times = generate_dynamic_spectrum(pulse_scat, obs, noise_level=1.0)

im = ax.imshow(ds_scat, aspect='auto', origin='lower',
               extent=[times[0]*1e3, times[-1]*1e3, freqs[-1], freqs[0]])
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Frequency (MHz)')
ax.set_title(f'Scattered Pulse (τ = 10 ms at {obs.freq_ref} MHz)')
plt.colorbar(im, ax=ax)

# Bias vs scattering time
ax = axes[1]
taus = [r['tau_scat_ms'] for r in scattering_results]
biases = [r['bias'] for r in scattering_results]
biases_robust = [r['bias_robust'] for r in scattering_results]

ax.plot(taus, biases, 'o-', label='Standard', linewidth=2, markersize=8)
ax.plot(taus, biases_robust, 's--', label='Robust (Huber)', linewidth=2, markersize=8)
ax.axhline(0, color='k', linestyle=':', alpha=0.3)
ax.set_xlabel('Scattering Time at Ref. Freq (ms)', fontsize=12)
ax.set_ylabel('DM Bias (pc/cm³)', fontsize=12)
ax.set_title('Scattering-Induced DM Bias', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


### 4.3 RFI Contamination Test


In [None]:
def test_rfi_robustness(rfi_fractions, snr=20, n_trials=30):
    """Test RFI robustness."""
    obs = ObsParams(
        freq_lo=1100, freq_hi=1500, n_chan=256,
        t_start=0.0, t_end=0.5, n_time=512
    )
    dm_true = 500.0
    
    results = []
    for rfi_frac in rfi_fractions:
        dm_standard = []
        dm_robust = []
        dm_snr_weighted = []
        
        for trial in range(n_trials):
            pulse = PulseParams(dm=dm_true, t0=0.15, width=0.002, amplitude=snr)
            ds, freqs, times = generate_dynamic_spectrum(
                pulse, obs, noise_level=1.0,
                rfi_fraction=rfi_frac, rfi_strength=10.0
            )
            
            r1 = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref)
            r2 = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref, robust=True)
            r3 = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref, weight_type='snr')
            
            dm_standard.append(r1.dm)
            dm_robust.append(r2.dm)
            dm_snr_weighted.append(r3.dm)
        
        results.append({
            'rfi_frac': rfi_frac,
            'bias_standard': np.mean(dm_standard) - dm_true,
            'bias_robust': np.mean(dm_robust) - dm_true,
            'bias_snr': np.mean(dm_snr_weighted) - dm_true,
            'std_standard': np.std(dm_standard),
            'std_robust': np.std(dm_robust),
            'std_snr': np.std(dm_snr_weighted),
        })
    
    return results

rfi_fractions = [0, 0.05, 0.1, 0.2, 0.3, 0.4]
print("Testing RFI robustness...")
rfi_results = test_rfi_robustness(rfi_fractions)

print(f"\n{'RFI %':>8} {'Bias (std)':>12} {'Bias (robust)':>14} {'Bias (SNR)':>12}")
print("-" * 50)
for r in rfi_results:
    print(f"{100*r['rfi_frac']:>8.0f} {r['bias_standard']:>12.2f} {r['bias_robust']:>14.2f} {r['bias_snr']:>12.2f}")


In [None]:
# Visualize RFI effects
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Show RFI-contaminated spectrum
ax = axes[0]
pulse = PulseParams(dm=500, t0=0.15, width=0.002, amplitude=20)
ds_rfi, freqs, times = generate_dynamic_spectrum(
    pulse, obs, noise_level=1.0, rfi_fraction=0.2, rfi_strength=10.0
)

im = ax.imshow(ds_rfi, aspect='auto', origin='lower',
               extent=[times[0]*1e3, times[-1]*1e3, freqs[-1], freqs[0]],
               vmin=-3, vmax=25)
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Frequency (MHz)')
ax.set_title('RFI-Contaminated Spectrum (20% channels)')
plt.colorbar(im, ax=ax)

# Bias vs RFI fraction
ax = axes[1]
rfi_pcts = [100 * r['rfi_frac'] for r in rfi_results]

ax.plot(rfi_pcts, [r['bias_standard'] for r in rfi_results], 'o-', label='Standard', linewidth=2, markersize=8)
ax.plot(rfi_pcts, [r['bias_robust'] for r in rfi_results], 's--', label='Robust (Huber)', linewidth=2, markersize=8)
ax.plot(rfi_pcts, [r['bias_snr'] for r in rfi_results], '^:', label='SNR-weighted', linewidth=2, markersize=8)
ax.axhline(0, color='k', linestyle=':', alpha=0.3)
ax.set_xlabel('RFI Channel Fraction (%)', fontsize=12)
ax.set_ylabel('DM Bias (pc/cm³)', fontsize=12)
ax.set_title('RFI-Induced DM Bias', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


### 4.4 Very Low S/N Stress Test


In [None]:
def test_low_snr(snr_values, n_trials=100):
    """Test behavior at very low S/N."""
    obs = ObsParams(
        freq_lo=1100, freq_hi=1500, n_chan=256,
        t_start=0.0, t_end=0.5, n_time=512
    )
    dm_true = 500.0
    
    results = []
    for snr in snr_values:
        dm_estimates = []
        
        for trial in range(n_trials):
            pulse = PulseParams(dm=dm_true, t0=0.15, width=0.002, amplitude=snr)
            ds, freqs, times = generate_dynamic_spectrum(pulse, obs, noise_level=1.0)
            
            result = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref)
            dm_estimates.append(result.dm)
        
        dm_estimates = np.array(dm_estimates)
        # Remove extreme outliers for visualization
        valid = np.abs(dm_estimates - dm_true) < 2000
        
        results.append({
            'snr': snr,
            'mean': np.mean(dm_estimates[valid]) if valid.sum() > 0 else np.nan,
            'median': np.median(dm_estimates[valid]) if valid.sum() > 0 else np.nan,
            'std': np.std(dm_estimates[valid]) if valid.sum() > 0 else np.nan,
            'outlier_frac': 1 - valid.mean(),
            'raw_estimates': dm_estimates,
        })
    
    return results

print("Testing very low S/N regime...")
low_snr_values = [0.5, 1, 1.5, 2, 3, 4, 5, 7, 10]
low_snr_results = test_low_snr(low_snr_values, n_trials=100)

print(f"\n{'S/N':>6} {'Mean DM':>10} {'Median DM':>10} {'Std':>10} {'Outliers':>10}")
print("-" * 50)
for r in low_snr_results:
    print(f"{r['snr']:>6.1f} {r['mean']:>10.1f} {r['median']:>10.1f} {r['std']:>10.1f} {100*r['outlier_frac']:>9.1f}%")


In [None]:
# Visualize low S/N behavior
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram of estimates at different S/N
ax = axes[0]
colors = plt.cm.viridis(np.linspace(0, 1, len(low_snr_results)))

for i, r in enumerate(low_snr_results):
    estimates = r['raw_estimates']
    valid = np.abs(estimates - 500) < 500  # Within ±500 of true
    if valid.sum() > 5:
        ax.hist(estimates[valid], bins=30, alpha=0.5, label=f"S/N = {r['snr']}", color=colors[i])

ax.axvline(500, color='red', linestyle='--', label='True DM', linewidth=2)
ax.set_xlabel('Estimated DM (pc/cm³)', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.set_title('DM Estimate Distribution vs S/N', fontsize=14)
ax.legend(loc='upper left', fontsize=9)
ax.set_xlim(300, 700)

# Precision vs S/N
ax = axes[1]
snrs = [r['snr'] for r in low_snr_results]
stds = [r['std'] for r in low_snr_results]

ax.loglog(snrs, stds, 'o-', linewidth=2, markersize=8)

# Theoretical scaling: σ_DM ∝ 1/S/N
snr_fit = np.array(snrs)
expected = stds[4] * snrs[4] / snr_fit  # Normalize to S/N=5 point
ax.loglog(snr_fit, expected, '--', color='gray', label=r'$\propto 1/{\rm S/N}$')

ax.set_xlabel('S/N', fontsize=12)
ax.set_ylabel('DM Std. Dev. (pc/cm³)', fontsize=12)
ax.set_title('Estimation Precision vs S/N', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


---
## 5. Comparison with Trial Dedispersion (Bowtie)

Implement a simple bowtie method to compare accuracy and speed.


In [None]:
def bowtie_dm_estimator(
    dynamic_spectrum: np.ndarray,
    freqs: np.ndarray,
    times: np.ndarray,
    dm_min: float = 0,
    dm_max: float = 1000,
    n_dm_trials: int = 1000,
    freq_ref: Optional[float] = None,
) -> Tuple[float, np.ndarray, np.ndarray]:
    """Traditional trial dedispersion."""
    if freq_ref is None:
        freq_ref = freqs.max()
    
    dm_trials = np.linspace(dm_min, dm_max, n_dm_trials)
    n_chan, n_time = dynamic_spectrum.shape
    dt = times[1] - times[0]
    
    snr_curve = np.zeros(n_dm_trials)
    
    for i, dm in enumerate(dm_trials):
        # Compute shift for each channel
        delays = dispersion_delay(dm, freqs, freq_ref)
        shifts = (delays / dt).astype(int)
        
        # Dedisperse by shifting
        dedispersed = np.zeros(n_time)
        for ch in range(n_chan):
            shift = shifts[ch]
            if shift >= 0 and shift < n_time:
                dedispersed += np.roll(dynamic_spectrum[ch], -shift)
        
        # S/N: peak / std
        snr_curve[i] = dedispersed.max() / (dedispersed.std() + 1e-10)
    
    # Find best DM with parabolic interpolation
    best_idx = np.argmax(snr_curve)
    
    if best_idx > 0 and best_idx < n_dm_trials - 1:
        # Parabolic interpolation
        y0, y1, y2 = snr_curve[best_idx-1:best_idx+2]
        delta = 0.5 * (y0 - y2) / (y0 - 2*y1 + y2 + 1e-10)
        dm_best = dm_trials[best_idx] + delta * (dm_trials[1] - dm_trials[0])
    else:
        dm_best = dm_trials[best_idx]
    
    return dm_best, dm_trials, snr_curve


# Compare methods
obs = ObsParams(
    freq_lo=1100, freq_hi=1500, n_chan=256,
    t_start=0.0, t_end=0.5, n_time=512
)
pulse = PulseParams(dm=500.0, t0=0.15, width=0.002, amplitude=10)
ds, freqs, times = generate_dynamic_spectrum(pulse, obs, noise_level=1.0)

# Time both methods
t0 = time.time()
moment_result = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref)
t_moment = time.time() - t0

t0 = time.time()
bowtie_dm, dm_trials, snr_curve = bowtie_dm_estimator(
    ds, freqs, times, dm_min=400, dm_max=600, n_dm_trials=200
)
t_bowtie = time.time() - t0

print(f"True DM: {pulse.dm:.2f} pc/cm³")
print(f"\nMoment method: DM = {moment_result.dm:.2f} pc/cm³, time = {t_moment*1e3:.2f} ms")
print(f"Bowtie method: DM = {bowtie_dm:.2f} pc/cm³, time = {t_bowtie*1e3:.2f} ms")
print(f"\nSpeedup: {t_bowtie/t_moment:.1f}x")


In [None]:
# Visualize bowtie curve
fig, ax = plt.subplots(figsize=(10, 5))

ax.plot(dm_trials, snr_curve, 'b-', linewidth=1.5)
ax.axvline(pulse.dm, color='red', linestyle='--', label='True DM', linewidth=2)
ax.axvline(bowtie_dm, color='green', linestyle=':', label=f'Bowtie: {bowtie_dm:.1f}', linewidth=2)
ax.axvline(moment_result.dm, color='orange', linestyle='-.', label=f'Moment: {moment_result.dm:.1f}', linewidth=2)

ax.set_xlabel('DM (pc/cm³)', fontsize=12)
ax.set_ylabel('S/N', fontsize=12)
ax.set_title('Trial Dedispersion S/N Curve', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


---
## 6. GPU Implementation


In [None]:
if HAS_GPU:
    def moment_dm_estimator_gpu(
        dynamic_spectrum: np.ndarray,
        freqs: np.ndarray,
        times: np.ndarray,
        freq_ref: Optional[float] = None,
    ) -> MomentEstimatorResult:
        """GPU-accelerated moment-based DM estimator."""
        if freq_ref is None:
            freq_ref = freqs.max()
        
        # Transfer to GPU
        ds_gpu = cp.asarray(dynamic_spectrum)
        freqs_gpu = cp.asarray(freqs)
        times_gpu = cp.asarray(times)
        
        # Compute dispersion coordinate
        f = K_DM * (freqs_gpu**-2 - freq_ref**-2)
        
        # Create grids
        t_grid, f_grid = cp.meshgrid(times_gpu, f)
        
        # Weights
        w = cp.maximum(ds_gpu, 0)
        
        # Compute moments (all reductions)
        W = w.sum()
        W_f = (w * f_grid).sum()
        W_ff = (w * f_grid**2).sum()
        W_t = (w * t_grid).sum()
        W_ft = (w * f_grid * t_grid).sum()
        
        # Solve
        det = W * W_ff - W_f**2
        dm = (W * W_ft - W_f * W_t) / det
        t0 = (W_t - W_f * dm) / W
        
        # Transfer back
        return MomentEstimatorResult(
            dm=float(dm.get()),
            t0=float(t0.get()),
            weights_sum=float(W.get()),
        )
    
    # Benchmark GPU vs CPU
    print("Benchmarking GPU vs CPU...\n")
    
    sizes = [(256, 512), (512, 1024), (1024, 2048), (2048, 4096)]
    
    for n_chan, n_time in sizes:
        obs_test = ObsParams(
            freq_lo=1100, freq_hi=1500, n_chan=n_chan,
            t_start=0.0, t_end=0.5, n_time=n_time
        )
        pulse_test = PulseParams(dm=500.0, t0=0.15, width=0.002, amplitude=10)
        ds_test, freqs_test, times_test = generate_dynamic_spectrum(pulse_test, obs_test, noise_level=1.0)
        
        # Warm up GPU
        _ = moment_dm_estimator_gpu(ds_test, freqs_test, times_test, freq_ref=obs_test.freq_ref)
        cp.cuda.Stream.null.synchronize()
        
        # CPU timing
        n_iter = 100
        t0 = time.time()
        for _ in range(n_iter):
            _ = moment_dm_estimator(ds_test, freqs_test, times_test, freq_ref=obs_test.freq_ref)
        t_cpu = (time.time() - t0) / n_iter
        
        # GPU timing
        t0 = time.time()
        for _ in range(n_iter):
            _ = moment_dm_estimator_gpu(ds_test, freqs_test, times_test, freq_ref=obs_test.freq_ref)
            cp.cuda.Stream.null.synchronize()
        t_gpu = (time.time() - t0) / n_iter
        
        print(f"Size {n_chan}x{n_time}: CPU = {t_cpu*1e3:.3f} ms, GPU = {t_gpu*1e3:.3f} ms, Speedup = {t_cpu/t_gpu:.1f}x")

else:
    print("GPU not available, skipping GPU benchmarks")


---
## 7. Summary and Conclusions

### Strengths of Moment-Based Estimator
1. **Speed**: Single pass through data, O(N_chan × N_time) complexity
2. **Simplicity**: Closed-form solution, no iterative optimization
3. **GPU-friendly**: All operations are reductions

### Weaknesses
1. **Scattering bias**: Asymmetric pulse profiles cause systematic bias
2. **RFI sensitivity**: Bright RFI dominates intensity weights
3. **Low S/N**: Variance increases rapidly below S/N ~ 5

### Recommendations
- Use robust (Huber) weighting for RFI mitigation
- Consider SNR-based weighting in RFI-heavy environments
- For scattered pulses, consider Radon/Hough methods (next notebook)


In [None]:
# Final summary plot
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

obs = ObsParams(
    freq_lo=1100, freq_hi=1500, n_chan=256,
    t_start=0.0, t_end=0.5, n_time=512
)

# 1. Clean pulse estimation
ax = axes[0, 0]
pulse = PulseParams(dm=500.0, t0=0.15, width=0.002, amplitude=15)
ds, freqs, times = generate_dynamic_spectrum(pulse, obs, noise_level=1.0)
im = ax.imshow(ds, aspect='auto', origin='lower',
               extent=[times[0]*1e3, times[-1]*1e3, freqs[-1], freqs[0]])
result = moment_dm_estimator(ds, freqs, times, freq_ref=obs.freq_ref)
ax.set_title(f'Clean Pulse: True DM = {pulse.dm}, Est = {result.dm:.1f}')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Freq (MHz)')

# 2. Scattered pulse
ax = axes[0, 1]
pulse_scat = PulseParams(dm=500.0, t0=0.15, width=0.002, amplitude=15,
                         scattering_time=0.01)
ds_scat, _, _ = generate_dynamic_spectrum(pulse_scat, obs, noise_level=1.0)
im = ax.imshow(ds_scat, aspect='auto', origin='lower',
               extent=[times[0]*1e3, times[-1]*1e3, freqs[-1], freqs[0]])
result_scat = moment_dm_estimator(ds_scat, freqs, times, freq_ref=obs.freq_ref)
ax.set_title(f'Scattered: True DM = {pulse.dm}, Est = {result_scat.dm:.1f}')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Freq (MHz)')

# 3. RFI contaminated
ax = axes[1, 0]
ds_rfi, _, _ = generate_dynamic_spectrum(pulse, obs, noise_level=1.0,
                                         rfi_fraction=0.2, rfi_strength=10)
im = ax.imshow(ds_rfi, aspect='auto', origin='lower',
               extent=[times[0]*1e3, times[-1]*1e3, freqs[-1], freqs[0]],
               vmin=-3, vmax=20)
result_rfi = moment_dm_estimator(ds_rfi, freqs, times, freq_ref=obs.freq_ref)
result_rfi_robust = moment_dm_estimator(ds_rfi, freqs, times, freq_ref=obs.freq_ref, robust=True)
ax.set_title(f'RFI: Est = {result_rfi.dm:.1f} (std), {result_rfi_robust.dm:.1f} (robust)')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Freq (MHz)')

# 4. Low S/N
ax = axes[1, 1]
pulse_weak = PulseParams(dm=500.0, t0=0.15, width=0.002, amplitude=3)
ds_weak, _, _ = generate_dynamic_spectrum(pulse_weak, obs, noise_level=1.0)
im = ax.imshow(ds_weak, aspect='auto', origin='lower',
               extent=[times[0]*1e3, times[-1]*1e3, freqs[-1], freqs[0]])
result_weak = moment_dm_estimator(ds_weak, freqs, times, freq_ref=obs.freq_ref)
ax.set_title(f'Low S/N (3σ): True DM = {pulse.dm}, Est = {result_weak.dm:.1f}')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Freq (MHz)')

plt.tight_layout()
plt.savefig('moment_estimator_summary.png', dpi=150)
plt.show()

print("Notebook complete!")
