# Ordinary Linear Regression Example

# Linear regression as a geophysical inverse problem (matrix form)

We model density as a linear function of iron concentration:

$$
\rho = a\,Fe + b,
$$

with unknown model parameters

$$
\mathbf{m} =
\begin{bmatrix}
a\\
b
\end{bmatrix}.
$$

Given \(N\) measurements \(\{(Fe_i,\rho_i)\}_{i=1}^N\), define the data vector

$$
\mathbf{d} =
\begin{bmatrix}
\rho_1\\
\rho_2\\
\vdots\\
\rho_N
\end{bmatrix}
\in \mathbb{R}^{N},
$$

and the design (kernel) matrix

$$
\mathbf{G} =
\begin{bmatrix}
Fe_1 & 1\\
Fe_2 & 1\\
\vdots & \vdots\\
Fe_N & 1
\end{bmatrix}
\in \mathbb{R}^{N\times 2}.
$$

The linear forward model is:

$$
\mathbf{d} \approx \mathbf{G}\,\mathbf{m}.
$$

---

## Data errors, weighting, and geophysical data misfit

Assume additive Gaussian noise:

$$
\mathbf{d} = \mathbf{G}\mathbf{m} + \boldsymbol{\epsilon},
\qquad
\boldsymbol{\epsilon}\sim\mathcal{N}(\mathbf{0},\mathbf{C}_d).
$$

Define the data-weighting matrix:

$$
\mathbf{W}_d^\top\mathbf{W}_d = \mathbf{C}_d^{-1}.
$$

For independent errors with standard deviation $\sigma_i$:

$$
\mathbf{W}_d = \mathrm{diag}\left(\frac{1}{\sigma_1},\ldots,\frac{1}{\sigma_N}\right).
$$

Define the residual:

$$
\mathbf{r}(\mathbf{m}) = \mathbf{G}\mathbf{m} - \mathbf{d}.
$$

We use the data misfit:

$$
\phi_d(\mathbf{m}) = \|\mathbf{W}_d\,\mathbf{r}(\mathbf{m})\|_2^2.
$$

A common diagnostic is the RMS misfit:

$$
\mathrm{RMS}(\mathbf{m}) = \sqrt{\frac{1}{N}\,\phi_d(\mathbf{m})}.
$$

If “ $\mathbf{W}_d = \mathbf{C}_d^{-1/2}$ ” matches the true data uncertainties, then at an optimal fit we expect:

$$
\mathrm{RMS}(\hat{\mathbf{m}}) \approx 1
\quad\text{and}\quad
\phi_d(\hat{\mathbf{m}}) \approx N.
$$

---

## Three “solutions” we will compare in this notebook

### 1) True model (synthetic reference)

We generate synthetic data from a known model

$$
\mathbf{m}_{true} =
\begin{bmatrix}
a_{true}\\
b_{true}
\end{bmatrix},
$$

and then add noise to produce $\mathbf{d}$. This is **only available in synthetic experiments**.

### 2) Normal-equations solution (closed-form least squares)

Minimizing $\phi_d$ yields the weighted normal equations:

$$
\hat{\mathbf{m}}_{NE}
=
\left(\mathbf{G}^\top\mathbf{W}_d^\top\mathbf{W}_d\mathbf{G}\right)^{-1}
\mathbf{G}^\top\mathbf{W}_d^\top\mathbf{W}_d\,\mathbf{d}
$$

We solve this $2\times 2$ system numerically to obtain $\hat{\mathbf{m}}_{NE}$.

### 3) Optimization (iterative) fit

We also estimate the model by iterative optimization (Gradient Descent or Newton):

- **Gradient Descent**  
  $$\mathbf{m}_{k+1} = \mathbf{m}_k - s\,\nabla J(\mathbf{m}_k), \quad J(\mathbf{m})=\tfrac12\phi_d(\mathbf{m}).$$

- **Newton (damped)**  
  $$\left(\mathbf{H}+\lambda\mathbf{I}\right)\,\Delta\mathbf{m}_k = \nabla J(\mathbf{m}_k),
  \quad \mathbf{m}_{k+1}=\mathbf{m}_k-\alpha_N\Delta\mathbf{m}_k.$$

At convergence (for this linear least-squares problem), the optimization estimate $\hat{\mathbf{m}}_{opt}$ should match the normal-equations solution:

$$
\hat{\mathbf{m}}_{opt} \to \hat{\mathbf{m}}_{NE}.
$$


## Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import traceback

import ipywidgets as widgets
from IPython.display import display, clear_output

from dataclasses import dataclass, field

## OLS of the density and Fe concentration - GD or Newton Methods

In [2]:
import numpy as np
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display, clear_output

from dataclasses import dataclass, field


# =============================================================================
# 1) Forward model utilities
# =============================================================================

def build_design_matrix(fe: np.ndarray) -> np.ndarray:
    """
    Build the design matrix G for the linear model rho = a*Fe + b.

    Parameters
    ----------
    fe : (N,) ndarray
        Iron concentration values.

    Returns
    -------
    G : (N, 2) ndarray
        Design matrix with columns [Fe, 1].
    """
    fe = np.asarray(fe).reshape(-1)
    return np.column_stack([fe, np.ones_like(fe)])


def make_synthetic_data_matrix(
    n: int = 60,
    a_true: float = 0.6,
    b_true: float = 2.5,
    fe_min: float = 0.0,
    fe_max: float = 10.0,
    noise_std: float = 0.4,
    seed: int = 0,
):
    """
    Generate synthetic data for the linear inverse problem:

        d = G m_true + epsilon,

    where epsilon ~ N(0, sigma^2 I) (homoscedastic noise).

    Parameters
    ----------
    n : int
        Number of samples.
    a_true : float
        True slope.
    b_true : float
        True intercept.
    fe_min, fe_max : float
        Range of Fe values.
    noise_std : float
        Standard deviation of additive Gaussian noise on rho (same for all points).
    seed : int
        RNG seed for reproducibility.

    Returns
    -------
    fe : (N,) ndarray
        Iron concentration values.
    d : (N, 1) ndarray
        Noisy densities (data) as a column vector.
    d_clean : (N, 1) ndarray
        Noise-free densities as a column vector.
    G : (N, 2) ndarray
        Design matrix.
    m_true : (2, 1) ndarray
        True model parameters [a_true, b_true]^T.
    sigma : (N, 1) ndarray
        Per-datum standard deviation (all equal to noise_std here).
    """
    rng = np.random.default_rng(seed)

    fe = rng.uniform(float(fe_min), float(fe_max), size=int(n))
    G = build_design_matrix(fe)

    m_true = np.array([[float(a_true)], [float(b_true)]])
    d_clean = G @ m_true

    eps = rng.normal(loc=0.0, scale=float(noise_std), size=(int(n), 1))
    d = d_clean + eps

    sigma = np.full((int(n), 1), float(noise_std))
    return fe, d, d_clean, G, m_true, sigma


def predict_data(G: np.ndarray, m: np.ndarray) -> np.ndarray:
    """
    Forward prediction:

        d_pred = G m

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    m : (2, 1) ndarray
        Model parameters [a, b]^T.

    Returns
    -------
    d_pred : (N, 1) ndarray
        Predicted data.
    """
    return G @ m


# =============================================================================
# 2) Data weighting and misfit
# =============================================================================

def build_Wd_from_sigma(sigma: np.ndarray) -> np.ndarray:
    """
    Build a diagonal data-weighting matrix Wd from per-datum standard deviations.

    For independent errors:
        C_d = diag(sigma_i^2)
        Wd = C_d^{-1/2} = diag(1/sigma_i)

    Parameters
    ----------
    sigma : (N, 1) or (N,) ndarray
        Standard deviations for each datum.

    Returns
    -------
    Wd : (N, N) ndarray
        Diagonal weighting matrix.
    """
    sigma = np.asarray(sigma).reshape(-1)
    if np.any(sigma <= 0):
        raise ValueError("All sigma values must be > 0.")
    return np.diag(1.0 / sigma)


def residual_vector(G: np.ndarray, d: np.ndarray, m: np.ndarray) -> np.ndarray:
    """
    Compute the residual vector:

        r(m) = G m - d

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    d : (N, 1) ndarray
        Observed data.
    m : (2, 1) ndarray
        Model parameters.

    Returns
    -------
    r : (N, 1) ndarray
        Residual vector.
    """
    return (G @ m) - d


def weighted_residual(G: np.ndarray, d: np.ndarray, m: np.ndarray, Wd: np.ndarray | None) -> np.ndarray:
    """
    Compute the (optionally) weighted residual:

        r_w(m) = Wd * (G m - d)

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    d : (N, 1) ndarray
        Observed data.
    m : (2, 1) ndarray
        Model parameters.
    Wd : (N, N) ndarray or None
        Data weighting matrix. If None, returns the unweighted residual.

    Returns
    -------
    r_w : (N, 1) ndarray
        Weighted residual vector (or unweighted if Wd is None).
    """
    r = residual_vector(G, d, m)
    return r if Wd is None else (Wd @ r)


def data_misfit_phi_d(G: np.ndarray, d: np.ndarray, m: np.ndarray, Wd: np.ndarray | None) -> float:
    """
    Data misfit in geophysical inversion form (NOT normalized by N):

        phi_d(m) = || Wd (G m - d) ||_2^2

    If Wd is None, this reduces to:
        phi_d(m) = || G m - d ||_2^2

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    d : (N, 1) ndarray
        Observed data.
    m : (2, 1) ndarray
        Model parameters.
    Wd : (N, N) ndarray or None
        Data weighting matrix.

    Returns
    -------
    phi_d : float
        Scalar data misfit.
    """
    rw = weighted_residual(G, d, m, Wd)
    return float((rw.T @ rw)[0, 0])


def rms_misfit(G: np.ndarray, d: np.ndarray, m: np.ndarray, Wd: np.ndarray | None) -> float:
    """
    RMS misfit associated with the data misfit:

        RMS(m) = sqrt( phi_d(m) / N )

    Interpretation:
    - If Wd = C_d^{-1/2} and the noise model is correct, RMS ~ 1 at optimal fit.
    Newton damping
    --------------
    In the Newton update we sometimes solve (H + λ I) Δm = ∇φ, where λ ≥ 0 is a *damping*
    (Levenberg-style) parameter that stabilizes the solve if H is ill-conditioned.


    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    d : (N, 1) ndarray
        Observed data.
    m : (2, 1) ndarray
        Model parameters.
    Wd : (N, N) ndarray or None
        Data weighting matrix.

    Returns
    -------
    rms : float
        RMS misfit.
    """
    N = G.shape[0]
    return float(np.sqrt(data_misfit_phi_d(G, d, m, Wd) / N))


# =============================================================================
# 3) Gradient, Hessian, and solvers (GD + Newton + Normal Equations)
# =============================================================================

def grad_phi_d(G: np.ndarray, d: np.ndarray, m: np.ndarray, Wd: np.ndarray | None) -> np.ndarray:
    """
    Gradient of the geophysical data misfit:

        phi_d(m) = || Wd (G m - d) ||_2^2

    Let r = (G m - d). Then:
        phi_d = r^T (Wd^T Wd) r
        ∇phi_d = 2 G^T (Wd^T Wd) r

    If Wd is None, Wd^T Wd = I.

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    d : (N, 1) ndarray
        Observed data.
    m : (2, 1) ndarray
        Model parameters.
    Wd : (N, N) ndarray or None
        Data weighting matrix.

    Returns
    -------
    grad : (2, 1) ndarray
        Gradient vector.
    """
    r = residual_vector(G, d, m)  # (N,1)
    if Wd is None:
        A = r
    else:
        A = (Wd.T @ Wd) @ r
    return 2.0 * (G.T @ A)


def hessian_phi_d(G: np.ndarray, Wd: np.ndarray | None) -> np.ndarray:
    """
    Hessian of the geophysical data misfit:

        phi_d(m) = || Wd (G m - d) ||_2^2

    For linear G, the Hessian is constant:
        H = 2 G^T (Wd^T Wd) G

    If Wd is None, Wd^T Wd = I.

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    Wd : (N, N) ndarray or None
        Data weighting matrix.

    Returns
    -------
    H : (2, 2) ndarray
        Hessian matrix.
    """
    if Wd is None:
        return 2.0 * (G.T @ G)
    return 2.0 * (G.T @ (Wd.T @ Wd) @ G)


def gd_step(G: np.ndarray, d: np.ndarray, m: np.ndarray, Wd: np.ndarray | None, alpha: float):
    """
    One *stable* gradient descent step for the geophysical data misfit.

    We keep the geophysical (non-normalized) misfit definition:

        phi_d(m) = || Wd (G m - d) ||_2^2

    but we perform gradient descent on the equivalent quadratic objective

        J(m) = (1/2) * phi_d(m),

    which has the same minimizer and a cleaner gradient/Hessian.

    A common reason for "diverging" behavior after switching from a machine-learning
    style loss (e.g., (1/2N)||r||^2) to a geophysical misfit (||Wd r||^2) is that the
    gradient magnitude scales with:
      - the number of data N, and
      - the weighting level (roughly sigma^{-2} when Wd = C_d^{-1/2}).

    To make the iteration robust without hand-tuning tiny learning rates, we use a
    Lipschitz-stable step size based on the largest eigenvalue of the Hessian:

        grad J(m) = G^T Wd^T Wd (G m - d)
        H_J       = G^T Wd^T Wd G
        s         = alpha / lambda_max(H_J)
        m_{k+1}   = m_k - s * grad J(m_k)

    Here, ``alpha`` is a dimensionless step fraction (typically 0 < alpha <= 1).
    For a strictly convex quadratic, choosing alpha <= 1 ensures monotone decrease
    for well-behaved problems; alpha close to 1 is usually fastest.

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    d : (N, 1) ndarray
        Observed data.
    m : (2, 1) ndarray
        Current model.
    Wd : (N, N) ndarray or None
        Data weighting matrix (Wd = C_d^{-1/2}). If None, Wd is treated as identity.
    alpha : float
        Dimensionless step fraction in (0, 1]. This is *not* an absolute learning rate.

    Returns
    -------
    m_new : (2, 1) ndarray
        Updated model.
    grad_norm : float
        Euclidean norm of grad J at the current iterate (not of grad phi_d).
    step_size : float
        Actual step size ``s = alpha / lambda_max(H_J)`` used in this update.
    lambda_max : float
        Largest eigenvalue of H_J (a Lipschitz constant for grad J).
    """
    alpha = float(alpha)
    if not (alpha > 0.0):
        raise ValueError("alpha must be > 0")
    # grad_phi_d = 2 * G^T WTW r  -> grad_J = 0.5 * grad_phi_d
    gJ = 0.5 * grad_phi_d(G, d, m, Wd)

    # hessian_phi_d = 2 * G^T WTW G -> H_J = 0.5 * hessian_phi_d
    HJ = 0.5 * hessian_phi_d(G, Wd)

    # Largest eigenvalue of a 2x2 SPD matrix (use eigvalsh for numerical stability)
    evals = np.linalg.eigvalsh(HJ)
    lambda_max = float(np.max(evals))

    # Avoid divide-by-zero in pathological cases (e.g., degenerate design matrix)
    eps = 1e-30
    step_size = alpha / (lambda_max + eps)

    m_new = m - step_size * gJ
    return m_new, float(np.linalg.norm(gJ)), float(step_size), float(lambda_max)


def newton_step(G: np.ndarray, d: np.ndarray, m: np.ndarray, Wd: np.ndarray | None, damping: float = 0.0, step_scale: float = 1.0):
    """
    One (optionally damped) Newton step on the geophysical misfit:

        phi_d(m) = || Wd (G m - d) ||^2

    Newton solves:
        (H + damping*I) delta = grad
        m_new = m - step_scale * delta

    Notes
    -----
    - For this 2-parameter linear problem, Newton typically reaches the minimizer
      in one step (up to numerical precision), especially with step_scale=1 and small damping.
    - damping >= 0 provides Levenberg-style stabilization.

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    d : (N, 1) ndarray
        Observed data.
    m : (2, 1) ndarray
        Current model.
    Wd : (N, N) ndarray or None
        Data weighting matrix.
    damping : float
        Damping coefficient λ >= 0.
    step_scale : float
        Step scale α_N in (0, 1]; 1 is standard Newton.

    Returns
    -------
    m_new : (2, 1) ndarray
        Updated model.
    grad_norm : float
        Norm of gradient at current iterate.
    step_norm : float
        Norm of Newton step delta.
    """
    damping = float(damping)
    if damping < 0:
        raise ValueError("damping must be >= 0")

    step_scale = float(step_scale)

    g = grad_phi_d(G, d, m, Wd)
    H = hessian_phi_d(G, Wd)
    H_d = H + damping * np.eye(H.shape[0])

    try:
        delta = np.linalg.solve(H_d, g)
    except np.linalg.LinAlgError:
        delta, *_ = np.linalg.lstsq(H_d, g, rcond=None)

    m_new = m - step_scale * delta
    return m_new, float(np.linalg.norm(g)), float(np.linalg.norm(delta))


def normal_equation_solution(G: np.ndarray, d: np.ndarray, Wd: np.ndarray | None) -> np.ndarray:
    """
    Compute the normal-equation (closed-form) solution for weighted least squares:

        minimize ||Wd (G m - d)||^2

    Normal equations:
        (G^T Wd^T Wd G) m = (G^T Wd^T Wd d)

    If Wd is None, this reduces to:
        (G^T G) m = (G^T d)

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix.
    d : (N, 1) ndarray
        Observed data.
    Wd : (N, N) ndarray or None
        Data weighting matrix.

    Returns
    -------
    m_ne : (2, 1) ndarray
        Normal-equation solution.
    """
    if Wd is None:
        A = G.T @ G
        b = G.T @ d
    else:
        WTW = Wd.T @ Wd
        A = G.T @ WTW @ G
        b = G.T @ WTW @ d
    return np.linalg.solve(A, b)


def misfit_surface_grid(G: np.ndarray, d: np.ndarray, Wd: np.ndarray | None, a_vals: np.ndarray, b_vals: np.ndarray) -> np.ndarray:
    """
    Compute phi_d(a,b) on a grid (vectorized), where:

        phi_d(m) = || Wd (G m - d) ||^2

    Parameters
    ----------
    G : (N, 2) ndarray
        Design matrix (first col is Fe, second is ones).
    d : (N, 1) ndarray
        Observed data.
    Wd : (N, N) ndarray or None
        Data weighting matrix.
    a_vals : (Na,) ndarray
        Grid values for slope a.
    b_vals : (Nb,) ndarray
        Grid values for intercept b.

    Returns
    -------
    Phi : (Nb, Na) ndarray
        Misfit surface values.
    """
    N = G.shape[0]
    Fe = G[:, 0:1]   # (N,1)
    ones = G[:, 1:2] # (N,1)

    A, B = np.meshgrid(a_vals, b_vals)  # (Nb,Na)
    A_flat = A.reshape(1, -1)
    B_flat = B.reshape(1, -1)

    d_pred = Fe @ A_flat + ones @ B_flat  # (N, Nb*Na)
    r = d_pred - d                         # (N, Nb*Na)

    if Wd is not None:
        r = Wd @ r  # still works: (N,N)@(N,M) -> (N,M)

    phi_flat = np.sum(r**2, axis=0)        # ||r||^2 (no /N)
    return phi_flat.reshape(A.shape)


# =============================================================================
# 4) State container
# =============================================================================

@dataclass
class RegressionState:
    """
    Container for dataset + inversion state.

    Stores:
    - data (fe, G, d, d_clean, sigma),
    - weighting matrix Wd (or None),
    - true model m_true (for synthetic reference),
    - normal-equation solution m_ne,
    - current model estimate m,
    - histories for phi_d, RMS, gradient norms, Newton step norms,
    - path in (a,b) space.

    Attributes
    ----------
    fe : (N,) ndarray
        Iron concentration samples (for plotting).
    G : (N,2) ndarray
        Design matrix.
    d : (N,1) ndarray
        Observed data.
    d_clean : (N,1) ndarray
        Noise-free data (synthetic reference).
    sigma : (N,1) ndarray
        Per-datum standard deviations used to generate synthetic noise.
    Wd : (N,N) ndarray or None
        Data weighting matrix. If None, inversion is unweighted.
    m_true : (2,1) ndarray
        True model parameters used for simulation.
    m_ne : (2,1) ndarray
        Normal-equation solution (weighted or unweighted, consistent with Wd).
    m : (2,1) ndarray
        Current estimate.
    it : int
        Iteration counter.
    phi_history : list[float]
        Data misfit values phi_d(m) per iteration (includes iteration 0).
    rms_history : list[float]
        RMS misfit values per iteration (includes iteration 0).
    gradnorm_history : list[float]
        Norm of gradient per iteration step (after each step).
    stepnorm_history : list[float]
        Newton step norm per iteration step (NaN for GD).
    m_history : list[np.ndarray]
        Model vector at each iteration (includes iteration 0).
    """
    fe: np.ndarray = field(default_factory=lambda: np.array([]))
    G: np.ndarray = field(default_factory=lambda: np.empty((0, 2)))
    d: np.ndarray = field(default_factory=lambda: np.empty((0, 1)))
    d_clean: np.ndarray = field(default_factory=lambda: np.empty((0, 1)))
    sigma: np.ndarray = field(default_factory=lambda: np.empty((0, 1)))

    Wd: np.ndarray | None = None

    m_true: np.ndarray = field(default_factory=lambda: np.zeros((2, 1)))
    m_ne: np.ndarray = field(default_factory=lambda: np.zeros((2, 1)))
    m: np.ndarray = field(default_factory=lambda: np.zeros((2, 1)))

    it: int = 0
    phi_history: list = field(default_factory=list)
    rms_history: list = field(default_factory=list)
    gradnorm_history: list = field(default_factory=list)
    stepnorm_history: list = field(default_factory=list)
    m_history: list = field(default_factory=list)

    def reset_model(self, a0: float = 0.0, b0: float = 0.0):
        """
        Reset current model estimate and all histories.

        Parameters
        ----------
        a0, b0 : float
            Initial guess for slope and intercept.
        """
        self.m = np.array([[float(a0)], [float(b0)]])
        self.it = 0

        phi0 = data_misfit_phi_d(self.G, self.d, self.m, self.Wd)
        rms0 = rms_misfit(self.G, self.d, self.m, self.Wd)

        self.phi_history = [phi0]
        self.rms_history = [rms0]
        self.gradnorm_history = []
        self.stepnorm_history = []
        self.m_history = [self.m.copy()]

    @property
    def a(self) -> float:
        """Current slope estimate."""
        return float(self.m[0, 0])

    @property
    def b(self) -> float:
        """Current intercept estimate."""
        return float(self.m[1, 0])

    def update_normal_equations_solution(self):
        """
        Recompute the normal-equation solution m_ne consistent with current Wd.
        """
        self.m_ne = normal_equation_solution(self.G, self.d, self.Wd)

    def step(self, method: str, lr: float = 1e-3, damping: float = 0.0, step_scale: float = 1.0):
        """
        Perform one iteration using either gradient descent or Newton.

        Parameters
        ----------
        method : str
            "gd" or "newton".
        lr : float
            Learning rate for GD. Ignored for Newton.
        damping : float
            Newton damping λ >= 0. Ignored for GD.
        step_scale : float
            Newton step scale α_N in (0,1]. Ignored for GD.
        """
        method = method.lower().strip()

        
        if method == "gd":
            # Here lr is interpreted as a *dimensionless* alpha in (0, 1], not an absolute learning rate.
            self.m, gnorm, step_size, lambda_max = gd_step(self.G, self.d, self.m, self.Wd, alpha=float(lr))
            # Store the actual step length in stepnorm_history for GD (useful diagnostics).
            snorm = float(step_size)
        
        elif method == "newton":
            self.m, gnorm, snorm = newton_step(self.G, self.d, self.m, self.Wd, damping=float(damping), step_scale=float(step_scale))
        else:
            raise ValueError("method must be 'gd' or 'newton'")
        
        self.it += 1
    
        phi = data_misfit_phi_d(self.G, self.d, self.m, self.Wd)
        rms = rms_misfit(self.G, self.d, self.m, self.Wd)
    
        self.phi_history.append(phi)
        self.rms_history.append(rms)
        self.gradnorm_history.append(float(gnorm))
        self.stepnorm_history.append(float(snorm))
        self.m_history.append(self.m.copy())


# =============================================================================
# 5) Plotting + summary
# =============================================================================

def plot_fits_and_histories(state: RegressionState, show_true: bool = True) -> None:
    """
    Plot:
      1) data + current fit + normal-equation fit (+ true line if requested),
      2) data misfit history phi_d with phi_d(m_NE) reference,
      3) RMS misfit history with RMS(m_NE) and target RMS ≈ 1.

    This version avoids Matplotlib mathtext failures in legends/tight_layout by:
      - removing outer '$...$' delimiters in labels
      - using constrained_layout instead of tight_layout
      - closing the figure explicitly after show()

    Parameters
    ----------
    state : RegressionState
        Current inversion state.
    show_true : bool
        If True, overlay the synthetic "true" line for reference.
    """
    fe = state.fe
    if fe.size == 0:
        return

    d = state.d.reshape(-1)

    x_line = np.linspace(fe.min(), fe.max(), 250)
    G_line = build_design_matrix(x_line)

    y_cur = (G_line @ state.m).reshape(-1)
    y_ne = (G_line @ state.m_ne).reshape(-1)
    y_true = (G_line @ state.m_true).reshape(-1)

    phi_ne = data_misfit_phi_d(state.G, state.d, state.m_ne, state.Wd)
    rms_ne = rms_misfit(state.G, state.d, state.m_ne, state.Wd)

    fig = plt.figure(figsize=(14, 4), constrained_layout=True)

    ax1 = fig.add_subplot(1, 3, 1)
    ax1.scatter(fe, d, marker="o", color="k")
    ax1.plot(x_line, y_cur, linewidth=2, color="limegreen", label="fit (current)")
    ax1.plot(x_line, y_ne, linewidth=2, linestyle="--", color="tab:blue", label="fit (Normal Eq.)")
    if show_true:
        ax1.plot(x_line, y_true, linewidth=2, linestyle="--", color="r", label="true")
    ax1.set_xlabel("Fe concentration (%)")
    ax1.set_ylabel("Density (g/cm³)")
    ax1.set_title("Data and fits")
    ax1.legend()

    ax2 = fig.add_subplot(1, 3, 2)
    # --- IMPORTANT: no $...$ wrapper, prevents your mathtext crash ---
    ax2.plot(state.phi_history, color="limegreen", label=r"$\phi_d(m)$")
    ax2.axhline(phi_ne, linestyle="--", color="tab:blue", label=r"$\phi_d(m_{NE})$")
    ax2.set_xlabel("Iteration")
    ax2.set_ylabel(r"$\phi_d$")
    ax2.set_title("Data misfit vs iteration")
    ax2.legend()

    ax3 = fig.add_subplot(1, 3, 3)
    ax3.plot(state.rms_history, color="limegreen", label="RMS(m)")
    ax3.axhline(rms_ne, linestyle="--", color="tab:blue", label="RMS(m_NE)")
    ax3.axhline(1.0, linestyle=":", color="r", label="target RMS ≈ 1")
    ax3.set_xlabel("Iteration")
    ax3.set_ylabel("RMS misfit")
    ax3.set_title("RMS misfit vs iteration")
    ax3.legend()

    plt.show()
    plt.close(fig)


def plot_misfit_surface(
    state: RegressionState,
    span_factor: float = 2.0,
    ngrid: int = 140,
    log_surface: bool = False,
) -> None:
    """
    Plot a contour (or image fallback) of the data misfit surface phi_d(a,b),
    overlaid with:
      - solver path in (a,b),
      - normal-equation solution,
      - true model.

    The function closes the figure after display to prevent duplicate/empty figures
    in widget output contexts.

    Parameters
    ----------
    state : RegressionState
        Current inversion state.
    span_factor : float
        Window around (a_NE, b_NE) sized by span_factor*max(|param|,1).
    ngrid : int
        Grid points per axis.
    log_surface : bool
        If True, plot log10(phi_d + eps).
    """
    if state.G.size == 0:
        return

    a_ne = float(state.m_ne[0, 0])
    b_ne = float(state.m_ne[1, 0])

    a_span = float(span_factor) * max(abs(a_ne), 1.0)
    b_span = float(span_factor) * max(abs(b_ne), 1.0)

    a_vals = np.linspace(a_ne - a_span, a_ne + a_span, int(ngrid))
    b_vals = np.linspace(b_ne - b_span, b_ne + b_span, int(ngrid))

    Phi = misfit_surface_grid(state.G, state.d, state.Wd, a_vals, b_vals)

    # sanitize
    finite = np.isfinite(Phi)
    fill = float(np.nanmax(Phi[finite])) if np.any(finite) else 0.0
    Phi = np.nan_to_num(Phi, nan=fill, posinf=fill, neginf=0.0)

    if log_surface:
        eps = 1e-18
        Z = np.log10(Phi + eps)
        title = r"$\log_{10}(\phi_d(a,b))$ with path"
    else:
        Z = Phi
        title = r"$\phi_d(a,b)$ with path"

    zmin = float(np.nanmin(Z))
    zmax = float(np.nanmax(Z))

    path = np.hstack(state.m_history) if state.m_history else np.array([[a_ne], [b_ne]])
    a_path, b_path = path[0, :], path[1, :]

    a_true = float(state.m_true[0, 0])
    b_true = float(state.m_true[1, 0])

    fig = plt.figure(figsize=(7.2, 6.2))
    ax = fig.add_subplot(1, 1, 1)

    if not np.isfinite(zmin) or not np.isfinite(zmax) or abs(zmax - zmin) < 1e-12:
        im = ax.imshow(
            Z,
            origin="lower",
            aspect="auto",
            extent=[a_vals.min(), a_vals.max(), b_vals.min(), b_vals.max()],
        )
        ax.set_title(title + " (imshow fallback)")
        fig.colorbar(im, ax=ax, shrink=0.85)
    else:
        levels = np.linspace(zmin, zmax, 20)
        ax.contour(a_vals, b_vals, Z, levels=levels)
        ax.set_title(title)

    ax.plot(a_path, b_path, marker="o", markersize=3, linewidth=1, color="k", label="path")
    ax.scatter([a_ne], [b_ne], marker="x", s=80, label="Normal Eq.")
    ax.scatter([a_true], [b_true], marker="*", s=120, label="True")

    ax.set_xlabel("a (slope)")
    ax.set_ylabel("b (intercept)")
    ax.legend()

    plt.tight_layout()
    plt.show()
    plt.close(fig)


def state_summary_text(state: RegressionState, solver_name: str) -> str:
    """
    Build a human-readable summary of the current inversion state.

    This summary is intended for interactive notebook use (e.g., ipywidgets callbacks).
    It reports the current parameter estimate, reference solutions (normal-equation and
    true model), and key diagnostics consistent with geophysical inversion practice.

    Definitions
    -----------
    Data misfit (not normalized by N):
        phi_d(m) = ||Wd (G m - d)||_2^2

    RMS misfit:
        RMS(m) = sqrt( phi_d(m) / N )

    Interpretation:
    - If Wd = C_d^{-1/2} and the noise model is correct, RMS ~ 1 at optimal fit.

    Parameters
    ----------
    state : RegressionState
        State container holding data, weighting, current model, reference solutions,
        and iteration histories.
    solver_name : str
        Human-readable name of the active solver (e.g., "Gradient Descent" or "Newton").

    Returns
    -------
    text : str
        Multi-line formatted string for printing.
    """
    phi_cur = state.phi_history[-1] if state.phi_history else np.nan
    rms_cur = state.rms_history[-1] if state.rms_history else np.nan

    phi_ne = data_misfit_phi_d(state.G, state.d, state.m_ne, state.Wd)
    rms_ne = rms_misfit(state.G, state.d, state.m_ne, state.Wd)

    gnorm = state.gradnorm_history[-1] if state.gradnorm_history else np.nan
    snorm = state.stepnorm_history[-1] if state.stepnorm_history else np.nan

    a_true = float(state.m_true[0, 0])
    b_true = float(state.m_true[1, 0])

    a_ne = float(state.m_ne[0, 0])
    b_ne = float(state.m_ne[1, 0])

    err_to_ne = float(np.linalg.norm(state.m - state.m_ne))

    weighting_label = "weighted (Wd = C_d^{-1/2})" if state.Wd is not None else "unweighted (Wd = I)"

    return (
        f"Solver: {solver_name}\n"
        f"Weighting: {weighting_label}\n"
        f"Iteration: {state.it}\n"
        f"Current        m = [a, b]^T = [{state.a:.6f}, {state.b:.6f}]^T\n"
        f"Normal Eq. sol m = [a, b]^T = [{a_ne:.6f}, {b_ne:.6f}]^T\n"
        f"True           m = [a, b]^T = [{a_true:.6f}, {b_true:.6f}]^T\n"
        f"||m - m_NE||:  {err_to_ne:.6e}\n"
        f"phi_d(m):      {phi_cur:.6e}     RMS(m):      {rms_cur:.6e}\n"
        f"phi_d(m_NE):   {phi_ne:.6e}     RMS(m_NE):   {rms_ne:.6e}\n"
        f"||∇phi_d||:    {gnorm:.6e}\n"
        f"Step metric:  {snorm:.6e}  (GD: step size s; Newton: ||Δm||)"
    )

## Interactive widgets (iterations step-by-step)

In [3]:
# =============================================================================
# 6) Interactive UI (widgets) - robust, single-instance + per-instance callbacks
# =============================================================================

import traceback
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import ipywidgets as widgets

# -----------------------------------------------------------------------------
# Single-instance management
# -----------------------------------------------------------------------------
# We keep ONE persistent app_out and close prior widget instances on re-run.
if "_FE_RHO_APP" in globals():
    try:
        # Reuse the same container output so the UI doesn't duplicate in the notebook
        _prev_out = _FE_RHO_APP.get("app_out", None)
        if isinstance(_prev_out, widgets.Output):
            app_out = _prev_out
        # Close prior widgets to stop callbacks and accidental cross-wiring
        for w in _FE_RHO_APP.get("widgets", []):
            try:
                w.close()
            except Exception:
                pass
    except Exception:
        pass

# If no previous container exists, create it once and display it once
if "app_out" not in globals() or not isinstance(app_out, widgets.Output):
    app_out = widgets.Output()
    display(app_out)

# Close stray matplotlib figures to avoid blank/extra renders
plt.close("all")

# -----------------------------------------------------------------------------
# Build UI inside the persistent container (all callbacks close over THIS instance)
# -----------------------------------------------------------------------------
with app_out:
    clear_output(wait=True)

    # Local re-entrancy lock (per UI instance)
    _refresh_lock = {"busy": False}

    # Outputs for this instance
    out_text = widgets.Output()
    out_plot = widgets.Output()

    # ------------------------
    # Controls (this instance)
    # ------------------------
    # Data generation controls
    n_w = widgets.IntSlider(value=60, min=10, max=300, step=5, description="N")
    noise_w = widgets.FloatSlider(value=0.4, min=1e-6, max=1000.0, step=0.1, description="noise σ")
    seed_w = widgets.IntSlider(value=42, min=0, max=9999, step=1, description="seed")
    a_true_w = widgets.FloatSlider(value=0.6, min=-3.0, max=100.0, step=0.05, description="a_true")
    b_true_w = widgets.FloatSlider(value=2.5, min=-10.0, max=10.0, step=0.1, description="b_true")

    # Weighting toggle
    use_weighting_w = widgets.Checkbox(value=True, description="use Wd (target RMS≈1)")

    # Initialization
    a0_w = widgets.FloatText(value=0.0, description="a0")
    b0_w = widgets.FloatText(value=0.0, description="b0")

    # Solver selection (values are exactly what RegressionState.step expects)
    optimizer_w = widgets.Dropdown(
        options=[("Gradient descent", "gd"), ("Newton", "newton")],
        value="gd",
        description="solver",
    )

    # GD control: α in (0,1] (dimensionless)
    alpha_gd_w = widgets.FloatSlider(value=0.8, min=0.01, max=1.0, step=0.01, description="α (GD)")

    # Newton controls
    newton_damp_w = widgets.FloatLogSlider(value=1e-12, base=10, min=-16, max=2, step=0.25, description="λ (damp)")
    newton_alpha_w = widgets.FloatSlider(value=1.0, min=0.05, max=1.0, step=0.05, description="α (Newton)")

    # Convergence + surface
    tol_w = widgets.FloatLogSlider(value=1e-6, base=10, min=-12, max=-2, step=0.1, description="tol (||∇||)")
    maxit_w = widgets.IntSlider(value=5000, min=10, max=10000, step=10, description="maxit")
    span_w = widgets.FloatSlider(value=2.0, min=0.25, max=6.0, step=0.25, description="surface span")
    logsurf_w = widgets.Checkbox(value=False, description="log10 surface")
    show_true_w = widgets.Checkbox(value=True, description="show true line")

    # Buttons
    btn_regen = widgets.Button(description="Regenerate data", button_style="primary")
    btn_reset = widgets.Button(description="Reset model", button_style="")
    btn_step = widgets.Button(description="Step (1 iter)", button_style="success")
    btn_10 = widgets.Button(description="Step (10 iters)", button_style="success")
    btn_converge = widgets.Button(description="Run to convergence", button_style="warning")

    # State for this instance
    state = RegressionState()

    # ------------------------
    # Helpers (this instance)
    # ------------------------
    def _solver_label() -> str:
        """Return a human-readable label for the dropdown solver."""
        return "Gradient Descent" if optimizer_w.value == "gd" else "Newton"

    def _refresh_outputs() -> None:
        """
        Refresh text + plots for this UI instance.

        Notes
        -----
        Uses a per-instance lock to prevent duplicate redraws caused by chained widget events.
        """
        if _refresh_lock["busy"]:
            return
        _refresh_lock["busy"] = True
        try:
            with out_text:
                clear_output(wait=True)
                print(state_summary_text(state, _solver_label()))

            with out_plot:
                clear_output(wait=True)
                try:
                    plt.close("all")
                    plot_fits_and_histories(state, show_true=bool(show_true_w.value))
                    plot_misfit_surface(
                        state,
                        span_factor=float(span_w.value),
                        ngrid=140,
                        log_surface=bool(logsurf_w.value),
                    )
                except Exception:
                    print("Plotting failed. Full traceback:\n")
                    print(traceback.format_exc())
        finally:
            _refresh_lock["busy"] = False

    def _regenerate_data(_=None) -> None:
        """Generate a new synthetic dataset and reset the inversion state."""
        fe, d, d_clean, G, m_true, sigma = make_synthetic_data_matrix(
            n=int(n_w.value),
            a_true=float(a_true_w.value),
            b_true=float(b_true_w.value),
            noise_std=float(noise_w.value),
            seed=int(seed_w.value),
        )

        state.fe = fe
        state.G = G
        state.d = d
        state.d_clean = d_clean
        state.m_true = m_true
        state.sigma = sigma

        state.Wd = build_Wd_from_sigma(sigma) if bool(use_weighting_w.value) else None
        state.update_normal_equations_solution()
        state.reset_model(a0=float(a0_w.value), b0=float(b0_w.value))

        _refresh_outputs()

    def _reset_model(_=None) -> None:
        """Reset the current model estimate and histories to the initial guess."""
        state.reset_model(a0=float(a0_w.value), b0=float(b0_w.value))
        _refresh_outputs()

    def _step_k(k: int) -> None:
        """Advance the selected solver by k iterations (bounded by maxit)."""
        method = optimizer_w.value  # "gd" or "newton"
        alpha_gd = float(alpha_gd_w.value)
        damp = float(newton_damp_w.value)
        alphaN = float(newton_alpha_w.value)
        maxit = int(maxit_w.value)

        for _ in range(int(k)):
            if state.it >= maxit:
                break
            state.step(method=method, lr=alpha_gd, damping=damp, step_scale=alphaN)

    def _on_step(_=None) -> None:
        _step_k(1)
        _refresh_outputs()

    def _on_step10(_=None) -> None:
        _step_k(10)
        _refresh_outputs()

    def _on_converge(_=None) -> None:
        """
        Run until ||∇phi_d|| < tol (GD) or do a single Newton step.

        Notes
        -----
        For this quadratic 2-parameter problem, Newton should jump to the minimizer
        in (about) one step (up to damping/step_scale).
        """
        method = optimizer_w.value
        alpha_gd = float(alpha_gd_w.value)
        damp = float(newton_damp_w.value)
        alphaN = float(newton_alpha_w.value)
        tol = float(tol_w.value)
        maxit = int(maxit_w.value)

        if method == "newton":
            if state.it < maxit:
                state.step(method="newton", lr=alpha_gd, damping=damp, step_scale=alphaN)
            _refresh_outputs()
            return

        # Gradient descent: iterate to tolerance
        while state.it < maxit:
            # ensure history exists
            if state.it == 0 and not state.gradnorm_history:
                state.step(method="gd", lr=alpha_gd, damping=damp, step_scale=alphaN)

            gnorm = state.gradnorm_history[-1] if state.gradnorm_history else np.inf
            if gnorm < tol:
                break

            state.step(method="gd", lr=alpha_gd, damping=damp, step_scale=alphaN)

        _refresh_outputs()

    # ------------------------
    # Wire callbacks (this instance)
    # ------------------------
    btn_regen.on_click(_regenerate_data)
    btn_reset.on_click(_reset_model)
    btn_step.on_click(_on_step)
    btn_10.on_click(_on_step10)
    btn_converge.on_click(_on_converge)

    # Make solver changes immediately visible in the printed header (and keep plots consistent)
    optimizer_w.observe(lambda ch: _refresh_outputs(), names="value")

    def _on_weighting_toggle(_):
        """Update Wd and recompute m_NE without regenerating data."""
        if getattr(state, "sigma", np.array([])).size == 0:
            return
        state.Wd = build_Wd_from_sigma(state.sigma) if bool(use_weighting_w.value) else None
        state.update_normal_equations_solution()
        _refresh_outputs()

    use_weighting_w.observe(_on_weighting_toggle, names="value")

    # ------------------------
    # Layout
    # ------------------------
    controls_left = widgets.VBox([
        widgets.HTML("<b>Data generation</b>"),
        n_w, noise_w, seed_w, a_true_w, b_true_w,
        use_weighting_w,
        btn_regen,
    ])

    controls_mid = widgets.VBox([
        widgets.HTML("<b>Initialization</b>"),
        a0_w, b0_w,
        btn_reset,
        widgets.HTML("<b>Optimization</b>"),
        optimizer_w,
        alpha_gd_w,
        newton_damp_w,
        newton_alpha_w,
        tol_w, maxit_w,
        widgets.HTML("<b>Misfit surface</b>"),
        span_w, logsurf_w, show_true_w,
    ])

    controls_right = widgets.VBox([
        widgets.HTML("<b>Iterate</b>"),
        btn_step,
        btn_10,
        btn_converge,
    ])

    ui = widgets.HBox([controls_left, controls_mid, controls_right])
    display(ui, out_text, out_plot)

    # Save references so we can close them on next run (including container)
    _FE_RHO_APP = {
        "app_out": app_out,
        "widgets": [
            out_text, out_plot,
            n_w, noise_w, seed_w, a_true_w, b_true_w, use_weighting_w,
            a0_w, b0_w,
            optimizer_w, alpha_gd_w, newton_damp_w, newton_alpha_w,
            tol_w, maxit_w, span_w, logsurf_w, show_true_w,
            btn_regen, btn_reset, btn_step, btn_10, btn_converge,
            ui, controls_left, controls_mid, controls_right
        ]
    }

# Initialize once
_regenerate_data()

Output()

In [4]:
# =============================================================================
# 6) Interactive UI (widgets) - robust, single-instance + per-instance callbacks
# =============================================================================

import traceback
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import ipywidgets as widgets

# -----------------------------------------------------------------------------
# Single-instance management
# -----------------------------------------------------------------------------
# We keep ONE persistent app_out and close prior widget instances on re-run.

# Try to reuse an existing app_out if present
_prev_app_out = None
if "_FE_RHO_APP" in globals():
    try:
        _prev_app_out = _FE_RHO_APP.get("app_out", None)
    except Exception:
        _prev_app_out = None

# Reuse if it's a valid Output widget, else create a new one
if isinstance(_prev_app_out, widgets.Output):
    app_out = _prev_app_out
else:
    app_out = widgets.Output()

# IMPORTANT: Always display app_out on every run.
# Jupyter replaces the cell output on re-run; if you don't re-display, you see nothing.
display(app_out)

# Close prior widgets (but NEVER close the persistent container app_out)
if "_FE_RHO_APP" in globals():
    try:
        for w in _FE_RHO_APP.get("widgets", []):
            try:
                # Skip closing the persistent container if it ever appears here
                if w is app_out:
                    continue
                w.close()
            except Exception:
                pass
    except Exception:
        pass

# Close stray matplotlib figures to avoid blank/extra renders
plt.close("all")

# -----------------------------------------------------------------------------
# Build UI inside the persistent container (all callbacks close over THIS instance)
# -----------------------------------------------------------------------------
with app_out:
    clear_output(wait=True)

    # Local re-entrancy lock (per UI instance)
    _refresh_lock = {"busy": False}

    # Outputs for this instance
    out_text = widgets.Output()
    out_plot = widgets.Output()

    # ------------------------
    # Controls (this instance)
    # ------------------------
    # Data generation controls
    n_w = widgets.IntSlider(value=60, min=10, max=300, step=5, description="N")
    noise_w = widgets.FloatSlider(value=0.4, min=1e-6, max=1000.0, step=0.1, description="noise σ")
    seed_w = widgets.IntSlider(value=42, min=0, max=9999, step=1, description="seed")
    a_true_w = widgets.FloatSlider(value=0.6, min=-3.0, max=100.0, step=0.05, description="a_true")
    b_true_w = widgets.FloatSlider(value=2.5, min=-10.0, max=10.0, step=0.1, description="b_true")

    # Weighting toggle
    use_weighting_w = widgets.Checkbox(value=True, description="use Wd (target RMS≈1)")

    # Initialization
    a0_w = widgets.FloatText(value=0.0, description="a0")
    b0_w = widgets.FloatText(value=0.0, description="b0")

    # Solver selection (values are exactly what RegressionState.step expects)
    optimizer_w = widgets.Dropdown(
        options=[("Gradient descent", "gd"), ("Newton", "newton")],
        value="gd",
        description="solver",
    )

    # GD control: α in (0,1] (dimensionless)
    alpha_gd_w = widgets.FloatSlider(value=0.8, min=0.01, max=1.0, step=0.01, description="α (GD)")

    # Newton controls
    newton_damp_w = widgets.FloatLogSlider(value=1e-12, base=10, min=-16, max=2, step=0.25, description="λ (damp)")
    newton_alpha_w = widgets.FloatSlider(value=1.0, min=0.05, max=1.0, step=0.05, description="α (Newton)")

    # Convergence + surface
    tol_w = widgets.FloatLogSlider(value=1e-6, base=10, min=-12, max=-2, step=0.1, description="tol (||∇||)")
    maxit_w = widgets.IntSlider(value=5000, min=10, max=10000, step=10, description="maxit")
    span_w = widgets.FloatSlider(value=2.0, min=0.25, max=6.0, step=0.25, description="surface span")
    logsurf_w = widgets.Checkbox(value=False, description="log10 surface")
    show_true_w = widgets.Checkbox(value=True, description="show true line")

    # Buttons
    btn_regen = widgets.Button(description="Regenerate data", button_style="primary")
    btn_reset = widgets.Button(description="Reset model", button_style="")
    btn_step = widgets.Button(description="Step (1 iter)", button_style="success")
    btn_10 = widgets.Button(description="Step (10 iters)", button_style="success")
    btn_converge = widgets.Button(description="Run to convergence", button_style="warning")

    # State for this instance
    state = RegressionState()

    # ------------------------
    # Helpers (this instance)
    # ------------------------
    def _solver_label() -> str:
        """Return a human-readable label for the dropdown solver."""
        return "Gradient Descent" if optimizer_w.value == "gd" else "Newton"

    def _refresh_outputs() -> None:
        """
        Refresh text + plots for this UI instance.

        Notes
        -----
        Uses a per-instance lock to prevent duplicate redraws caused by chained widget events.
        """
        if _refresh_lock["busy"]:
            return
        _refresh_lock["busy"] = True
        try:
            with out_text:
                clear_output(wait=True)
                print(state_summary_text(state, _solver_label()))

            with out_plot:
                clear_output(wait=True)
                try:
                    plt.close("all")
                    plot_fits_and_histories(state, show_true=bool(show_true_w.value))
                    plot_misfit_surface(
                        state,
                        span_factor=float(span_w.value),
                        ngrid=140,
                        log_surface=bool(logsurf_w.value),
                    )
                except Exception:
                    print("Plotting failed. Full traceback:\n")
                    print(traceback.format_exc())
        finally:
            _refresh_lock["busy"] = False

    def _regenerate_data(_=None) -> None:
        """Generate a new synthetic dataset and reset the inversion state."""
        fe, d, d_clean, G, m_true, sigma = make_synthetic_data_matrix(
            n=int(n_w.value),
            a_true=float(a_true_w.value),
            b_true=float(b_true_w.value),
            noise_std=float(noise_w.value),
            seed=int(seed_w.value),
        )

        state.fe = fe
        state.G = G
        state.d = d
        state.d_clean = d_clean
        state.m_true = m_true
        state.sigma = sigma

        state.Wd = build_Wd_from_sigma(sigma) if bool(use_weighting_w.value) else None
        state.update_normal_equations_solution()
        state.reset_model(a0=float(a0_w.value), b0=float(b0_w.value))

        _refresh_outputs()

    def _reset_model(_=None) -> None:
        """Reset the current model estimate and histories to the initial guess."""
        state.reset_model(a0=float(a0_w.value), b0=float(b0_w.value))
        _refresh_outputs()

    def _step_k(k: int) -> None:
        """Advance the selected solver by k iterations (bounded by maxit)."""
        method = optimizer_w.value  # "gd" or "newton"
        alpha_gd = float(alpha_gd_w.value)
        damp = float(newton_damp_w.value)
        alphaN = float(newton_alpha_w.value)
        maxit = int(maxit_w.value)

        for _ in range(int(k)):
            if state.it >= maxit:
                break
            state.step(method=method, lr=alpha_gd, damping=damp, step_scale=alphaN)

    def _on_step(_=None) -> None:
        _step_k(1)
        _refresh_outputs()

    def _on_step10(_=None) -> None:
        _step_k(10)
        _refresh_outputs()

    def _on_converge(_=None) -> None:
        """
        Run until ||∇phi_d|| < tol (GD) or do a single Newton step.

        Notes
        -----
        For this quadratic 2-parameter problem, Newton should jump to the minimizer
        in (about) one step (up to damping/step_scale).
        """
        method = optimizer_w.value
        alpha_gd = float(alpha_gd_w.value)
        damp = float(newton_damp_w.value)
        alphaN = float(newton_alpha_w.value)
        tol = float(tol_w.value)
        maxit = int(maxit_w.value)

        if method == "newton":
            if state.it < maxit:
                state.step(method="newton", lr=alpha_gd, damping=damp, step_scale=alphaN)
            _refresh_outputs()
            return

        # Gradient descent: iterate to tolerance
        while state.it < maxit:
            # ensure history exists
            if state.it == 0 and not state.gradnorm_history:
                state.step(method="gd", lr=alpha_gd, damping=damp, step_scale=alphaN)

            gnorm = state.gradnorm_history[-1] if state.gradnorm_history else np.inf
            if gnorm < tol:
                break

            state.step(method="gd", lr=alpha_gd, damping=damp, step_scale=alphaN)

        _refresh_outputs()

    # ------------------------
    # Wire callbacks (this instance)
    # ------------------------
    btn_regen.on_click(_regenerate_data)
    btn_reset.on_click(_reset_model)
    btn_step.on_click(_on_step)
    btn_10.on_click(_on_step10)
    btn_converge.on_click(_on_converge)

    # Make solver changes immediately visible in the printed header (and keep plots consistent)
    optimizer_w.observe(lambda ch: _refresh_outputs(), names="value")

    def _on_weighting_toggle(_):
        """Update Wd and recompute m_NE without regenerating data."""
        if getattr(state, "sigma", np.array([])).size == 0:
            return
        state.Wd = build_Wd_from_sigma(state.sigma) if bool(use_weighting_w.value) else None
        state.update_normal_equations_solution()
        _refresh_outputs()

    use_weighting_w.observe(_on_weighting_toggle, names="value")

    # ------------------------
    # Layout
    # ------------------------
    controls_left = widgets.VBox([
        widgets.HTML("<b>Data generation</b>"),
        n_w, noise_w, seed_w, a_true_w, b_true_w,
        use_weighting_w,
        btn_regen,
    ])

    controls_mid = widgets.VBox([
        widgets.HTML("<b>Initialization</b>"),
        a0_w, b0_w,
        btn_reset,
        widgets.HTML("<b>Optimization</b>"),
        optimizer_w,
        alpha_gd_w,
        newton_damp_w,
        newton_alpha_w,
        tol_w, maxit_w,
        widgets.HTML("<b>Misfit surface</b>"),
        span_w, logsurf_w, show_true_w,
    ])

    controls_right = widgets.VBox([
        widgets.HTML("<b>Iterate</b>"),
        btn_step,
        btn_10,
        btn_converge,
    ])

    ui = widgets.HBox([controls_left, controls_mid, controls_right])
    display(ui, out_text, out_plot)

    # Save references so we can close them on next run (including container)
    _FE_RHO_APP = {
        "app_out": app_out,
        "widgets": [
            # DO NOT include app_out here; we want to reuse it forever
            out_text, out_plot,
            n_w, noise_w, seed_w, a_true_w, b_true_w, use_weighting_w,
            a0_w, b0_w,
            optimizer_w, alpha_gd_w, newton_damp_w, newton_alpha_w,
            tol_w, maxit_w, span_w, logsurf_w, show_true_w,
            btn_regen, btn_reset, btn_step, btn_10, btn_converge,
            ui, controls_left, controls_mid, controls_right
        ]
    }

    # Initialize once (inside the container, so the output updates immediately)
    _regenerate_data()


Output(outputs=({'output_type': 'display_data', 'data': {'text/plain': "HBox(children=(VBox(children=(HTML(val…