# Conditioning the Linear System Prior on Observations of the PDE Solution

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn
import scipy.sparse

import probnum_galerkin

In [None]:
%matplotlib inline

from IPython.display import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In the following, we will again look at the Poisson equation on a 1D domain with Dirichlet boundary conditions.

In [None]:
domain = (-1.0, 1.0)

bvp = probnum_galerkin.problems.PoissonEquation(
    domain=domain,
    rhs=2.0,
    boundary_condition=probnum_galerkin.problems.DirichletBoundaryCondition(
        domain,
        (0.0, 0.0),
    ),
)

In [None]:
basis = probnum_galerkin.bases.ZeroBoundaryFiniteElementBasis(domain, num_elements=103)
linsys = bvp.discretize(basis)

If we have (noisy) measurements of the solution of the PDE, we can use the information to speed up inference.

Let $(v_i)_{i = 1}^n$ be the chosen basis.
In our formulation, we posit a multivariate Gaussian prior over the coefficients $\vec{a} \in \mathbb{R}^n$ of the discretized solution $\hat{u} = \sum_{i = 1}^n a_i v_i$ to the PDE, i.e. $\vec{a} \sim \mathcal{N}(\mu_0, \Sigma_0)$.
We can relate the discretized solution $\hat{u}$ to the coefficients by a linear operator $$(\mathcal{L}_u \vec{a})(x) = \sum_{i = 1}^n a_i v_i(x).$$
Moreover, the solution can be evaluated at several locations $x_1, \dotsc, x_m \in \Omega$ by another linear operator $$(\mathcal{L}_\delta u)_j = \int_\Omega \delta(\chi - x_j) u(\chi) d \chi = u(x_j).$$
All in all, we obtain the following linear operator which maps $\vec{a}$ to a vector of measurements at $x_1, \dotsc, x_m \in \Omega$: $$(L_y \vec{a})_j = (\mathcal{L}_\delta \mathcal{L}_u \vec{a})_j = \int_\Omega \delta(\chi - x_j) (\mathcal{L}_u \vec{a})(\chi) d \chi = \sum_{i = 1}^n a_i \int_\Omega \delta(\chi - x_j) v_i(\chi) d\chi = \sum_{i = 1}^n a_i v_i(x_j)$$
If we now assume additive Gaussian measurement noise on independent observations $y_1, \dotsc, y_m$ of the solution at locations $x_1, \dotsc, x_m \in \Omega$, we obtain the following measurement likelihood:
$$p(y_1, \dots, y_m \mid u(x_1), \dotsc, u(x_m)) = \mathcal{N}(\vec{y} \mid \begin{pmatrix} u(x_1), \dotsc, u(x_m) \end{pmatrix}^T, \Lambda),$$
or, equivalently,
$$p(y_1, \dots, y_m \mid \vec{a}) = \mathcal{N}(\vec{y} \mid L_\vec{y} \vec{a}, \Lambda).$$
Since the model is linear-Gaussian, we can compute the posterior in closed form.
Note that this is exactly the supervised regression setting.

In [None]:
num_measurements = 3

In [None]:
# Measure the solution at equidistant interior points
meas_xs = np.linspace(*domain, num_measurements + 2)[1:-1]
true_ys = bvp.solution(meas_xs)

# Add measurement noise
measurement_noise = pn.randvars.Normal(
    mean=np.zeros(num_measurements, dtype=np.double),
    cov=pn.linops.Scaling((1e-2) ** 2, shape=num_measurements, dtype=np.double),
)

meas_ys = true_ys + measurement_noise.sample()

In [None]:
meas_ys - true_ys

In [None]:
# Build the prior
prior = pn.randvars.Normal(
    mean=np.zeros(len(basis), dtype=np.double),
    cov=pn.linops.aslinop(linsys.A.A.tocsc()).inv(),
)

np.linalg.cond(prior.dense_cov)

In [None]:
# Build the observation operator
L_yu = basis.observation_operator(meas_xs)

In [None]:
# Build the noise model
noise_model = pn.randvars.Normal(
    mean=np.zeros(num_measurements, dtype=np.double),
    cov=measurement_noise.cov,
    # cov=1e2 * measurement_noise.cov,
)

In [None]:
# Condition the prior on the measurements
prior_cond_meas = probnum_galerkin.inference.linear_gaussian_model(
    prior=prior,
    A=L_yu,
    measurement_noise=noise_model,
    measurements=meas_ys,
)

np.linalg.cond(prior_cond_meas.dense_cov)

In [None]:
xs_plot = np.linspace(*domain, 200)

u_prior_cond_meas = basis.coords2fn(prior_cond_meas)

u_prior_cond_meas.plot(plt.gca(), xs_plot)
plt.scatter(meas_xs, meas_ys, marker="+")
plt.show()

In [None]:
from matplotlib import animation

def animate_probsolve_poisson_1d(
    basis,
    linsys=None,
    solver=probnum_galerkin.solvers.bayescg,
    **solver_kwargs
):
    n = len(basis)
    
    if linsys is None:
        linsys = bvp.discretize(basis)
    
    # Run the algorithm and log step statistics
    step_xs = []
    step_residual_norms = []
    step_residual_A_norms = []

    def _callback(x: pn.randvars.Normal, residual: np.ndarray, **kwargs):
        step_xs.append(x)
        step_residual_norms.append(np.linalg.norm(residual, ord=2))
        step_residual_A_norms.append(np.sqrt(np.inner(residual, linsys.A @ residual)))

    solver(
        linsys.A,
        linsys.b,
        callback=_callback,
        **solver_kwargs,
    )

    fig, ax = plt.subplots(ncols=3, figsize=(22, 6))
    plt.close()
    
    xs_plot = np.linspace(*domain, 200)
    
    if isinstance(basis, probnum_galerkin.bases.ZeroBoundaryFiniteElementBasis):
        basis_str = "Zero Boundary FEM"
    elif isinstance(basis, probnum_galerkin.bases.FiniteElementBasis):
        basis_str = "FEM"
    else:
        basis_str = "Unknown Basis"
    
    if solver is probnum_galerkin.solvers.bayescg:
        solver_str = "BayesCG"
    elif solver is probnum_galerkin.solvers.problinsolve:
        solver_str = "problinsolve"
    else:
        solver_str = "Unknown Solver"
    
    def animate(step_idx):
        ax[0].cla()
        ax[1].cla()
        ax[2].cla()

        u = basis.coords2fn(coords=step_xs[step_idx])

        fig.suptitle(f"1D Poisson - {basis_str} (N = {n}) - {solver_str} - Iteration {step_idx:03d}")

        ax[0].set_title("Solution")
        ax[0].plot(xs_plot, bvp.solution(xs_plot), label="Exact Solution")
        u.plot(ax[0], xs_plot, color="C1", label="FEM Solution")
        ax[0].legend()

        ax[1].set_title("Residual Norm")
        ax[1].plot(step_residual_norms[:step_idx + 1], "C0")
        ax[1].set_xlabel("Iterations")
        
        ax[2].set_title("Residual A-norm")
        ax[2].plot(step_residual_A_norms[:step_idx + 1], "C0")
        ax[2].set_xlabel("Iterations")

    return animation.FuncAnimation(
        fig,
        func=animate,
        frames=len(step_xs),
        interval=200,
        repeat_delay=4000,
        blit=False,
    )

In [None]:
from IPython.display import HTML

anim = animate_probsolve_poisson_1d(
    basis,
    linsys=linsys,
    solver=probnum_galerkin.solvers.bayescg,
#     solver=probnum_galerkin.solvers.problinsolve,
    x0=prior_cond_meas,
#     x0=pn.randvars.Normal(
#         mean=prior_cond_meas.mean,
#         cov=prior.cov,
#     ),
    maxiter=len(basis),
#     reorthogonalize=True,
)

HTML(anim.to_jshtml())

In [None]:
anim.save("../results/fem_probsolve_data.gif", animation.PillowWriter(fps=5))

In [None]:
eigvals, eigvecs = np.linalg.eigh((linsys.A @ prior_cond_meas.cov @ linsys.A.T).todense())

In [None]:
eigvals

In [None]:
plt.semilogy(np.sort(linsys.A.eigvals()))
plt.semilogy(np.sort((linsys.A @ prior_cond_meas.cov @ linsys.A).eigvals()))

In [None]:
spec = np.sort((linsys.A @ prior_cond_meas.cov @ linsys.A).eigvals())
spec.min(), spec.max(), spec.max() / spec.min()

In [None]:
spec = np.sort(linsys.A.eigvals())
spec.min(), spec.max()