# Conditioning the Linear System Prior on Observations of the PDE Solution

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn
import scipy.sparse

import linpde_gp

In [None]:
%matplotlib inline

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In the following, we will again look at the Poisson equation on a 1D domain with Dirichlet boundary conditions.

In [None]:
domain = (-1.0, 1.0)

bvp = linpde_gp.problems.pde.poisson_1d_bvp(
    domain=domain,
    rhs=2.0,
    boundary_values=(0.0, 0.0),
)

In [None]:
basis = linpde_gp.galerkin.bases.ZeroBoundaryFiniteElementBasis(domain, num_elements=103)
linsys = linpde_gp.galerkin.discretize(bvp, basis)

In [None]:
L_inc = linpde_gp.linalg.pivoted_cholesky(linsys.A.todense(), 20)

P = linpde_gp.linops.LowRankMatrix(L_inc)

plt.imshow(P.pinv.todense(), cmap="bwr")

In [None]:
L_inc = linpde_gp.linalg.pivoted_cholesky(linsys.A.todense(), 20)

scipy.linalg.qr(L_inc, mode="economic")

P = linpde_gp.linops.LowRankUpdate(
    1.0 * pn.linops.Identity(linsys.A.shape[0]),
    L_inc
)

plt.imshow(P.inv().todense(), cmap="bwr", vmin=-1, vmax=1)

If we have (noisy) measurements of the solution of the PDE, we can use the information to speed up inference.

Let $(v_i)_{i = 1}^n$ be the chosen basis.
In our formulation, we posit a multivariate Gaussian prior over the coefficients $\vec{a} \in \mathbb{R}^n$ of the discretized solution $\hat{u} = \sum_{i = 1}^n a_i v_i$ to the PDE, i.e. $\vec{a} \sim \mathcal{N}(\mu_0, \Sigma_0)$.
We can relate the discretized solution $\hat{u}$ to the coefficients by a linear operator $$(\mathcal{L}_u \vec{a})(x) = \sum_{i = 1}^n a_i v_i(x).$$
Moreover, the solution can be evaluated at several locations $x_1, \dotsc, x_m \in \Omega$ by another linear operator $$(\mathcal{L}_\delta u)_j = \int_\Omega \delta(\chi - x_j) u(\chi) d \chi = u(x_j).$$
All in all, we obtain the following linear operator which maps $\vec{a}$ to a vector of measurements at $x_1, \dotsc, x_m \in \Omega$: $$(L_y \vec{a})_j = (\mathcal{L}_\delta \mathcal{L}_u \vec{a})_j = \int_\Omega \delta(\chi - x_j) (\mathcal{L}_u \vec{a})(\chi) d \chi = \sum_{i = 1}^n a_i \int_\Omega \delta(\chi - x_j) v_i(\chi) d\chi = \sum_{i = 1}^n a_i v_i(x_j)$$
If we now assume additive Gaussian measurement noise on independent observations $y_1, \dotsc, y_m$ of the solution at locations $x_1, \dotsc, x_m \in \Omega$, we obtain the following measurement likelihood:
$$p(y_1, \dots, y_m \mid u(x_1), \dotsc, u(x_m)) = \mathcal{N}(\vec{y} \mid \begin{pmatrix} u(x_1), \dotsc, u(x_m) \end{pmatrix}^T, \Lambda),$$
or, equivalently,
$$p(y_1, \dots, y_m \mid \vec{a}) = \mathcal{N}(\vec{y} \mid L_\vec{y} \vec{a}, \Lambda).$$
Since the model is linear-Gaussian, we can compute the posterior in closed form.
Note that this is exactly the supervised regression setting.

In [None]:
rng = np.random.default_rng(42)

In [None]:
num_measurements = 3

In [None]:
# Measure the solution at equidistant interior points
meas_xs = np.linspace(*domain, num_measurements + 2)[1:-1]
true_ys = bvp.solution(meas_xs)

# Add measurement noise
measurement_noise = pn.randvars.Normal(
    mean=np.zeros(num_measurements, dtype=np.double),
    cov=pn.linops.Scaling((1e-2) ** 2, shape=num_measurements, dtype=np.double),
)

meas_ys = true_ys + measurement_noise.sample(rng)

In [None]:
meas_ys - true_ys

In [None]:
# Build the prior
prior = pn.randvars.Normal(
    mean=np.zeros(len(basis), dtype=np.double),
    cov=pn.linops.aslinop(linsys.A.A.tocsc()).inv(),
#     cov=(P @ P.T).inv(),
#     cov=pn.linops.aslinop(linsys.A.A.tocsc()).inv() - P.pinv,
)

In [None]:
# Build the observation operator
L_yu = basis.observation_operator(meas_xs)

In [None]:
# Build the noise model
noise_model = pn.randvars.Normal(
    mean=np.zeros(num_measurements, dtype=np.double),
    cov=measurement_noise.cov,
    # cov=1e2 * measurement_noise.cov,
)

In [None]:
# Condition the prior on the measurements
prior_cond_meas = prior.condition_on_observations(
    observations=meas_ys,
    noise=noise_model,
    transform=L_yu,
)

In [None]:
xs_plot = np.linspace(*domain, 200)

u_prior_cond_meas = basis.coords2fn(prior_cond_meas)

u_prior_cond_meas.plot(plt.gca(), xs_plot)
plt.scatter(meas_xs, meas_ys, marker="+", label="Measurements")
plt.legend()
plt.show()

## Feed conditioned prior into probabilistic solver

In [None]:
from matplotlib import animation

def animate_probsolve_poisson_1d(
    basis,
    linsys,
    x0,
    solver=linpde_gp.linalg.solvers.bayescg,
    **solver_kwargs
):
    n = len(basis)
    
    # Run the algorithm and log step statistics
    step_xs = []
    step_residuals = []
    step_actions = []

    def _callback(x: pn.randvars.Normal, residual: np.ndarray, action: np.ndarray, **kwargs):
        step_xs.append(x)
        step_residuals.append(residual.copy())
        
        if action is not None:
            step_actions.append(action.copy())

    solver(
        linsys.A,
        linsys.b,
        x0=x0,
        callback=_callback,
        **solver_kwargs,
    )
    
    ASigmaA = linsys.A @ x0.cov @ linsys.A.T
    
    residuals = np.vstack(step_residuals).T
    residual_2_norms = np.linalg.norm(residuals, ord=2, axis=0)
    residual_A_norms = np.sqrt(np.sum(residuals * (linsys.A @ residuals), axis=0))
    residual_ASigmaA_norms = np.sqrt(np.sum(residuals * (ASigmaA @ residuals), axis=0))
    
    action_A_inprods = linpde_gp.linalg.pairwise_inprods(
        step_actions,
        inprod=linsys.A,
        normalize=True,
    )
    
    action_ASigmaA_inprods = linpde_gp.linalg.pairwise_inprods(
        step_actions,
        inprod=ASigmaA,
        normalize=True,
    )

    fig, ax = plt.subplots(
        nrows=2,
        ncols=3,
        figsize=(22, 12),
        #dpi=100,
    )

    plt.close()
    
    xs_plot = np.linspace(*domain, 200)
    
    if isinstance(basis, linpde_gp.galerkin.bases.ZeroBoundaryFiniteElementBasis):
        basis_str = "Zero Boundary FEM"
    elif isinstance(basis, linpde_gp.galerkin.bases.FiniteElementBasis):
        basis_str = "FEM"
    else:
        basis_str = "Unknown Basis"
    
    if solver is linpde_gp.linalg.solvers.bayescg:
        solver_str = "BayesCG"
    elif solver is linpde_gp.linalg.solvers.problinsolve:
        solver_str = "problinsolve"
    else:
        solver_str = "Unknown Solver"
    
    def animate(step_idx):
        for i in range(ax.shape[0]):
            for j in range(ax.shape[1]):
                ax[i, j].cla()

        u = basis.coords2fn(coords=step_xs[step_idx])

        fig.suptitle(f"1D Poisson - {basis_str} (N = {n}) - {solver_str} - Iteration {step_idx:03d}")

        ax[0, 0].set_title("Solution")
        ax[0, 0].plot(xs_plot, bvp.solution(xs_plot), label="Exact Solution")
        u.plot(ax[0, 0], xs_plot, color="C1", label="FEM Solution")
        ax[0, 0].legend()

        ax[0, 1].set_title("Residual 2-norm")
        ax[0, 1].semilogy(residual_2_norms[:step_idx + 1], "C0")
        
        ax[0, 2].set_title("Residual $A$-norm")
        ax[0, 2].semilogy(residual_A_norms[:step_idx + 1], "C0")
        
        ax[1, 0].set_title("Residual $A \Sigma_0 A^T$-norm")
        ax[1, 0].semilogy(residual_ASigmaA_norms[:step_idx + 1], "C0")

        ax[1, 1].set_title("Normalized $A$-inner products between the actions")

        if step_idx > 0:
            ax[1, 1].matshow(
                action_A_inprods[:step_idx, :step_idx],
                cmap="bwr",
                vmin=-1.0,
                vmax=1.0,
            )
            
        ax[1, 2].set_title("Normalized $A \Sigma_0 A^T$-inner products between the actions")

        if step_idx > 0:
            ax[1, 2].matshow(
                action_ASigmaA_inprods[:step_idx, :step_idx],
                cmap="bwr",
                vmin=-1.0,
                vmax=1.0,
            )
        
        fig.tight_layout(rect=[0, 0.03, 1, 0.95])

    return animation.FuncAnimation(
        fig,
        func=animate,
        frames=len(step_xs),
        interval=200,
        repeat_delay=4000,
        blit=False,
    )

In [None]:
from IPython.display import HTML

anim = animate_probsolve_poisson_1d(
    basis,
    linsys,
    x0=prior_cond_meas,
#     x0=pn.randvars.Normal(
#         mean=prior_cond_meas.mean,
#         cov=prior.cov,
#     ),
    solver=linpde_gp.linalg.solvers.bayescg,
#     solver=linpde_gp.linalg.solvers.problinsolve,
    maxiter=len(basis),
    reorthogonalize=True,
#     noise_var=1e-6,
#     rng=np.random.default_rng(50),
)

HTML(anim.to_jshtml())

In [None]:
anim.save("../results/fem_probsolve_data.gif", animation.PillowWriter(fps=5))

**Observations:**
- instabilities
- no monotonic descent
- actions lose $A$-conjugacy for some clusters of steps

Nevertheless, the solver seems to converge to the correct solution.

**Hypotheses:**
- $A \Sigma_0 A^T$ is ill-conditioned or even numerically singular
- as a result, the Gram matrix in the belief update is very small, which leads to instable updates

## Properties of $\Sigma_0$ and $A \Sigma_0 A^T$

### Spectra

In [None]:
Sigma_eigvals, Sigma_eigvecs = np.linalg.eigh(prior.cov.todense())
Sigma_meas_eigvals, Sigma_meas_eigvecs = np.linalg.eigh(prior_cond_meas.cov.todense())

In [None]:
Sigma_meas_eigvals

In [None]:
plt.semilogy(Sigma_eigvals, label="without measurements")
plt.semilogy(Sigma_meas_eigvals, label="with measurements")
plt.xlabel("$i$")
plt.ylabel("$\lambda_{i + 1}(\Sigma_0)$")
plt.title("Spectrum of $\Sigma_0$")
plt.legend()
plt.show()

In [None]:
A_eigvals, A_eigvecs = np.linalg.eigh(linsys.A.todense())
ASigmaA_eigvals, ASigmaA_eigvecs = np.linalg.eigh((linsys.A @ prior.cov @ linsys.A.T).todense())
ASigmaA_meas_eigvals, ASigmaA_meas_eigvecs = np.linalg.eigh((linsys.A @ prior_cond_meas.cov @ linsys.A.T).todense())

In [None]:
plt.semilogy(A_eigvals, label="$A$")
plt.semilogy(ASigmaA_eigvals, label="$A \Sigma_0 A^T$, without measurements")
plt.semilogy(ASigmaA_meas_eigvals, marker="o", markersize="2", label="$A \Sigma_0 A^T$, with measurements")
plt.xlabel("$i$")
plt.ylabel("$\lambda_{i + 1}(\cdot)$")
plt.title("Spectra of $A$ and $A \Sigma_0 A^T$")
plt.legend()
plt.show()

### $\lVert \cdot \rVert_2$-condition numbers

In [None]:
A_eigvals[-1] / A_eigvals[0]

In [None]:
ASigmaA_eigvals[-1] / ASigmaA_eigvals[0]

In [None]:
ASigmaA_meas_eigvals[-1] / ASigmaA_meas_eigvals[0]

## Analysis of Instabilities

In [None]:
actions = []

def _callback(action: np.ndarray, **kwargs):
    if action is not None:
        actions.append(action.copy())


_ = linpde_gp.linalg.solvers.bayescg(
    linsys.A,
    linsys.b,
    x0=prior_cond_meas,curl -X POST https://content.dropboxapi.com/2/files/upload \
    --header "Authorization: Bearer token" \
    --header "Dropbox-API-Arg: {\"path\": \"/Github Actions/linpde-gp/jmlr/linpde-gp.pdf\",\"mode\": \"add\",\"autorename\": true,\"mute\": true,\"strict_conflict\": false}" \
    --header "Content-Type: application/octet-stream" \
    --data-binary @linpde_gp.pdf
    maxiter=len(basis),
    # reorthogonalize=True,
    callback=_callback,
)

In [None]:
action_A_inprods = linpde_gp.linalg.pairwise_inprods(
    actions,
    inprod=linsys.A,
    normalize=True,
)

In [None]:
measurement_step_basis = [a.squeeze(1) for a in np.hsplit(ASigmaA_meas_eigvecs[:, :3], 3)]

In [None]:
measurement_A_inprods = linpde_gp.linalg.pairwise_inprods(
    measurement_step_basis,
    inprod=linsys.A,
    normalize=True,
)

measurement_action_A_inprods = linpde_gp.linalg.pairwise_inprods(
    measurement_step_basis,
    actions,
    inprod=linsys.A,
    normalize=True,
)

In [None]:
fig, ax = plt.subplots(
    nrows=2,
    ncols=2,
    figsize=(6.5, 6),
    gridspec_kw={
        "width_ratios": [1, 3],
        "height_ratios": [1, 3],
    },
)

imshow_kwargs = {
    "cmap": "bwr",
    "vmin": -1.0,
    "vmax": 1.0,
    "aspect": "auto",
    "interpolation": "nearest",
}

# Measurement-measurement
ax[0, 0].imshow(measurement_A_inprods, **imshow_kwargs)
ax[0, 0].xaxis.tick_top()

# Measurement-action
ax[0, 1].imshow(
    measurement_action_A_inprods,
    **imshow_kwargs,
)
ax[0, 1].xaxis.tick_top()
ax[0, 1].yaxis.set_ticks([])

# Action-measurement
ax[1, 0].imshow(
    measurement_action_A_inprods.T,
    **imshow_kwargs,
)
ax[1, 0].xaxis.set_ticks([])

# Action-action
ax[1, 1].imshow(
    action_A_inprods,
    **imshow_kwargs,
)
ax[1, 1].xaxis.set_ticks([])
ax[1, 1].yaxis.set_ticks([])

fig.tight_layout()

plt.show()