# Gravity inversion: trade-off parameter $\beta$ (Tikhonov) — interactive demo (with optional regional trend)

This notebook is a **teaching demo** inspired by the GIFTools Cookbook “trade-off parameter” applet, implemented with **SimPEG + ipywidgets**.

We simulate a simple 2D-looking gravity problem (a vertical cylinder in a tensor mesh with `ny=1`), then solve a sequence of Tikhonov inversions along a **$\beta$ cooling schedule** to visualize the trade-off between:
- data misfit $\phi_d$ (fit to the observed gravity profile), and  
- model objective $\phi_m$ (smallness + smoothness + depth weighting).

## What’s new in this version: regional (trend) option
In practice, gravity profiles often contain a **regional trend** (e.g., a polynomial background) in addition to the local anomaly of interest.  
Here you can:
- **add a synthetic polynomial trend** of chosen order (0–5) to the data, and
- choose whether to **remove** a fitted polynomial trend **before** inversion.

Crucially, the **degree used to remove the trend can be different** from the degree of the true trend. This lets you demonstrate under/over-fitting of the regional.

> Tip for class:  
> Try adding a 3rd–5th order trend, then remove it using order 0–1, and watch how the inversion “explains” the leftover trend as spurious subsurface density.

---


In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
# Matplotlib mathtext: render Greek letters and subscripts with LaTeX-like syntax
import matplotlib as mpl
mpl.rcParams["mathtext.default"] = "it"

import scipy.sparse as sp
import scipy.sparse.linalg as spla
from scipy.optimize import lsq_linear

import ipywidgets as widgets
from IPython.display import display, clear_output

from discretize import TensorMesh
from simpeg import maps
from simpeg.potential_fields import gravity

# Ensure roma colormap is available
try:
    from cmcrameri import cm as cmc
except Exception:
    !{sys.executable} -m pip install -q cmcrameri
    from cmcrameri import cm as cmc

ROMA = cmc.roma_r

In [2]:
def _try_engine():
    """Helper function used by the interactive gravity inversion demo.
    
    Parameters
    ----------
    *args, **kwargs
        See the function signature for details.
    
    Returns
    -------
    object
        See the function implementation for details.
    
    """
    try:
        import choclo  # noqa: F401
        return "choclo"
    except Exception:
        return "geoana"


def power_iteration_maxeig(A, n_iter=35, seed=0):

    """Estimate the largest eigenvalue of a symmetric positive (semi)definite matrix.
    This uses a basic power iteration on a sparse matrix operator, suitable for
    large problems where an explicit eigendecomposition is too expensive.

    Parameters
    ----------
    A : scipy.sparse.spmatrix
        Sparse matrix (assumed symmetric PSD in this demo).
    n_iter : int, default: 35
        Number of power iterations.
    seed : int, default: 0
        RNG seed for the initial vector.

    Returns
    -------
    lam_max : float
        Approximation of the largest eigenvalue of `A`.
    """
    rng = np.random.default_rng(seed)
    n = A.shape[0]
    x = rng.standard_normal(n)
    x /= np.linalg.norm(x) + 1e-30
    lam = 0.0
    for _ in range(n_iter):
        y = A @ x
        ny = np.linalg.norm(y) + 1e-30
        x = y / ny
        lam = float(x @ (A @ x))
    return lam


def depth_weight_cells(mesh, z0=None, p=1.5, normalize=True):


    """Compute a depth-weighting vector on mesh cells for gravity-style inversions.
    
    A simple form is used:

    w(z) propto (z - z_0)^{-p},
    
    then normalized so `max(w)=1` (optional). Here `z` is the cell-center vertical
    coordinate (typically negative downward in this notebook's convention).

    Parameters
    ----------
    mesh : discretize.TensorMesh
        Tensor mesh.
    z0 : float
        Reference depth shift (meters). A small positive value helps avoid singular
        weights near the surface.
    p : float, default: 1.5
        Depth-weighting exponent.
    normalize : bool, default: True
        If True, scale weights so the maximum is 1.

    Returns
    -------
    w : (mesh.nC,) numpy.ndarray
        Depth-weight vector (dimensionless).
    """
    cc = mesh.cell_centers
    depth = np.maximum(0.0, -cc[:, 2])
    if z0 is None:
        z0 = float(np.min(mesh.h[2]))
    w = 1.0 / np.power(depth + z0, p)
    if normalize:
        w = w / (np.max(w) + 1e-30)
    return w


def face_weights_from_cell_weights(mesh, w_cell, orientation):

    """Helper function used by the interactive gravity inversion demo.
    Parameters
    ----------
    *args, **kwargs
        See the function signature for details.

    Returns
    -------
    object
        See the function implementation for details.
    """
    
    w3 = w_cell.reshape(mesh.vnC, order="F")
    nx, ny, nz = mesh.vnC
    if orientation == "x":
        w_face = np.zeros((nx + 1, ny, nz))
        w_face[0, :, :] = w3[0, :, :]
        w_face[1:nx, :, :] = 0.5 * (w3[1:, :, :] + w3[:-1, :, :])
        w_face[nx, :, :] = w3[-1, :, :]
        return w_face.reshape(-1, order="F")
    if orientation == "z":
        w_face = np.zeros((nx, ny, nz + 1))
        w_face[:, :, 0] = w3[:, :, 0]
        w_face[:, :, 1:nz] = 0.5 * (w3[:, :, 1:] + w3[:, :, :-1])
        w_face[:, :, nz] = w3[:, :, -1]
        return w_face.reshape(-1, order="F")
    raise ValueError("orientation must be 'x' or 'z'")


def build_regularization_operators(mesh, active, w_cell, alpha_s=1.0, alpha_x=1.0, alpha_z=1.0):
    """
    Build a depth-weighted Tikhonov regularization operator L and quadratic form B = L^T L.

    Parameters
    ----------
    mesh : discretize.TensorMesh
        Tensor mesh.
    active : (nC,) bool
        Active cell mask.
    w_cell : (nC,) float
        Cell-based depth-weighting (normalized to max=1 typically).
    alpha_s, alpha_x, alpha_z : float
        Weights for smallness, x-smoothness and z-smoothness terms.
        Setting alpha_x=alpha_z=0 gives smallness-only regularization.

    Returns
    -------
    L : scipy.sparse.csr_matrix
        Stacked regularization operator with rows scaled by sqrt(alpha).
    B : scipy.sparse.csr_matrix
        Quadratic form B = L^T L.
    """
    Dx = mesh.cell_gradient_x
    Dz = mesh.cell_gradient_z
    P = sp.diags(active.astype(float), 0, shape=(mesh.nC, mesh.nC), format="csr")

    # Smallness on cells
    Wc = sp.diags(w_cell, 0, format="csr")
    Ws = Wc @ P

    blocks = []
    if alpha_s > 0.0:
        blocks.append(np.sqrt(alpha_s) * Ws)

    # Smoothness (optional)
    if alpha_x > 0.0:
        w_fx = face_weights_from_cell_weights(mesh, w_cell, "x")
        Wfx = sp.diags(w_fx, 0, format="csr")
        Lx = Wfx @ (Dx @ P)
        blocks.append(np.sqrt(alpha_x) * Lx)

    if alpha_z > 0.0:
        w_fz = face_weights_from_cell_weights(mesh, w_cell, "z")
        Wfz = sp.diags(w_fz, 0, format="csr")
        Lz = Wfz @ (Dz @ P)
        blocks.append(np.sqrt(alpha_z) * Lz)

    if len(blocks) == 0:
        # degenerate, but keep safe
        blocks = [Ws]

    L = sp.vstack(blocks).tocsr()
    B = (L.T @ L).tocsr()
    return L, B

    w_fx = face_weights_from_cell_weights(mesh, w_cell, "x")
    w_fz = face_weights_from_cell_weights(mesh, w_cell, "z")
    Wfx = sp.diags(w_fx, 0, format="csr")
    Wfz = sp.diags(w_fz, 0, format="csr")

    Lx = Wfx @ (Dx @ P)
    Lz = Wfz @ (Dz @ P)

    L = sp.vstack([Ws, Lx, Lz]).tocsr()
    B = (L.T @ L).tocsr()
    return L, B


def beta_schedule(beta0, cooling=2.0, n_steps=8):


    """Generate a geometric cooling schedule for the trade-off parameter `beta`.

    Parameters
    ----------
        beta0 : float
        Initial beta value (often called `beta_0`).
        cooling : float, default: 2.0
        Cooling factor. Each step multiplies beta by 1/cooling.
        n_steps : int, default: 8 (Number of beta values to return.) 

    Returns
    -------
        betas : (n_steps,) numpy.ndarray
        Sequence of beta values starting at `beta0`.
    """
    return beta0 / (cooling ** np.arange(n_steps))


def solve_tikhonov(G_w, d_w, L, beta, solver="direct", bounds=None):


    """Solve a Tikhonov-regularized least-squares problem for a fixed `beta`.

    We solve (in weighted form):

    min_m |G_w m - d_w|_2^2 + beta |L m|_2^2,

    which leads to the normal equations:

    (G_w^T G_w + beta L^T L), m = G_w^T d_w.

    Optionally, simple bounds can be imposed (e.g., positivity) if a bound-capable
    solver is available.

    Parameters
    ----------
    G_w : scipy.sparse.csr_matrix
        Weighted forward operator.

    d_w : (nD,) numpy.ndarray
        Weighted data vector.

    L : scipy.sparse.csr_matrix
        Regularization operator.

    beta : float
        Trade-off parameter.

    solver : {"direct", "cg"}, default: "direct"
        Linear solver choice.

    bounds : tuple or None, optional
        Bounds `(lower, upper)` for `m`. If None, solves unconstrained.

    Returns
    -------
    m : (nP,) numpy.ndarray
        Estimated model.

    info : dict
        Solver diagnostics (iterations, residual norms, etc.), if available.
    """
    
    A_stack = sp.vstack([G_w, np.sqrt(beta) * L]).tocsr()
    b_stack = np.r_[d_w, np.zeros(A_stack.shape[0] - d_w.size)]

    if bounds is not None:
        res = lsq_linear(A_stack, b_stack, bounds=bounds, lsmr_tol="auto", verbose=0)
        return res.x, int(res.status)

    if solver == "cg":
        A = (G_w.T @ G_w).tocsr()
        B = (L.T @ L).tocsr()
        rhs = (G_w.T @ d_w).astype(float)
        M = (A + beta * B).tocsr()
        jitter = 1e-12 * (M.diagonal().mean() if M.shape[0] > 0 else 1.0)
        M = M + jitter * sp.eye(M.shape[0], format="csr")
        m, info = spla.cg(M, rhs, x0=np.zeros(M.shape[0]), maxiter=400, rtol=1e-6)
        return m, info

    m = spla.lsqr(A_stack, b_stack, atol=1e-10, btol=1e-10, iter_lim=2000)[0]
    return m, 0

In [3]:
# --- Mesh (2D-looking: ny=1) ---
dx = 25.0
nx, ny, nz = 48, 1, 24
Lx = nx * dx
Ly = 1000.0
Lz = nz * dx

mesh = TensorMesh(
    [dx * np.ones(nx), np.array([Ly]), dx * np.ones(nz)],
    x0=[-Lx/2, -Ly/2, -Lz]
)
active = np.ones(mesh.nC, dtype=bool)

engine = _try_engine()

def build_survey_simulation(n_data=61, x_min=-450.0, x_max=450.0, z_rx=5.0):
    """Build survey, simulation, and return (x_rx, receiver_locations, receiver_locations_plot, sim, G)."""
    x_rx = np.linspace(x_min, x_max, int(n_data))
    receiver_locations = np.c_[x_rx, np.zeros_like(x_rx), z_rx * np.ones_like(x_rx)]
    # For plotting, pretend stations are on the surface (z=0)
    receiver_locations_plot = np.c_[x_rx, np.zeros_like(x_rx), np.zeros_like(x_rx)]

    receiver_list = gravity.receivers.Point(receiver_locations, components="gz")
    source_field = gravity.sources.SourceField(receiver_list=[receiver_list])
    survey = gravity.survey.Survey(source_field)

    sim = gravity.simulation.Simulation3DIntegral(
        survey=survey,
        mesh=mesh,
        rhoMap=maps.IdentityMap(nP=mesh.nC),
        active_cells=active,
        store_sensitivities="ram",
        engine=engine,
    )

    G = sim.G
    return x_rx, receiver_locations, receiver_locations_plot, sim, G

# default build
x_rx, receiver_locations, receiver_locations_plot, sim, G = build_survey_simulation(n_data=61)
print(f"engine={engine}, mesh.nC={mesh.nC}, ndata={x_rx.size}, G shape={G.shape}")

engine=choclo, mesh.nC=1152, ndata=61, G shape=(61, 1152)


In [4]:
def run_beta_schedule_contrast(
    d_obs, Wd, G_use,
    depth_p=1.5,
    alpha_s=1.0, alpha_x=1.0, alpha_z=1.0,
    n_steps=8, cooling=2.0,
    beta_scale=1.0,
    min_beta_factor=1e-3,
    solver="direct",
    positivity=False,
    auto_beta0=True,
    beta0_frozen=None,
):
    """
    Run a simple $\\beta$ cooling schedule and return models and trade-off metrics.

    Key teaching feature
    --------------------
    - If auto_beta0=True (default): $\\beta_0$ is recomputed each time from A/B, which tends to
      *normalize out* absolute changes in (alpha_s, alpha_x, alpha_z).
    - If auto_beta0=False: $\\beta_0$ is taken from beta0_frozen (computed once) so changing alphas
      has a strong visible impact.

    Parameters
    ----------
    auto_beta0 : bool
        Whether to recompute $\\beta_0$ from λmax(A)/λmax(B) each call.
    beta0_frozen : float or None
        Frozen $\\beta_0$ value to use when auto_beta0=False. If None, will fall back to auto $\\beta_0$.

    Returns
    -------
    betas, models, phi_d, phi_m, beta0_used
    """
    # Weight data and forward operator (keep everything 2D / sparse)
    # d_w = Wd @ d_obs ;  G_w = Wd @ G_use
    if sp.issparse(Wd):
        d_w = (Wd @ np.asarray(d_obs).ravel()).astype(float)
    else:
        d_w = (np.asarray(Wd).ravel() * np.asarray(d_obs).ravel()).astype(float)
    
    if not sp.issparse(G_use):
        G_use = sp.csr_matrix(G_use)
    if sp.issparse(Wd):
        G_w = (Wd @ G_use).tocsr()
    else:
        wd = np.asarray(Wd).ravel()
        G_w = (sp.diags(wd, 0, format="csr") @ G_use).tocsr()

    w_cell = depth_weight_cells(mesh, z0=dx, p=depth_p, normalize=True)
    L, B = build_regularization_operators(mesh, active, w_cell, alpha_s=alpha_s, alpha_x=alpha_x, alpha_z=alpha_z)

    A = (G_w.T @ G_w).tocsr()

    # Compute beta0 (or use frozen)
    if (not auto_beta0) and (beta0_frozen is not None):
        beta0_used = float(beta0_frozen)
    else:
        lamA = power_iteration_maxeig(A, n_iter=35, seed=0)
        lamB = power_iteration_maxeig(B, n_iter=35, seed=1)
        beta0_used = float(lamA / (lamB + 1e-30))

    beta0_eff = float(beta0_used * beta_scale)
    betas = beta_schedule(beta0_eff, cooling=cooling, n_steps=n_steps)

    beta_min = beta0_eff * float(min_beta_factor)
    betas = np.maximum(betas, beta_min)

    bounds = None
    if positivity:
        bounds = (np.zeros(mesh.nC), np.inf*np.ones(mesh.nC))

    models = np.zeros((n_steps, mesh.nC))
    phi_d = np.zeros(n_steps)
    phi_m = np.zeros(n_steps)

    for k, be in enumerate(betas):
        m, _info = solve_tikhonov(G_w, d_w, L, be, solver=solver, bounds=bounds)
        models[k, :] = m
        d_pred = np.asarray(G_use @ m)
        r_w = Wd * (d_pred - d_obs)
        phi_d[k] = float(r_w @ r_w)
        phi_m[k] = float(m @ (B @ m))

    return betas, models, phi_d, phi_m, beta0_used

## Interactive controls: what each parameter means

This section documents the sliders/checkboxes used in the widget-based demo.

### Data acquisition

**`N data`**  
Number of gravity stations along the profile (uniformly spaced between **x min** and **x max**).  
- Increasing **N** adds constraints; if uncertainties are consistent, the expected target misfit is typically  $ \phi_d^* \approx N$

**`x min`, `x max` (m)**  
Horizontal extent of the receiver line (meters).  
- Larger span improves coverage; smaller span localizes sensitivity.

**`z_rx` (m)**  
Receiver elevation used in the forward simulation.  
- Higher receivers are farther from the source → lower amplitude and broader anomalies.

### True model geometry (synthetic “cylinder”)

**`$\Delta\rho$ true` (g/cc)**  
Density **contrast** inside the cylinder relative to background. Outside the cylinder, $\Delta\rho=0$.  
- Larger $\Delta\rho$ → larger anomaly.

**`x0` (m)**  
Horizontal center of the cylinder.

**`z0` (m)**  
Vertical center of the cylinder. In this notebook, **negative values are depth**.

**`R` (m)**  
Cylinder radius. Larger $R$ increases the anomalous mass (and anomaly amplitude).

### Noise + target misfit

**`noise`**  
Noise level used to define data uncertainty (implemented as a fraction of the anomaly scale).  
- Larger noise → larger uncertainty → smaller weighted misfit for the same residual.

**`$\phi_d^*$`**  
Target (“expected”) data misfit. With correctly estimated Gaussian uncertainties, a common choice is  
$\phi_d^* \approx N$,
but you can change it to demonstrate over/under-estimated errors or alternate stopping targets.

### $\beta$ schedule (trade-off parameter)

**`$\beta$ steps`**  
Number of values in the $\beta$ cooling schedule.

**`$\beta$ idx`**  
Index of the selected $\beta$ within the schedule (the one highlighted in plots and used for the displayed recovered model).

**`cooling`**  
Geometric cooling factor used to decrease $\beta$:
$$
\beta_k = \beta_0\,\beta_{\text{scale}}\,(\text{cooling})^{-k}.
$$
- Larger cooling → faster decrease in $\beta$.

**`beta scale`**  
Scalar multiplier applied to $\beta_0$.  
- Larger values shift the full schedule to stronger regularization.

**`min $\beta/\beta_0$`**  
Lower bound on $\beta$ relative to $\beta_0$.  
- Prevents extremely small $\beta$ that can yield unstable/noisy models.

### Regularization / model objective

**`depth_p`**  
Exponent $p$ controlling depth weighting used in regularization.  
- Depth weighting counteracts the tendency of gravity inversions to recover overly shallow structure.

**`alpha_s`**  
Smallness weight (pulls the model toward a reference model, typically 0):
$$
\alpha_s \,\|W_s(m-m_{\text{ref}})\|_2^2.
$$

**`alpha_x`**  
Smoothness weight in the **x** direction:
$$
\alpha_x \,\|W_x\,\partial_x m\|_2^2.
$$

**`alpha_z`**  
Smoothness weight in the **z** direction:
$$
\alpha_z \,\|W_z\,\partial_z m\|_2^2.
$$

### Constraints + solver

**`Freeze $\beta_0$`**  
If **OFF**, $\beta_0$ is auto-normalized each time (often using a scale estimate like $\lambda_{\max}(A)/\lambda_{\max}(B)$).  
If **ON**, $\beta_0$ is held fixed — useful to make the effect of changing $\alpha$ values more visible for teaching.

**`$\Delta\rho \ge 0$`**  
Enforces non-negative density contrast (positivity constraint).

**`solver`**  
Choice of linear solver used in each Tikhonov solve (e.g., direct vs iterative). The impact is mostly computational (speed/robustness), not conceptual.

### Regional trend

**`Add regional trend`**  
Adds a synthetic polynomial “regional” component to the observed data:
\[
d_{\text{obs}}(x)=d_{\text{anomaly}}(x)+d_{\text{trend}}(x)+\epsilon.
\]

**`Trend order` (0–5)**  
Polynomial degree for the trend (0 = constant offset, 1 = linear, …, 5 = quintic).

**`Trend amp (×|max d|)`**  
Scales the trend amplitude relative to the anomaly scale.

**`Remove trend before inversion`**  
If ON, fits a polynomial of the chosen order to the observed data and subtracts it before inversion (for plotting, the predicted anomaly is added back to the fitted trend).


### Regional trend controls (polynomial)

This demo can add a **synthetic regional trend** to the gravity profile and (optionally) remove a fitted trend before inversion.

- **Add regional trend**: if enabled, the *true* trend is generated as a polynomial of chosen **True trend order** (0–5) on a normalized x-axis.  
  This is the “regional/background” component that gets added to the *local anomaly*.

- **True trend order**: polynomial degree used to generate the *true* trend (0 = constant offset, 1 = linear, …, 5 = 5th order).

- **Trend amp (×|max d|)**: sets the *maximum absolute amplitude* of the true trend as a fraction of the anomaly amplitude (based on `max(|d_true|)`).

- **Remove trend before inversion**: if enabled, we fit a polynomial to the **observed** data and subtract it before inversion.

- **Removal trend order**: polynomial degree used for the **fitted** trend (this can differ from the true trend order).  
  *This is the key teaching control:* if removal order is too low, the residual trend contaminates the inversion; if too high, you risk removing anomaly content.

**Terminology in the plots**
- **True trend**: the synthetic polynomial we *added* to the data (known in the simulation).
- **Fitted trend**: the polynomial estimated from the noisy observed data (what you would do in real processing).
- **Observed (raw)**: anomaly + true trend + noise.
- **Observed (trend removed)**: observed(raw) − fitted trend (this is what the inversion sees when removal is ON).



In [5]:
# =============================================================================
# Interactive demo (widgets): gravity inversion + β-tradeoff + optional regional trend
# =============================================================================
#
# NOTE:
# - This cell is intentionally self-contained (robust to re-runs).
# - It expects that the notebook has already defined:
#     mesh, ROMA, build_survey_simulation, run_beta_schedule_contrast
#   and that SimPEG/discretize imports used by those functions are available.
#
# If you re-run this cell multiple times, it will close old widgets and reuse a
# single Output container (prevents duplicated UI and callback recursion).

import numpy as np
import matplotlib.pyplot as plt
import scipy.sparse as sp
from IPython.display import display, clear_output
import ipywidgets as widgets


# -----------------------------------------------------------------------------
# Trend utilities
# -----------------------------------------------------------------------------
def fit_poly_trend(x, d, degree, w=None):
    """Fit a polynomial trend to a 1D profile.

    Parameters
    ----------
    x : (n,) array_like
        Profile coordinates.
    d : (n,) array_like
        Data values to fit.
    degree : int
        Polynomial degree (0 = constant, 1 = linear, ...).
    w : (n,) array_like or None, optional
        Optional weights for `np.polyfit` (typically `1/sigma`).

    Returns
    -------
    coeffs : (degree+1,) numpy.ndarray
        Polynomial coefficients (highest power first), compatible with `np.polyval`.
    """
    x = np.asarray(x).ravel()
    d = np.asarray(d).ravel()
    deg = int(np.clip(degree, 0, 5))
    if w is None:
        return np.polyfit(x, d, deg)
    return np.polyfit(x, d, deg, w=np.asarray(w).ravel())


def eval_poly_trend(x, coeffs):
    """Evaluate a polynomial trend.

    Parameters
    ----------
    x : (n,) array_like
        Coordinates.
    coeffs : array_like
        Polynomial coefficients (highest power first).

    Returns
    -------
    trend : (n,) numpy.ndarray
        Trend values.
    """
    return np.polyval(np.asarray(coeffs), np.asarray(x))


def make_synthetic_trend(x, degree, amp, seed=2026):
    """Create a deterministic synthetic polynomial trend for teaching.

    The trend is generated from random coefficients (seeded for reproducibility),
    then scaled so that `max(|trend|) = amp`.

    Parameters
    ----------
    x : (n,) array_like
        Coordinates along profile.
    degree : int
        Polynomial degree (0..5).
    amp : float
        Desired maximum absolute amplitude of the trend.
    seed : int, default: 2026
        Seed for reproducibility.

    Returns
    -------
    trend : (n,) numpy.ndarray
        Trend values.
    coeffs : (degree+1,) numpy.ndarray
        Polynomial coefficients used internally (highest power first, on normalized x).
    """
    x = np.asarray(x).ravel()
    deg = int(np.clip(degree, 0, 5))

    # Normalize x to [-1, 1] for numerical stability
    x_mid = 0.5 * (x.min() + x.max())
    x_half = 0.5 * (x.max() - x.min()) + 1e-12
    xn = (x - x_mid) / x_half

    rng = np.random.default_rng(int(seed) + deg)
    coeffs = rng.normal(size=deg + 1) / (deg + 1)

    trend = np.polyval(coeffs, xn)
    mx = np.max(np.abs(trend)) + 1e-30
    trend = trend * (float(amp) / mx)
    return trend, coeffs


def choose_sign_factor(d):
    """Return +1 if the dominant anomaly is positive, else -1.

    Used to keep plotted data/anomaly positive by default (teaching convenience).
    """
    d = np.asarray(d).ravel()
    if d.size == 0:
        return 1.0
    return 1.0 if np.max(d) >= abs(np.min(d)) else -1.0


# -----------------------------------------------------------------------------
# Fallback helper (ensures this cell is self-contained)
# -----------------------------------------------------------------------------
if "make_contrast_model" not in globals():
    def make_contrast_model(mesh, drho_cyl=0.30, x0=0.0, z0=-200.0, R=120.0):
        """Create a 2D cylindrical density-contrast model on a ny=1 tensor mesh.

        Parameters
        ----------
        mesh : discretize.TensorMesh
            Tensor mesh with ny=1.
        drho_cyl : float
            Density contrast inside the cylinder (g/cc).
        x0, z0 : float
            Cylinder center (m). Note z is typically negative downward.
        R : float
            Cylinder radius (m).

        Returns
        -------
        m : (mesh.nC,) numpy.ndarray
            Density-contrast model on cell centers.
        """
        cc = mesh.cell_centers
        x = cc[:, 0]
        z = cc[:, 2]
        r = np.sqrt((x - float(x0)) ** 2 + (z - float(z0)) ** 2)
        m = np.zeros(mesh.nC, dtype=float)
        m[r <= float(R)] = float(drho_cyl)
        return m


# -----------------------------------------------------------------------------
# Single-instance widget / output management (prevents recursion / duplicates)
# -----------------------------------------------------------------------------
if "_GRAV_BETA_APP" in globals():
    try:
        _prev_out = _GRAV_BETA_APP.get("out", None)
        if isinstance(_prev_out, widgets.Output):
            out = _prev_out
        else:
            out = widgets.Output()

        for w in _GRAV_BETA_APP.get("widgets", []):
            try:
                w.close()
            except Exception:
                pass
    except Exception:
        out = widgets.Output()
else:
    out = widgets.Output()


# -----------------------------------------------------------------------------
# Widgets
# -----------------------------------------------------------------------------
n_data_widget = widgets.IntSlider(
    value=51, min=11, max=121, step=2, description="N data",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)
x_min_widget = widgets.FloatSlider(
    value=-450.0, min=-800.0, max=-100.0, step=25.0, description="x min (m)",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)
x_max_widget = widgets.FloatSlider(
    value=450.0, min=100.0, max=800.0, step=25.0, description="x max (m)",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)
z_rx_widget = widgets.FloatSlider(
    value=5.0, min=0.0, max=50.0, step=1.0, description="z_rx (m)",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)

drho_widget = widgets.FloatSlider(
    value=0.30, min=0.05, max=0.80, step=0.01, description=r"$\Delta\rho$ true",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)
x0_widget = widgets.FloatSlider(
    value=0.0, min=-300.0, max=300.0, step=10.0, description=r"$x_0$ (m)",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)
z0_widget = widgets.FloatSlider(
    value=-200.0, min=-550.0, max=-25.0, step=10.0, description=r"$z_0$ (m)",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)
R_widget = widgets.FloatSlider(
    value=120.0, min=40.0, max=300.0, step=10.0, description=r"$R$ (m)",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)

noise_widget = widgets.FloatSlider(
    value=0.02, min=0.0, max=0.10, step=0.005, description="noise frac",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)

trend_enable_widget = widgets.Checkbox(value=False, description="Add regional trend", indent=False)
trend_degree_widget = widgets.IntSlider(value=1, min=0, max=5, step=1, description="True trend order",
                                        continuous_update=False, style={"description_width":"90px"},
                                        layout=widgets.Layout(width="380px"))
trend_remove_degree_widget = widgets.IntSlider(
    value=1, min=0, max=5, step=1,
    description="Removal order",
    continuous_update=False,
    style={"description_width":"90px"},
    layout=widgets.Layout(width="380px"),
)

trend_amp_widget = widgets.FloatSlider(
    value=0.30, min=0.0, max=2.0, step=0.05,
    description=r"Trend amp ($\times$|max d|)",
    continuous_update=False, style={"description_width":"160px"},
    layout=widgets.Layout(width="520px")
)

trend_remove_widget = widgets.Checkbox(value=True, description="Remove trend before inversion", indent=False)

phi_d_star_widget = widgets.IntSlider(
    value=51, min=1, max=500, step=1, description=r"$\phi_d^\ast$",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)

beta_steps_widget = widgets.IntSlider(
    value=8, min=3, max=20, step=1, description=r"$\beta$ steps",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="720px"),
)
beta_index_widget = widgets.IntSlider(
    value=0, min=0, max=7, step=1, description=r"$\beta$ idx",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="720px"),
)
cooling_widget = widgets.FloatSlider(
    value=2.0, min=1.05, max=6.0, step=0.05, description="cooling",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="720px"),
)
beta_scale_widget = widgets.FloatLogSlider(
    value=1.0, base=10, min=-2, max=2, step=0.1, description="beta scale",
    continuous_update=False, style={"description_width": "100px"},
    layout=widgets.Layout(width="420px"),
)
min_beta_factor_widget = widgets.FloatLogSlider(
    value=1e-3, base=10, min=-6, max=-1, step=0.1, description=r"min $\beta/\beta_0$",
    continuous_update=False, style={"description_width": "120px"},
    layout=widgets.Layout(width="420px"),
)

freeze_beta0_widget = widgets.Checkbox(value=False, description=r"Freeze $\beta_0$ (OFF)", indent=False)
beta0_display = widgets.HTML(value=r"<b>$\beta_0$ frozen:</b> (none)")

depthexp_widget = widgets.FloatSlider(
    value=0.9, min=0.0, max=3.5, step=0.1, description=r"depth $p$",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="720px"),
)
alpha_s_widget = widgets.FloatLogSlider(
    value=1.0, base=10, min=-3, max=3, step=0.1, description=r"$\alpha_s$",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)
alpha_x_widget = widgets.FloatLogSlider(
    value=1.0, base=10, min=-3, max=3, step=0.1, description=r"$\alpha_x$",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)
alpha_z_widget = widgets.FloatLogSlider(
    value=1.0, base=10, min=-3, max=3, step=0.1, description=r"$\alpha_z$",
    continuous_update=False, style={"description_width": "90px"},
    layout=widgets.Layout(width="380px"),
)

pos_widget = widgets.Checkbox(value=False, description=r"$\Delta\rho \ge 0$ constraint", indent=False)

solver_widget = widgets.Dropdown(
    options=[("Stacked LS (lsqr)", "direct"), ("CG (normal eq)", "cg")],
    value="cg", description="solver",
    layout=widgets.Layout(width="380px"),
)

_BETA0_FROZEN = None
_FREEZE_LAST = False


def _render():
    """Render/update all plots for the current widget configuration."""
    global _BETA0_FROZEN, _FREEZE_LAST

    with out:
        clear_output(wait=True)

        # --- Rebuild survey/sim/G ---
        x_rx, receiver_locations, receiver_locations_plot, sim, G = build_survey_simulation(
            n_data=int(n_data_widget.value),
            x_min=float(x_min_widget.value),
            x_max=float(x_max_widget.value),
            z_rx=float(z_rx_widget.value),
        )

        # Keep phi_d* synced to N if user hasn't overridden
        if int(phi_d_star_widget.value) in (int(phi_d_star_widget.min), int(x_rx.size)):
            phi_d_star_widget.value = int(x_rx.size)

        # --- True model ---
        m_true = make_contrast_model(
            mesh,
            drho_cyl=float(drho_widget.value),
            x0=float(x0_widget.value),
            z0=float(z0_widget.value),
            R=float(R_widget.value),
        )

        # --- Synthetic anomaly ---
        d_raw = sim.dpred(m_true)
        sgn = choose_sign_factor(d_raw)
        d_true = sgn * d_raw

        # --- Optional regional trend ---
        trend = np.zeros_like(d_true)
        if bool(trend_enable_widget.value):
            deg = int(trend_degree_widget.value)
            amp_frac = float(trend_amp_widget.value)
            amp = amp_frac * (np.max(np.abs(d_true)) + 1e-30)
            trend, _ = make_synthetic_trend(x_rx, degree=deg, amp=amp, seed=2026)

        d_true_with_trend = d_true + trend

        # --- Noise ---
        noise_floor = float(noise_widget.value) * (np.max(np.abs(d_true_with_trend)) + 1e-30)
        uncert = noise_floor * np.ones_like(d_true_with_trend)
        rng = np.random.default_rng(42)
        d_obs = d_true_with_trend + rng.normal(0.0, uncert)
        Wd = sp.diags(1.0 / (uncert + 1e-30), format="csr")

        # --- Trend removal ---
        trend_fit = np.zeros_like(d_obs)
        d_for_inversion = d_obs.copy()
        if bool(trend_enable_widget.value) and bool(trend_remove_widget.value):
            deg = int(trend_remove_degree_widget.value)
            coeffs_fit = fit_poly_trend(x_rx, d_obs, degree=int(trend_remove_degree_widget.value))
            trend_fit = eval_poly_trend(x_rx, coeffs_fit)
            d_for_inversion = d_obs - trend_fit

        # Sign-consistent forward operator
        G_use = sgn * G

        # --- Inversion (β schedule) ---
        betas, models, phi_d, phi_m, beta0_used = run_beta_schedule_contrast(
            d_obs=d_for_inversion, Wd=Wd, G_use=G_use,
            depth_p=float(depthexp_widget.value),
            alpha_s=float(alpha_s_widget.value),
            alpha_x=float(alpha_x_widget.value),
            alpha_z=float(alpha_z_widget.value),
            n_steps=int(beta_steps_widget.value),
            cooling=float(cooling_widget.value),
            beta_scale=float(beta_scale_widget.value),
            min_beta_factor=float(min_beta_factor_widget.value),
            solver=str(solver_widget.value),
            positivity=bool(pos_widget.value),
            auto_beta0=(_BETA0_FROZEN is None),
            beta0_frozen=_BETA0_FROZEN,
        )

        # --- Freeze β0 (edge detect) ---
        if bool(freeze_beta0_widget.value) and (not _FREEZE_LAST):
            _BETA0_FROZEN = float(beta0_used)
        if (not bool(freeze_beta0_widget.value)) and _FREEZE_LAST:
            _BETA0_FROZEN = None
        _FREEZE_LAST = bool(freeze_beta0_widget.value)

        freeze_beta0_widget.description = (
            r"Freeze $\beta_0$ (ON)" if _BETA0_FROZEN is not None else r"Freeze $\beta_0$ (OFF)"
        )
        if _BETA0_FROZEN is None:
            beta0_display.value = rf"<b>$\beta_0$ frozen:</b> (none) | <b>current $\beta_0$:</b> {beta0_used:.2e}"
        else:
            beta0_display.value = rf"<b>$\beta_0$ frozen:</b> {_BETA0_FROZEN:.2e} | <b>current $\beta_0$:</b> {beta0_used:.2e}"

        # β index clamp
        beta_index_widget.max = max(0, int(beta_steps_widget.value) - 1)
        beta_idx = int(np.clip(beta_index_widget.value, 0, int(beta_steps_widget.value) - 1))

        m_show = models[beta_idx]
        d_pred_show = np.asarray(G_use @ m_show)

        # Plot prediction in raw-data space
        d_pred_plot = d_pred_show.copy()
        if bool(trend_enable_widget.value) and bool(trend_remove_widget.value):
            d_pred_plot = d_pred_show + trend_fit

        phi_d_star = float(phi_d_star_widget.value)

        # --- True model plot ---
        figA, axA = plt.subplots(1, 1, figsize=(7.8, 5.0), constrained_layout=True)
        qmA = mesh.plot_slice(
            m_true, normal="Y", ind=0, ax=axA, grid=True,
            pcolor_opts={"cmap": ROMA}
        )
        cbA = figA.colorbar(qmA[0], ax=axA)
        cbA.set_label(r"Density contrast $\Delta\rho$ (g/cc)")
        axA.scatter(
            receiver_locations_plot[:, 0],
            receiver_locations_plot[:, 2],
            s=28, marker="v", color="w", edgecolor="k", linewidth=0.5, zorder=5
        )
        axA.set_title(rf"True $\Delta\rho$ + stations (N={x_rx.size}; forward $z_{{rx}}$={z_rx_widget.value:.1f} m)")
        axA.set_xlabel("x (m)")
        axA.set_ylabel("z (m)")
        plt.show()

        # --- 2×2 panel ---
        fig = plt.figure(figsize=(15.5, 8.0), constrained_layout=True)
        gs = fig.add_gridspec(2, 2, width_ratios=[1.25, 1.0])

        ax1 = fig.add_subplot(gs[0, 0])

        # Decide which "data space" is being inverted:
        # - If we remove a fitted trend, the inversion is done on detrended data (d_for_inversion),
        #   and the predicted curve that should be compared is d_pred_show (also detrended).
        # - Otherwise, the inversion is done on the raw observed data (d_obs),
        #   and d_pred_plot == d_pred_show is the predicted curve in that same space.
        use_detrended = bool(trend_enable_widget.value) and bool(trend_remove_widget.value)

        if use_detrended:
            # Show raw data faintly for context, but emphasize what actually goes into the inversion.
            ax1.scatter(x_rx, d_obs, s=25, label="Observed (raw)", zorder=2, color="C0", alpha=0.25)
            ax1.scatter(x_rx, d_for_inversion, s=28, marker="x",
                        label="Observed (used in inversion)", zorder=4, color="C2")
            ax1.plot(x_rx, d_pred_show, lw=2.2, label="Predicted (fits inversion data)", color="k")
            # Optional: what the predicted would look like in raw space after adding the fitted trend back
            ax1.plot(x_rx, d_pred_plot, lw=1.5, ls="--", color="k", alpha=0.6,
                     label="Predicted (+ fitted trend)")
        else:
            ax1.scatter(x_rx, d_obs, s=25, label="Observed (used in inversion)", zorder=3, color="C0")
            ax1.plot(x_rx, d_pred_plot, lw=2.2, label="Predicted (fits inversion data)", color="k")

        # Trend curves (if enabled)
        if bool(trend_enable_widget.value):
            ax1.plot(x_rx, trend, ls="--", lw=1.5, color="0.6", label="True trend")
            if use_detrended:
                ax1.plot(
                    x_rx, trend_fit, ls=":", lw=2.0, color="0.3",
                    label=rf"Fitted trend (order={int(trend_remove_degree_widget.value)})"
                )

        ax1.set_xlabel("x (m)")
        ax1.set_ylabel("g (mGal, relative)")
        ax1.legend(loc="best")

        ax2 = fig.add_subplot(gs[1, 0])
        qm2 = mesh.plot_slice(
            m_show, normal="Y", ind=0, ax=ax2, grid=True,
            pcolor_opts={"cmap": ROMA}
        )
        cb2 = fig.colorbar(qm2[0], ax=ax2)
        cb2.set_label(r"Recovered $\Delta\rho$ (g/cc)")
        ax2.set_xlabel("x (m)")
        ax2.set_ylabel("z (m)")
        ax2.set_title(rf"Recovered model @ $\beta$ index {beta_idx}")

        ax3 = fig.add_subplot(gs[0, 1])
        phi_d_pos = np.maximum(phi_d, 1e-30)
        phi_m_pos = np.maximum(phi_m, 1e-30)
        ax3.loglog(phi_m_pos, phi_d_pos, marker="o")
        ax3.loglog(phi_m_pos[beta_idx], phi_d_pos[beta_idx], marker="o", color="r", markersize=10)
        ax3.axhline(phi_d_star, ls="--", color="0.4", lw=1.5)
        ax3.set_xlabel(r"$\phi_m$")
        ax3.set_ylabel(r"$\phi_d$")
        ax3.set_title(r"Tikhonov ($L$-curve)")
        ax3.text(0.78, 0.92, rf"$\phi_d^\ast$={phi_d_star:.0f}", transform=ax3.transAxes, fontsize=10)

        ax4 = fig.add_subplot(gs[1, 1])
        it = np.arange(len(betas))
        ax4.plot(it, phi_d, marker="o", label=r"$\phi_d$", color="C0")
        ax4.axhline(phi_d_star, ls="--", color="C0", alpha=0.7, label=r"$\phi_d^\ast$")
        ax4.axvline(beta_idx, ls="--", color="0.3", lw=1.5, label=r"$\beta$ chosen")
        ax4.scatter([beta_idx], [phi_d[beta_idx]], s=90, marker="o", color="r", zorder=5)
        ax4.set_xlabel(r"$\beta$ schedule iteration")
        ax4.set_ylabel(r"$\phi_d$", color="C0")
        ax4.tick_params(axis="y", labelcolor="C0")

        ax4b = ax4.twinx()
        ax4b.plot(it, phi_m, marker="s", label=r"$\phi_m$", color="C1")
        ax4b.scatter([beta_idx], [phi_m[beta_idx]], s=90, marker="s", color="r", zorder=5)
        ax4b.set_ylabel(r"$\phi_m$", color="C1")
        ax4b.tick_params(axis="y", labelcolor="C1")

        h1, l1 = ax4.get_legend_handles_labels()
        h2, l2 = ax4b.get_legend_handles_labels()
        ax4.legend(h1 + h2, l1 + l2, loc="best")

        fig.suptitle(
            rf"$\beta_0$ used={beta0_used:.2e} | Frozen={_BETA0_FROZEN is not None} | "
            rf"beta scale={beta_scale_widget.value:g} | $\beta$ idx={beta_idx}"
        )
        plt.show()


def _on_change(change=None):
    """Widget callback: update UI when any control changes."""
    beta_index_widget.max = max(0, int(beta_steps_widget.value) - 1)
    _render()


# Layout
controls_geom = widgets.VBox([
    widgets.HBox([n_data_widget, phi_d_star_widget]),
    widgets.HBox([x_min_widget, x_max_widget]),
    widgets.HBox([z_rx_widget]),
    widgets.HBox([drho_widget, noise_widget]),
    widgets.HBox([x0_widget, z0_widget]),
    widgets.HBox([R_widget]),
    widgets.HTML("<b>Regional trend</b>"),
    trend_enable_widget,
    trend_degree_widget,
    trend_remove_degree_widget,
    trend_amp_widget,
    trend_remove_widget,
])

controls_beta = widgets.VBox([
    beta_steps_widget,
    beta_index_widget,
    cooling_widget,
    widgets.HBox([beta_scale_widget, min_beta_factor_widget]),
    widgets.HBox([freeze_beta0_widget, beta0_display]),
    depthexp_widget,
    widgets.HBox([alpha_s_widget, alpha_x_widget, alpha_z_widget]),
    widgets.HBox([pos_widget, solver_widget]),
])

ui = widgets.HBox([controls_geom, controls_beta])
display(ui, out)

_all_widgets = [
    n_data_widget, x_min_widget, x_max_widget, z_rx_widget,
    drho_widget, x0_widget, z0_widget, R_widget, noise_widget,
    trend_enable_widget, trend_degree_widget, trend_remove_degree_widget, trend_amp_widget, trend_remove_widget,
    phi_d_star_widget,
    beta_steps_widget, beta_index_widget, cooling_widget,
    beta_scale_widget, min_beta_factor_widget,
    freeze_beta0_widget,
    depthexp_widget,
    alpha_s_widget, alpha_x_widget, alpha_z_widget,
    pos_widget, solver_widget,
]
for w in _all_widgets:
    w.observe(_on_change, names="value")

_GRAV_BETA_APP = {"out": out, "widgets": _all_widgets}

_render()

HBox(children=(VBox(children=(HBox(children=(IntSlider(value=51, continuous_update=False, description='N data'…

Output()