# 2) Notebook variants (what to toggle)

Create one notebook per item below. In Cell 4, adjust only the mentioned field:

	1.	Wave speed: A.speed = 1.5, B.speed = 3.2

	2.	Q factor: A.Q = 50, B.Q = 200

	3.	Grid kind: A.grid_kind = “regular”, B.grid_kind = “streams”

(Set CHANNELS_DIR for streams; keep others identical.)

	4.	Wave kind: A.wave_kind = “surface”, B.wave_kind = “body”

	5.	Station corrections switch: A.station_corr = False, B.station_corr = True

(If you have a gains CSV and a hook in asl_sausage, add that there. The template passes None to keep it simple; the goal is just to verify toggling works end-to-end.)

	6.	Distance mode: A.dist_mode = “2d”, B.dist_mode = “3d”

	7.	Misfit engine: A.misfit_engine = “l2”, B.misfit_engine = “huber”

Notes / tips:

	•	These notebooks intentionally avoid shared caches so each run is clean and self-contained for debugging. Once you’re happy, you can wire in your shared cache dirs if you want speed.

	•	If you want the map outputs to always land at the same filename for easy diffing, keep the same PEAKF, Q, etc., and your ASL plotting code will produce deterministically named files (e.g., map_Q50_F8.png).
	
	•	If you want to include the misfit heatmap in these quick tests, make sure your asl_sausage (or the called ASL plotting) writes it (e.g., misfit_heatmap.png). The summary table checks for its existence.

If you want, I can also give you a minimal “compare images side-by-side” cell you can drop in to display map_Q*_F*.png from A and B inline.

In [None]:
from __future__ import annotations
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Optional, Dict, Any, Tuple

import os, sys, time, traceback
import numpy as np
import pandas as pd
from obspy import read, read_inventory
from glob import glob

# flovopy bits
from flovopy.core.mvo import dome_location
from flovopy.asl.asl_old import asl_sausage
from flovopy.asl.grid import make_grid, nodegrid_from_channel_csvs, Grid, NodeGrid
from flovopy.asl.distances import compute_or_load_distances, distances_signature
from flovopy.asl.ampcorr import AmpCorr, AmpCorrParams

# plotting helpers
from flovopy.asl.map import topo_map

homedir = os.path.expanduser('~')
projectdir = os.path.join(homedir, 'Dropbox', 'BRIEFCASE', 'SSADenver')
localprojectdir = os.path.join(homedir, 'work', 'PROJECTS', 'SSADenver_local')

originalDEM = os.path.join(homedir, "Dropbox/PROFESSIONAL/DATA/WadgeDEMs/conversions/DEM_1999_WGS84_rotated.tif")
CHANNELSDIR = os.path.join(projectdir, 'channel_finder')
fixedDEM1999 = os.path.join(CHANNELSDIR, '02_dem_flipped_horizontal.tif') # made from tif above, after flipping
GLOBAL_CACHE = os.path.join(localprojectdir, 'asl_global_cache')
invXML = os.path.join(projectdir, 'MV.xml')
MSEED_DIR = os.path.join(projectdir, 'ASL_inputs', 'biggest_pdc_events')
mseedfiles = sorted(glob(os.path.join(MSEED_DIR, '*.cleaned')))

# ==== EDIT THESE ====
MSEED_FILE      = mseedfiles[0]          # one event
INVENTORY_XML   = os.path.join(projectdir, 'MV.xml')  # StationXML
OUTPUT_BASE     = os.path.join(localprojectdir, "asl_notebooks")                # where runs are written
CHANNELS_DIR    = None                                # required if grid_kind="streams"
CHANNELS_STEP_M = 100.0
CHANNELS_DEM    = None                                # optional GeoTIFF for NodeGrid
REGULAR_GRID_DEM = "pygmt:01s"                        # or "geotiff:/abs/path.tif" or None
DEM_TIF_BMAP    = None                                # optional basemap GeoTIFF
NODE_SPACING_M  = 50
MIN_STATIONS    = 5
METRIC          = "VT"                                # or "mean" etc.
WINDOW_SECONDS  = 5
PEAKF           = 8.0
ISLAND_REGION = [-62.255, -62.135, 16.66, 16.84]


In [None]:
@dataclass(frozen=True)
class MiniConfig:
    grid_kind: str           # 'streams' or 'regular'
    wave_kind: str           # 'surface' or 'body'
    station_corr: bool       # True/False (we won't load a CSV here—just a switch)
    speed: float             # km/s
    Q: int
    dist_mode: str           # '2d' or '3d'
    misfit_engine: str       # 'l2' or 'huber'

    def tag(self) -> str:
        parts = [
            f"G_{self.grid_kind}",
            f"W_{self.wave_kind}",
            f"SC_{'on' if self.station_corr else 'off'}",
            f"V_{self.speed:g}",
            f"Q_{self.Q}",
            f"D_{self.dist_mode}",
            f"M_{self.misfit_engine}",
        ]
        return "__".join(parts)

In [None]:
def run_single_event(
    cfg: MiniConfig,
    *,
    mseed_file: str,
    inventory_xml: str,
    output_base: str,
    node_spacing_m: int = 50,
    metric: str = "VT",
    window_seconds: int = 5,
    peakf: float = 8.0,
    channels_dir: Optional[str] = None,
    channels_step_m: float = 100.0,
    channels_dem_tif: Optional[str] = None,
    regular_grid_dem: Optional[str] = "pygmt:01s",
    dem_tif_for_bmap: Optional[str] = None,
    simple_basemap: bool = True,
) -> Dict[str, Any]:
    """
    Minimal, notebook-friendly runner (with per-run caches).
    """
    import time, traceback
    t0 = time.time()
    outdir = Path(output_base) / cfg.tag()
    outdir.mkdir(parents=True, exist_ok=True)

    # --- tiny caches for this run ---
    cache_root      = Path(output_base) / "_nbcache"
    dist_cache_dir  = cache_root / "distances" / cfg.tag()
    ampcorr_cache   = cache_root / "ampcorr" / cfg.tag()
    dist_cache_dir.mkdir(parents=True, exist_ok=True)
    ampcorr_cache.mkdir(parents=True, exist_ok=True)

    def log(msg):
        ts = time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{ts}] [{cfg.tag()}] {msg}")

    try:
        inv = read_inventory(inventory_xml)

        # --- Build Grid
        if cfg.grid_kind == "streams":
            if not channels_dir:
                raise ValueError("grid_kind='streams' requires CHANNELS_DIR")
            gridobj = nodegrid_from_channel_csvs(
                channels_dir=channels_dir,
                step_m=channels_step_m,
                dem_tif=channels_dem_tif,
                approx_spacing_m=channels_step_m,
                max_points=None,
            )
        else:
            dem_spec = None
            if regular_grid_dem:
                if regular_grid_dem.startswith("pygmt:"):
                    res = regular_grid_dem.split(":", 1)[1] or "01s"
                    dem_spec = ("pygmt", {"resolution": res, "cache_dir": None, "tag": res})
                elif regular_grid_dem.startswith("geotiff:"):
                    path = regular_grid_dem.split(":", 1)[1]
                    dem_spec = ("geotiff", {"path": path, "tag": Path(path).name})
            gridobj = make_grid(
                center_lat=dome_location["lat"],
                center_lon=dome_location["lon"],
                node_spacing_m=node_spacing_m,
                grid_size_lat_m=10_000,
                grid_size_lon_m=14_000,
                dem=dem_spec,
            )

            # Mask out ocean nodes once:
            land = gridobj.apply_land_mask_from_dem(sea_level=1.0)
            print("Kept", land.sum(), "of", land.size, "nodes")

            # If you key caches, include the mask signature:
            #mask_sig = griobjd.mask_signature()
            # e.g., params.mask_sig = mask_sig

        # --- Plot the grid we're using (nodes + basemap)
        try:
            preview_png = os.path.join(outdir, f"grid_preview_{getattr(gridobj, 'id', 'grid')[:8]}.png")
            topo_kw = {
                "inv": inv,
                "add_labels": False,
                "topo_color": True,          # colored topo (change to False if you prefer grayscale)
                "region": ISLAND_REGION,     # consistent extent across runs
                "DEM_DIR": GLOBAL_CACHE,    # shared PyGMT cache for basemap if no dem_tif
            }
            if dem_tif_for_bmap:
                topo_kw["dem_tif"] = dem_tif_for_bmap

            # Many Grid/NodeGrid classes in flovopy expose a .plot() helper that accepts a GMT basemap config
            # and an outfile. If your version differs, point it at a small helper that calls topo_map() + fig.plot().
            gridobj.plot(show=False, topo_map_kwargs=topo_kw, outfile=preview_png)
            print(f"[GRID] Preview saved: {preview_png}")
        except Exception as e:
            print(f"[GRID:WARN] Could not plot grid preview: {e}")        

        # --- Distances (give a real cache dir)
        use_3d = (cfg.dist_mode.lower() == "3d")
        node_distances_km, station_coords, dist_meta = compute_or_load_distances(
            gridobj,
            inventory=inv,
            stream=None,
            cache_dir=str(dist_cache_dir),
            force_recompute=True,          # still “fresh”, but writes to cache_dir
            use_elevation=use_3d,
        )
        log(f"DIST meta: {dist_meta}")

        # --- AmpCorr (also give a real cache dir)
        surface_flag = (cfg.wave_kind == "surface")
        params = AmpCorrParams(
            surface_waves=surface_flag,
            wave_speed_kms=cfg.speed,
            Q=cfg.Q,
            peakf=float(peakf),
            grid_sig=gridobj.signature(),
            inv_sig=tuple(sorted(node_distances_km.keys())),
            dist_sig=distances_signature(node_distances_km),
            mask_sig=None,
            code_version="v1",
        )
        ampcorr = AmpCorr(params, cache_dir=str(ampcorr_cache))
        ampcorr.compute_or_load(node_distances_km, inventory=inv)
        if hasattr(ampcorr, "validate_against_nodes"):
            ampcorr.validate_against_nodes(gridobj.gridlat.size)

        # --- ASL config
        asl_config: Dict[str, Any] = {
            "window_seconds": window_seconds,
            "min_stations": MIN_STATIONS,
            "Q": cfg.Q,
            "surfaceWaveSpeed_kms": cfg.speed,
            "vsam_metric": metric,
            "gridobj": gridobj,
            "node_distances_km": node_distances_km,
            "station_coords": station_coords,
            "ampcorr": ampcorr,
            "inventory": inv,
            "interactive": False,
            "numtrials": 200,
            "dist_meta": dist_meta,
            "misfit_engine": cfg.misfit_engine,
        }

        # --- Read stream and run ASL
        st = read(mseed_file).select(component="Z")
        if len(st) < MIN_STATIONS:
            raise RuntimeError(f"Not enough stations: {len(st)} < {MIN_STATIONS}")

        event_dir = outdir / Path(mseed_file).stem
        event_dir.mkdir(exist_ok=True)

        log(f"Running ASL on {mseed_file}")
        asl_sausage(
            stream=st,
            event_dir=str(event_dir),
            asl_config=asl_config,
            output_dir=str(outdir),
            dry_run=False,
            peakf_override=None,
            station_gains_df=None,
            allow_station_fallback=True,
        )

        # --- summary
        src_csv = next((p for p in event_dir.glob("source_*.csv")), None)
        qml = next((p for p in event_dir.glob("event_*.qml")), None)
        heat = next((p for p in event_dir.glob("map_Q*_F*.png")), None)
        misf = next((p for p in event_dir.glob("*misfit*heatmap*.png")), None)

        summ = {
            "tag": cfg.tag(),
            "outdir": str(outdir),
            "event_dir": str(event_dir),
            "source_csv": str(src_csv) if src_csv else None,
            "event_qml": str(qml) if qml else None,
            "map_png_exists": bool(heat),
            "misfit_png_exists": bool(misf),
        }
        log(f"Summary: {summ}")
        log(f"Elapsed: {time.time()-t0:.1f}s")
        return summ

    except Exception as e:
        log(f"ERROR: {type(e).__name__}: {e}")
        traceback.print_exc()
        return {
            "tag": cfg.tag(),
            "error": f"{type(e).__name__}: {e}",
            "outdir": str(outdir),
        }

In [None]:
# Example baseline that you will tweak per notebook:
A = MiniConfig(grid_kind="regular", wave_kind="surface", station_corr=False,
               speed=1.5, Q=50, dist_mode="2d", misfit_engine="l2")

B = MiniConfig(grid_kind="regular", wave_kind="surface", station_corr=False,
               speed=3.2, Q=50, dist_mode="2d", misfit_engine="l2")  # <-- only one field changes per notebook

In [None]:
Path(OUTPUT_BASE).mkdir(parents=True, exist_ok=True)

resA = run_single_event(
    A,
    mseed_file=MSEED_FILE,
    inventory_xml=INVENTORY_XML,
    output_base=OUTPUT_BASE,
    node_spacing_m=NODE_SPACING_M,
    metric=METRIC,
    window_seconds=WINDOW_SECONDS,
    peakf=PEAKF,
    channels_dir=CHANNELS_DIR,
    channels_step_m=CHANNELS_STEP_M,
    channels_dem_tif=CHANNELS_DEM,
    regular_grid_dem=REGULAR_GRID_DEM,
    dem_tif_for_bmap=DEM_TIF_BMAP,
)



In [None]:
resB = run_single_event(
    B,
    mseed_file=MSEED_FILE,
    inventory_xml=INVENTORY_XML,
    output_base=OUTPUT_BASE,
    node_spacing_m=NODE_SPACING_M,
    metric=METRIC,
    window_seconds=WINDOW_SECONDS,
    peakf=PEAKF,
    channels_dir=CHANNELS_DIR,
    channels_step_m=CHANNELS_STEP_M,
    channels_dem_tif=CHANNELS_DEM,
    regular_grid_dem=REGULAR_GRID_DEM,
    dem_tif_for_bmap=DEM_TIF_BMAP,
)

In [None]:
def flatten(d: Dict[str, Any]) -> Dict[str, Any]:
    return {k: d.get(k) for k in ["tag", "source_csv", "event_qml", "map_png_exists", "misfit_png_exists", "error"]}

df = pd.DataFrame([flatten(resA), flatten(resB)])
df