In [None]:
# Cell 1 — imports & helpers
import time, math, json, os
import numpy as np
import pandas as pd
from pathlib import Path

from flovopy.asl.asl import ASL, _resolve_misfit_backend
from flovopy.core.mvo import dome_location
from flovopy.asl.map import topo_map
from obspy import Stream

def _to_lonlat_from_xy_km(xk, yk, ref_lon, ref_lat):
    from flovopy.asl.grid import _meters_per_degree
    m_per_deg_lat, m_per_deg_lon = _meters_per_degree(ref_lat)
    lon = ref_lon + (xk * 1000.0) / m_per_deg_lon
    lat = ref_lat + (yk * 1000.0) / m_per_deg_lat
    return float(lon), float(lat)

def _source_row(tag, lon, lat, t_elapsed, extra=None):
    row = {"method": tag, "lon": lon, "lat": lat, "elapsed_s": round(t_elapsed,3)}
    if extra: row.update(extra)
    return row

def _pick_peak_time_index(aslobj):
    if getattr(aslobj, "source", None) is not None and "DR" in aslobj.source:
        return int(np.nanargmax(aslobj.source["DR"]))
    raise RuntimeError("ASL.source missing or no DR; run fast_locate() once or pass time_index")

In [None]:
# Cell 2 — method runners

def run_grid_search(stream: Stream, cfg, *, misfit_engine: str, min_stations=None, region=None):
    """
    Build ASL and run grid fast_locate() with the chosen misfit engine.
    Returns (aslobj, result_dict) where result_dict has lon/lat and timings.
    """
    # Ensure cfg is built (distances, ampcorr, inventory)
    if cfg.node_distances_km is None or cfg.ampcorr is None or getattr(cfg, "inventory", None) is None:
        cfg = cfg.build()

    # Build SAM + ASL
    sam = cfg.sam_class(stream=stream, sampling_interval=1.0)
    asl = ASL(sam, cfg)  # your new ASL signature that takes cfg
    idx = cfg.gridobj.get_mask_indices()
    if idx is not None and getattr(idx, "size", 0):
        asl._node_mask = idx

    # Misfit backend
    backend = _resolve_misfit_backend(
        misfit_engine,
        peakf_hz=float(cfg.ampcorr.params.peakf),
        speed_kms=float(cfg.ampcorr.params.wave_speed_kms),
    )

    t0 = time.time()
    asl.fast_locate(min_stations=int(min_stations or cfg.min_stations), misfit_backend=backend)
    t1 = time.time()

    # Best location at peak DR
    jstar_t = int(np.nanargmax(asl.source["DR"]))
    lon = float(asl.source["lon"][jstar_t]); lat = float(asl.source["lat"][jstar_t])

    res = _source_row(f"grid:{misfit_engine}", lon, lat, t1-t0, {
        "nsta_used": int(asl.source.get("nsta", [np.nan])[jstar_t]),
        "peak_DR": float(asl.source["DR"][jstar_t]),
        "peak_t_index": jstar_t,
    })
    return asl, res


def run_inverse_linear(stream: Stream, cfg, *, seed_from_grid=True, grid_misfit_engine="l2"):
    """
    Hybrid: (optional) grid seed → inverse_locate() (linearized β, fixed distances).
    Returns result dict with lon/lat and parameters.
    """
    # build once
    if cfg.node_distances_km is None or cfg.ampcorr is None or getattr(cfg, "inventory", None) is None:
        cfg = cfg.build()

    # Prepare ASL (and optionally seed with grid search)
    sam = cfg.sam_class(stream=stream, sampling_interval=1.0)
    asl = ASL(sam, cfg)

    init_lonlat = None
    if seed_from_grid:
        asl_seed, _ = run_grid_search(stream, cfg, misfit_engine=grid_misfit_engine)
        ti = _pick_peak_time_index(asl_seed)
        init_lonlat = (float(asl_seed.source["lon"][ti]), float(asl_seed.source["lat"][ti]))

    t0 = time.time()
    out = asl.inverse_locate(init_lonlat=init_lonlat)  # the hybrid LSQ you added earlier
    t1 = time.time()

    lon, lat = float(out["lon"]), float(out["lat"])
    res = _source_row("inverse:linear", lon, lat, t1-t0, {
        "N_hat": float(out.get("N", np.nan)),
        "k_hat": float(out.get("k", np.nan)),
        "logA0": float(out.get("logA0", np.nan)),
        "nsta_used": int(out.get("nsta", 0)),
        "sse": float(out.get("sse", np.nan)),
    })
    return res


def run_inverse_nonlinear(stream: Stream, cfg, *,
                          seed_from_grid=True, grid_misfit_engine="l2",
                          use_3d=False, source_elev_mode="zero"):
    """
    Full nonlinear fit for (x,y[,z], logA0, N, k) with analytic Jacobian.
    """
    if cfg.node_distances_km is None or cfg.ampcorr is None or getattr(cfg, "inventory", None) is None:
        cfg = cfg.build()

    sam = cfg.sam_class(stream=stream, sampling_interval=1.0)
    asl = ASL(sam, cfg)

    init_lonlat = None
    if seed_from_grid:
        asl_seed, _ = run_grid_search(stream, cfg, misfit_engine=grid_misfit_engine)
        ti = _pick_peak_time_index(asl_seed)
        init_lonlat = (float(asl_seed.source["lon"][ti]), float(asl_seed.source["lat"][ti]))

    t0 = time.time()
    out = asl.inverse_locate_nonlinear(
        init_lonlat=init_lonlat,
        use_3d=bool(use_3d),
        source_elev_mode=source_elev_mode,  # "zero"|"fixed"|"dem"
        robust_loss="soft_l1",
        f_scale=1.0,
        verbose=False,
    )
    t1 = time.time()

    lon, lat = float(out["lon"]), float(out["lat"])
    res = _source_row(f"inverse:nonlinear:{'3d' if use_3d else '2d'}", lon, lat, t1-t0, {
        "N_hat": float(out.get("N", np.nan)),
        "k_hat": float(out.get("k", np.nan)),
        "logA0": float(out.get("logA0", np.nan)),
        "nsta_used": int(out.get("nsta", 0)),
        "sse": float(out.get("sse", np.nan)),
        "success": bool(out.get("success", False)),
        "message": str(out.get("message", ""))[:120],
    })
    return res

In [None]:
# Cell 3 — master comparison

def compare_location_methods(stream: Stream, cfg,
                             grid_misfits=("l2","std_over_mean","r2_decay"),
                             do_linear=True, do_nonlinear=True,
                             nonlinear_3d=(False, True),
                             plot=True, region=None, dem_tif=None, add_labels=True):
    """
    Runs:
      - grid search with each misfit backend in grid_misfits
      - inverse_locate (linearized) if do_linear
      - inverse_locate_nonlinear (2-D and/or 3-D) if do_nonlinear
    Returns a pandas DataFrame of results and (optionally) a map.
    """
    results = []
    asl_seed = None

    # Grid search variants
    for m in grid_misfits:
        try:
            asl_seed, row = run_grid_search(stream, cfg, misfit_engine=m)
            results.append(row)
        except Exception as e:
            print(f"[WARN] grid:{m} failed: {e}")

    # Linear hybrid
    if do_linear:
        try:
            row = run_inverse_linear(stream, cfg, seed_from_grid=True, grid_misfit_engine=grid_misfits[0])
            results.append(row)
        except Exception as e:
            print(f"[WARN] inverse:linear failed: {e}")

    # Nonlinear 2-D / 3-D
    if do_nonlinear:
        for flag3d in nonlinear_3d:
            try:
                row = run_inverse_nonlinear(stream, cfg, seed_from_grid=True,
                                            grid_misfit_engine=grid_misfits[0],
                                            use_3d=bool(flag3d),
                                            source_elev_mode=("dem" if flag3d else "zero"))
                results.append(row)
            except Exception as e:
                print(f"[WARN] inverse:nonlinear ({'3D' if flag3d else '2D'}) failed: {e}")

    df = pd.DataFrame(results)

    # Plot all solutions on the same basemap for visual compare
    fig = None
    if plot and not df.empty:
        kw = dict(region=region, dem_tif=dem_tif, add_labels=add_labels,
                  topo_color=True, cmap="topo", show=False, frame=True, title="ASL method comparison")
        fig = topo_map(**kw)

        # method → marker style
        styles = {
            "grid:l2": ("c0.28c", "black"),
            "grid:std_over_mean": ("c0.28c", "darkblue"),
            "grid:r2_decay": ("c0.28c", "darkgreen"),
            "inverse:linear": ("a0.32c", "red"),
            "inverse:nonlinear:2d": ("t0.32c", "orange"),
            "inverse:nonlinear:3d": ("t0.32c", "purple"),
        }

        for _, r in df.iterrows():
            tag = r["method"]
            lon, lat = float(r["lon"]), float(r["lat"])
            style, color = styles.get(tag, ("c0.24c", "black"))
            fig.plot(x=[lon], y=[lat], style=style, pen="0.6p,white", fill=color)

        # Dome reference
        fig.plot(x=dome_location["lon"], y=dome_location["lat"], style="a0.34c", fill="red", pen="0.6p,white")
        display(fig)
    return df, fig

In [None]:
# Cell 4 — run on one preprocessed Stream + cfg
# Assumes you already created `stream` and `cfg` (your normal preprocessing).
# Use Montserrat inventory/dome_location via cfg; distances are already km.

df_compare, fig = compare_location_methods(
    stream=stream,
    cfg=cfg,
    grid_misfits=("l2", "std_over_mean", "r2_decay"),
    do_linear=True,
    do_nonlinear=True,
    nonlinear_3d=(False, True),       # test both 2-D and 3-D nonlinear
    plot=True,
    region=cfg.region if hasattr(cfg, "region") else None,
    dem_tif=getattr(cfg, "dem_tif_for_bmap", None),
    add_labels=True,
)

print(df_compare.sort_values("elapsed_s"))

In [None]:
# Optional: add this cell if you want side-by-side comparisons inside the notebook
from flovopy.asl.compare import (
    compare_asl_sources,
    plot_asl_source_comparison,
)

# assume you have ASL objects: asl_grid, asl_invlin, asl_invnl (or whatever you named them)
cmp1 = compare_asl_sources(asl_grid, asl_invlin)
cmp2 = compare_asl_sources(asl_grid, asl_invnl)

print("Grid vs Linear inverse:", {k:v for k,v in cmp1.items() if k!='location_km'})
print("  mean sep (km):", cmp1["location_km"]["mean"], "max:", cmp1["location_km"]["max"])
print("Grid vs Nonlinear inverse:", {k:v for k,v in cmp2.items() if k!='location_km'})
print("  mean sep (km):", cmp2["location_km"]["mean"], "max:", cmp2["location_km"]["max"])

# quick plots
plot_asl_source_comparison(asl_grid, asl_invlin, title="Grid vs Linear inverse", show=True)
plot_asl_source_comparison(asl_grid, asl_invnl,  title="Grid vs Nonlinear inverse", show=True)

In [None]:
# smoke test
# --- Smoke test for flovopy.asl.compare (new + legacy wrappers) ---

import numpy as np
from pathlib import Path
import pandas as pd

from flovopy.asl.compare import (
    SourceTrack,
    compare_tracks,
    plot_source_comparison,
    compare_asl_sources,              # legacy wrapper (ASL objects)
    plot_asl_source_comparison,       # legacy wrapper (ASL objects)
)

OUT = Path("./_compare_smoketest")
OUT.mkdir(exist_ok=True, parents=True)

# -------------------------------------------------------------------
# 1) NEW API: build two SourceTrack objects with slight offsets
# -------------------------------------------------------------------
t0 = np.datetime64("2024-01-01T00:00:00")
t  = t0 + np.arange(20).astype("timedelta64[s]")

# Track A: straight line drifting NE near Montserrat
latA = 16.72  + 0.001 * np.linspace(0, 1, t.size)
lonA = -62.19 + 0.0015 * np.linspace(0, 1, t.size)
DR_A = 10 + 2*np.sin(np.linspace(0, 2*np.pi, t.size))
nsta_A = 12 + (np.random.rand(t.size) > 0.2).astype(int)

trackA = SourceTrack(t=t, lat=latA, lon=lonA, DR=DR_A, nsta=nsta_A, tag="Track-A")

# Track B: same path, slightly shifted (≈ few hundred meters–km)
latB = latA + 0.0004*np.cos(np.linspace(0, 2*np.pi, t.size))
lonB = lonA + 0.0004*np.sin(np.linspace(0, 2*np.pi, t.size))
DR_B = DR_A * (1.0 + 0.02*np.sin(np.linspace(0, 4*np.pi, t.size)))  # tiny amplitude difference
nsta_B = nsta_A - (np.random.rand(t.size) > 0.8).astype(int)

trackB = SourceTrack(t=t, lat=latB, lon=lonB, DR=DR_B, nsta=nsta_B, tag="Track-B")

# Compare (new API)
cmp_new = compare_tracks(trackA, trackB)
print("[NEW] mean separation (km):", cmp_new["location_km"]["mean"])
print("[NEW] p90 separation (km):",  cmp_new["location_km"]["p90"])
print("[NEW] max separation (km):",  cmp_new["location_km"]["max"])

# Plot (new API)
plot_source_comparison(
    trackA, trackB,
    title="Smoke test: SourceTrack comparison",
    show=False,
    outfile=str(OUT / "smoketest_new_api.png"),
)
print(f"[NEW] Wrote: {OUT/'smoketest_new_api.png'}")

# -------------------------------------------------------------------
# 2) LEGACY WRAPPERS: mock two ASL objects with .source dicts
# -------------------------------------------------------------------
class MockASL:
    def __init__(self, t, lat, lon, DR=None, azgap=None, nsta=None, tag="ASL"):
        # mimic ASL.source dict (ObsPy UTCDateTime is also supported, but datetime64 works)
        self.source = {
            "t": t, "lat": lat, "lon": lon,
            "DR": DR, "azgap": azgap, "nsta": nsta,
        }
        self.tag = tag

asl1 = MockASL(t=t, lat=latA, lon=lonA, DR=DR_A, nsta=nsta_A, tag="ASL-1")
asl2 = MockASL(t=t, lat=latB, lon=lonB, DR=DR_B, nsta=nsta_B, tag="ASL-2")

# Compare (legacy wrapper)
cmp_legacy = compare_asl_sources(asl1, asl2)
print("[LEGACY] mean separation (km):", cmp_legacy["location_km"]["mean"])
print("[LEGACY] p90 separation (km):",  cmp_legacy["location_km"]["p90"])
print("[LEGACY] max separation (km):",  cmp_legacy["location_km"]["max"])

# Plot (legacy wrapper)
plot_asl_source_comparison(
    asl1, asl2,
    title="Smoke test: legacy ASL comparison",
    show=False,
    outfile=str(OUT / "smoketest_legacy_api.png"),
)
print(f"[LEGACY] Wrote: {OUT/'smoketest_legacy_api.png'}")

print("\n[OK] Smoke test complete.")