# Probabilistic Numerical Solution of the Linear System
\begin{equation}
    \renewcommand{\vec}[1]{\mathbf{#1}}
    \newcommand{\inprod}[2]{\left\langle #1, #2 \right\rangle}
    \newcommand{\norm}[1]{\left\lVert #1 \right\rVert}
\end{equation}

We now solve the linear system $A x = b$ using a solution-based probabilistic linear solver.
This essentially boils down to online inference in a linear Gaussian model.
We assume that we have a prior $$p(x) = \mathcal{N}(x; \mu_0, \Sigma_0)$$ over the solution of the linear system.
In each iteration of the solver, we collect information about the solution by projecting onto a one dimensional subspace defined by $s_i$: $$y_i := s_i^T b = (s_i^T A) x.$$
Let $S_m := \begin{pmatrix} s_1, \dots, s_m \end{pmatrix}$ and $\vec{y}_m := \begin{pmatrix} y_1, \dots, y_m \end{pmatrix}^T$.
We can now infer $x$ using the Dirac likelihood $$p(y \mid x) = \delta(\vec{y}_m - S_m^T A x).$$
Since this is an observation of a linear function of $x$, we can perform inference in closed form.
The posterior for step $m$ is then given by $$p(x \mid y) = \mathcal{N}(x; \mu_m, \Sigma_m),$$ with
\begin{align}
    \mu_m & := \mu_0 + \Sigma_0 A^T S_m \Lambda_m^{-1} S_m^T r_0 \\
    \Sigma_m & := \Sigma_0 - \Sigma_0 A^T S_m \Lambda_m^{-1} S_m^T A \Sigma_0,
\end{align}
where $r_0 := b - A \mu_0$ is the initial residual and $\Lambda_m := S_m^T A \Sigma_0 A^T S_m$.

To keep inference tractable, we must choose $S_m$ such that $\Lambda_m$ is diagonal.
This can be achieved by considering $A \Sigma_0 A^T$-orthonormal search directions, i.e. $\inprod{s_i}{s_j}_{A \Sigma_0 A^T} = \delta_{ij}$. In this case, we have $(\Lambda_m)_{i,j} = \delta_{ij}$ and thus
\begin{align}
    \mu_m
    & = \mu_0 + \Sigma_0 A^T S_m S_m^T r_0 \\
    & = \mu_0 + \Sigma_0 A^T \sum_{i = 1}^m s_i s_i^T r_0 \\
    & = \mu_0 + \Sigma_0 A^T \sum_{i = 1}^{m - 1} s_i s_i^T r_0 + \Sigma_0 A^T s_m s_m^T r_0 \\
    & = \mu_0 + \Sigma_0 A^T S_{m - 1} S_{m - 1}^T r_0 + \Sigma_0 A^T s_m (s_m^T b - s_m^T A \mu_0) \\
    & = \mu_{m - 1} + \Sigma_0 A^T s_m (s_m^T b - s_m^T A \mu_0 - \underbrace{s_m^T A \Sigma_0 A^T S_{m - 1}}_{= 0} S_{m - 1}^T r_0) \\
    & = \mu_{m - 1} + \Sigma_0 A^T s_m s_m^T (b - A (\mu_0 + \Sigma_0 A^T S_{m - 1} S_{m - 1}^T r_0)) \\
    & = \mu_{m - 1} + \Sigma_0 A^T s_m s_m^T \underbrace{(b - A \mu_{m - 1})}_{=: r_{m - 1}}
\end{align}

and

\begin{align}
    \Sigma_m
    & = \Sigma_0 - \Sigma_0 A^T S_m S_m^T A \Sigma_0 \\
    & = \Sigma_0 - \Sigma_0 A^T \sum_{i = 1}^m s_i s_i^T A \Sigma_0 \\
    & = \Sigma_0 - \Sigma_0 A^T \sum_{i = 1}^{m - 1} s_i s_i^T A \Sigma_0 - \Sigma_0 A^T s_i s_i^T A \Sigma_0 \\
    & = \Sigma_0 - \Sigma_0 A^T S_{m - 1} S_{m - 1}^T A \Sigma_0 - \Sigma_0 A^T s_i (\Sigma_0 A^T s_i)^T \\
    & = \Sigma_{m - 1} - \Sigma_0 A^T s_i (\Sigma_0 A^T s_i)^T.
\end{align}

We can generate $A \Sigma_0 A^T$-orthonormal search directions $s_m$ on-the-fly as in CG, i.e. $s_m = \frac{\tilde{s}_m}{\norm{\tilde{s}_m}_{A \Sigma_0 A^T}}$, $\tilde{s}_1 = r_0$ and

\begin{align}
    \tilde{s}_m = r_{m - 1} - \inprod{s_{m - 1}}{r_{m - 1}}_{A \Sigma_0 A^T} s_{m - 1}
\end{align}

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn
import scipy.sparse

import probnum_galerkin

In [None]:
%matplotlib inline

from IPython.display import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In [None]:
domain = (-1.0, 1.0)

bvp = probnum_galerkin.problems.PoissonEquation(
    domain=domain,
    rhs=2.0,
    boundary_condition=probnum_galerkin.problems.DirichletBoundaryCondition(
        domain,
        (0.0, 0.0),
    ),
)

In [None]:
basis = probnum_galerkin.bases.ZeroBoundaryFiniteElementBasis(domain, num_elements=103)
linsys = bvp.discretize(basis)

In [None]:
def bayescg(A, b, x0=None, maxiter=None, atol=1e-5, rtol=1e-5, callback=None):
    # Prior construction
    x_dtype = np.result_type(A.dtype, b.dtype)
    
    if isinstance(x0, pn.randvars.Normal):
        (x0, cov0) = (x0.mean, x0.cov)
        
        x0 = x0.astype(x_dtype, copy=True)
        cov0 = cov0.astype(x_dtype, copy=True)
    else:
        if isinstance(x0, np.ndarray):
            x0 = x0.astype(x_dtype, copy=True)
        else:
            assert x0 is None

            x0 = np.zeros(A.shape[1], x_dtype)
        
        cov0 = pn.linops.aslinop(A).inv()
        
    # Stopping Criteria
    if maxiter is None:
        maxiter = 10 * x0.size
        
    res_norm_thresh = np.maximum(rtol * np.linalg.norm(b), atol)
    
    # Callback
    if callback is None:
        callback = lambda **kwargs: None
    
    # Initialization
    x = x0
    cov = cov0
    nu = 0.0
    
    residual = b - A @ x
    prev_residual = None
    
    s = residual
    
    # Check if initialization meets stopping criteria
    stop = np.linalg.norm(residual, ord=2) < res_norm_thresh
    
    # Callback
    callback(
        iteration=0,
        x=pn.randvars.Normal(mean=x0.copy(), cov=cov0),
        residual=residual.copy(),
        stop=stop,
        action=None,
    )
    
    # Iterate
    if not stop:
        for m in range(1, maxiter + 1):
            # Update belief
            E_sq = s.T @ A @ cov0 @ A.T @ s
            alpha = residual.T @ residual / E_sq

            x += alpha * (cov0 @ A.T @ s)
            cov -= np.outer(cov0 @ A.T @ s, cov0 @ A.T @ s) / E_sq
            nu += alpha

            # Update residual
            prev_residual = residual
            residual = b - A @ x

            # Check stopping criterion
            stop = np.linalg.norm(residual, ord=2) < res_norm_thresh

            # Callback
            cov_scale = nu / m

            callback(
                iteration=m,
                x=pn.randvars.Normal(mean=x.copy(), cov=cov_scale * cov),
                cov_scale=cov_scale,
                residual=residual,
                stop=stop,
                action=s,
            )

            # Apply stopping criterion
            if stop:
                break

            # Update search direction
            beta = (residual.T @ residual) / (prev_residual.T @ prev_residual)
            s = residual + beta * s
    
    return pn.randvars.Normal(
        mean=x,
        cov=(nu / m) * cov,
    )

In [None]:
from probnum_galerkin.solvers import bayescg

In [None]:
from typing import Callable, Optional, Union


def problinsolve(
    A: pn.linops.LinearOperatorLike,
    b: np.ndarray,
    prior_x: Optional[Union[pn.randvars.Normal, np.ndarray]] = None,
    auto_cov_type: str = "cg",
    max_num_steps: Optional[int] = None,
    rtol: float = 1e-5,
    atol: float = 1e-5,
    callback: Callable[..., None] = None,
):
    A = pn.linops.aslinop(A)
    
    # Prior construction
    if isinstance(prior_x, pn.randvars.Normal):
        x0 = prior_x.mean.astype(np.result_type(A.dtype, b.dtype), copy=True)
        cov0 = prior_x.cov.astype(x0.dtype, copy=True)
    else:
        if isinstance(prior_x, np.ndarray):
            x0 = prior_x.astype(np.result_type(A.dtype, b.dtype), subok=True, copy=True)
        else:
            assert prior_x is None

            x0 = np.zeros(A.shape[1], np.result_type(A.dtype, b.dtype))
            
        if auto_cov_type == "id":
            cov0 = np.eye(x0.size, dtype=x0.dtype)
        elif auto_cov_type == "cg":
            cov0 = A.inv()
        
    # Stopping Criteria
    if max_num_steps is None:
        max_num_steps = 10 * x0.size
        
    res_norm_thresh = np.maximum(rtol * np.linalg.norm(b), atol)
    
    # Callback
    if callback is None:
        callback = lambda **kwargs: None
    
    # Initialization
    step_idx = 0
    
    x = x0
    cov = cov0
    
    residual = b - A @ x
    residual_norm_sq = np.inner(residual, residual)
    residual_norm = np.sqrt(residual_norm_sq)
    
    stop = (
        step_idx >= max_num_steps
        or residual_norm < res_norm_thresh
    )
    
    callback(
        step_idx=step_idx,
        x=pn.randvars.Normal(mean=x.copy(), cov=cov),
        residual=residual.copy(),
        residual_norm_sq=residual_norm_sq,
        residual_norm=residual_norm,
        stop=stop,
        action=None,
        observation=None,
        stepdir=None,
        stepsize=None,
    )
    
    action = residual
    
    # Iteration
    prev_residual = None
    prev_residual_norm_sq = None
    
    while not stop:
        observation = np.inner(action, residual)
        
        matvec = A @ action
        
        # Update solution
        stepdir = cov @ matvec
        
        gram = np.inner(matvec, stepdir)
        # gram_pinv = 1.0 / gram if gram >= 12 ** -7 else 0.0
        gram_pinv = 1.0 / gram
        
        stepsize = gram_pinv * observation
        
        x += stepsize * stepdir
        cov -= np.outer(stepdir, stepdir) * gram_pinv
        
        # Update residual
        prev_residual = residual
        prev_residual_norm_sq = residual_norm_sq

        residual = b - A @ x
        residual_norm_sq = np.inner(residual, residual)
        residual_norm = np.sqrt(residual_norm_sq)
        
        # Check stopping criteria
        step_idx += 1
        
        stop = (
            step_idx >= max_num_steps
            or residual_norm < res_norm_thresh
        )
        
        # Callback
        callback(
            step_idx=step_idx,
            x=pn.randvars.Normal(mean=x.copy(), cov=cov),
            residual=residual,
            residual_norm=residual_norm,
            stop=stop,
            action=action,
            observation=observation,
            stepdir=stepdir,
            stepsize=stepsize,
        )

        # Apply stopping criteria
        if stop:
            break
        
        # Update action
        action = residual + (residual_norm_sq / prev_residual_norm_sq) * action
    
    return pn.randvars.Normal(
        mean=x,
        cov=cov,
    )

In [None]:
M = np.random.randn(10, 10)
A = M @ M.T + np.eye(10)
b = np.random.randn(10)

In [None]:
np.linalg.solve(A, b)

In [None]:
# sol = bayescg(pn.problems.LinearSystem(A, b))
sol = bayescg(A, b)
sol.mean, sol.var

In [None]:
u_fem_coords = bayescg(linsys.A, linsys.b, maxiter=40, rtol=0, atol=0)
u_fem = basis.coords2fn(u_fem_coords)

In [None]:
xs_plot = np.linspace(*domain, 200)

mean = u_fem.mean(xs_plot)
std = u_fem.std(xs_plot[:, None])

plt.plot(xs_plot, bvp.solution(xs_plot))
plt.plot(xs_plot, mean, c="C1")
plt.fill_between(
    xs_plot,
    mean - 1.96 * std,
    mean + 1.96 * std,
    alpha=0.1,
    color="C1"
)

In [None]:
from matplotlib import animation

def animate_poisson_1d_bayescg(basis, linsys=None, **bayescg_kwargs):
    n = len(basis)
    
    if linsys is None:
        linsys = bvp.discretize(basis)
    
    # Run the algorithm and log step statistics
    step_xs = []
    step_residual_norms = []
    step_residual_A_norms = []

    def _callback(x: pn.randvars.Normal, residual: np.ndarray, **kwargs):
        step_xs.append(x)
        step_residual_norms.append(np.linalg.norm(residual, ord=2))
        step_residual_A_norms.append(np.sqrt(np.inner(residual, linsys.A @ residual)))

    bayescg(
        linsys.A,
        linsys.b,
        callback=_callback,
        **bayescg_kwargs
    )

    fig, ax = plt.subplots(ncols=3, figsize=(22, 6))
    plt.close()
    
    xs_plot = np.linspace(*domain, 200)

    def animate(step_idx):
        ax[0].cla()
        ax[1].cla()
        ax[2].cla()

        u = basis.coords2fn(coords=step_xs[step_idx])
        
        u_mean_plot = u.mean(xs_plot)
        u_std_plot = u.std(xs_plot[:, None])

        fig.suptitle(f"1D Poisson - FEM (N = {n}) - BayesCG - Iteration {step_idx:03d}")

        ax[0].set_title("Solution")
        #ax[0].set_ylim(-1.3, np.max(mean + 2 * std) + 0.1)
        ax[0].plot(xs_plot, bvp.solution(xs_plot), label="Exact Solution")
        ax[0].plot(xs_plot, u_mean_plot, c="C1", label="FEM Solution")
        ax[0].fill_between(
            xs_plot,
            u_mean_plot - 2 * u_std_plot,
            u_mean_plot + 2 * u_std_plot,
            color="C1",
            alpha=0.2,
        )
        ax[0].legend()

        ax[1].set_title("Residual Norm")
        ax[1].plot(step_residual_norms[:step_idx + 1], "C0", label="residual norm")
        # ax[1].legend(loc="upper right")
        ax[1].set_xlabel("Iterations")
        
        ax[2].set_title("Residual A-norm")
        ax[2].plot(step_residual_A_norms[:step_idx + 1], "C0", label="residual norm")
        ax[2].set_xlabel("Iterations")

    return animation.FuncAnimation(
        fig,
        func=animate,
        frames=len(step_xs),
        interval=200,
        repeat_delay=4000,
        blit=False,
    )

In [None]:
from IPython.display import HTML

anim = animate_poisson_1d_bayescg(
    basis,
    maxiter=len(basis),
)

HTML(anim.to_jshtml())

In [None]:
anim.save("../results/fem_probsolve.gif", animation.PillowWriter(fps=5))

## Conditioning the Prior on Observations of the Solution

u_prior_cond_measIf we have (noisy) measurements of the solution of the PDE, we can use the information to speed up inference.

Let $(v_i)_{i = 1}^n$ be the chosen basis.
In our formulation, we posit a multivariate Gaussian prior over the coefficients $\vec{a} \in \mathbb{R}^n$ of the discretized solution $\hat{u} = \sum_{i = 1}^n a_i v_i$ to the PDE, i.e. $\vec{a} \sim \mathcal{N}(\mu_0, \Sigma_0)$.
We can relate the discretized solution $\hat{u}$ to the coefficients by a linear operator $$(\mathcal{L}_u \vec{a})(x) = \sum_{i = 1}^n a_i v_i(x).$$
Moreover, the solution can be evaluated at several locations $x_1, \dotsc, x_m \in \Omega$ by another linear operator $$(\mathcal{L}_\delta u)_j = \int_\Omega \delta(\chi - x_j) u(\chi) d \chi = u(x_j).$$
All in all, we obtain the following linear operator which maps $\vec{a}$ to a vector of measurements at $x_1, \dotsc, x_m \in \Omega$: $$(L_y \vec{a})_j = (\mathcal{L}_\delta \mathcal{L}_u \vec{a})_j = \int_\Omega \delta(\chi - x_j) (\mathcal{L}_u \vec{a})(\chi) d \chi = \sum_{i = 1}^n a_i \int_\Omega \delta(\chi - x_j) v_i(\chi) d\chi = \sum_{i = 1}^n a_i v_i(x_j)$$
If we now assume additive Gaussian measurement noise on independent observations $y_1, \dotsc, y_m$ of the solution at locations $x_1, \dotsc, x_m \in \Omega$, we obtain the following measurement likelihood:
$$p(y_1, \dots, y_m \mid u(x_1), \dotsc, u(x_m)) = \mathcal{N}(\vec{y} \mid \begin{pmatrix} u(x_1), \dotsc, u(x_m) \end{pmatrix}^T, \Lambda),$$
or, equivalently,
$$p(y_1, \dots, y_m \mid \vec{a}) = \mathcal{N}(\vec{y} \mid L_\vec{y} \vec{a}, \Lambda).$$
Since the model is linear-Gaussian, we can compute the posterior in closed form.
Note that this is exactly the supervised regression setting.

In [None]:
num_measurements = 3

In [None]:
# Measure the solution at equidistant interior points
meas_xs = np.linspace(*domain, num_measurements + 2)[1:-1]
true_ys = bvp.solution(meas_xs)

# Add measurement noise
measurement_noise = pn.randvars.Normal(
    mean=np.zeros(num_measurements, dtype=np.double),
    cov=pn.linops.Scaling((1e-2) ** 2, shape=num_measurements, dtype=np.double),
)

meas_ys = true_ys + measurement_noise.sample()

In [None]:
meas_ys - true_ys

In [None]:
# Build the prior
prior = pn.randvars.Normal(
    mean=np.zeros(len(basis), dtype=np.double),
    cov=2.0 * linsys.A.inv(),
)

np.linalg.cond(prior.dense_cov)

In [None]:
# Build the observation operator
L_yu = basis.observation_operator(meas_xs)

In [None]:
# Build the noise model
noise_model = pn.randvars.Normal(
    mean=np.zeros(num_measurements, dtype=np.double),
    cov=measurement_noise.cov,
    # cov=1e2 * measurement_noise.cov,
)

In [None]:
# Condition the prior on the measurements
prior_cond_meas = probnum_galerkin.inference.linear_gaussian_model(
    prior=prior,
    A=L_yu,
    measurement_noise=noise_model,
    measurements=meas_ys,
)

np.linalg.cond(prior_cond_meas.dense_cov)

In [None]:
xs_plot = np.linspace(*domain, 200)

u_prior_cond_meas = basis.coords2fn(prior_cond_meas)(xs_plot[:, None])

plt.plot(xs_plot, u_prior_cond_meas.mean)
plt.fill_between(
    xs_plot,
    np.squeeze(u_prior_cond_meas.mean - 1.96 * u_prior_cond_meas.std),
    np.squeeze(u_prior_cond_meas.mean + 1.96 * u_prior_cond_meas.std),
    alpha=0.1
)
plt.scatter(meas_xs, meas_ys, marker="+")
# plt.errorbar(meas_xs, meas_ys, yerr=2 * measurement_noise.std, marker="+", linestyle="", capsize=2)
plt.show()

In [None]:
from IPython.display import HTML

anim = animate_poisson_1d_bayescg(
    basis,
    linsys=linsys,
    x0=prior_cond_meas,
    maxiter=len(basis),
    #atol=0,
    #rtol=0,
    #reorthogonalize=True,
)

HTML(anim.to_jshtml())

In [None]:
anim.save("../results/fem_probsolve_data.gif", animation.PillowWriter(fps=5))

In [None]:
plt.semilogy(np.sort(linsys.A.eigvals()))
plt.semilogy(np.sort((linsys.A @ prior_cond_meas.cov @ linsys.A).eigvals()))

In [None]:
spec = np.sort((linsys.A @ prior_cond_meas.cov @ linsys.A).eigvals())
spec.min(), spec.max(), spec.max() / spec.min()

In [None]:
spec = np.sort(linsys.A.eigvals())
spec.min(), spec.max()