# Probabilistic Numerical Solution of the Linear System
\begin{equation}
    \renewcommand{\vec}[1]{\mathbf{#1}}
    \newcommand{\inprod}[2]{\left\langle #1, #2 \right\rangle}
    \newcommand{\norm}[1]{\left\lVert #1 \right\rVert}
\end{equation}

We now solve the linear system $A x = b$ using a solution-based probabilistic linear solver.
This essentially boils down to online inference in a linear Gaussian model.
We assume that we have a prior $$p(x) = \mathcal{N}(x; \mu_0, \Sigma_0)$$ over the solution of the linear system.
In each iteration of the solver, we collect information about the solution by projecting onto a one dimensional subspace defined by $s_i$: $$y_i := s_i^T b = (s_i^T A) x.$$
Let $S_m := \begin{pmatrix} s_1, \dots, s_m \end{pmatrix}$ and $\vec{y}_m := \begin{pmatrix} y_1, \dots, y_m \end{pmatrix}^T$.
We can now infer $x$ using the Dirac likelihood $$p(y \mid x) = \delta(\vec{y}_m - S_m^T A x).$$
Since this is an observation of a linear function of $x$, we can perform inference in closed form.
The posterior for step $m$ is then given by $$p(x \mid y) = \mathcal{N}(x; \mu_m, \Sigma_m),$$ with
\begin{align}
    \mu_m & := \mu_0 + \Sigma_0 A^T S_m \Lambda_m^{-1} S_m^T r_0 \\
    \Sigma_m & := \Sigma_0 - \Sigma_0 A^T S_m \Lambda_m^{-1} S_m^T A \Sigma_0,
\end{align}
where $r_0 := b - A \mu_0$ is the initial residual and $\Lambda_m := S_m^T A \Sigma_0 A^T S_m$.

To keep inference tractable, we must choose $S_m$ such that $\Lambda_m$ is diagonal.
This can be achieved by considering $A \Sigma_0 A^T$-orthonormal search directions, i.e. $\inprod{s_i}{s_j}_{A \Sigma_0 A^T} = \delta_{ij}$. In this case, we have $(\Lambda_m)_{i,j} = \delta_{ij}$ and thus
\begin{align}
    \mu_m
    & = \mu_0 + \Sigma_0 A^T S_m S_m^T r_0 \\
    & = \mu_0 + \Sigma_0 A^T \sum_{i = 1}^m s_i s_i^T r_0 \\
    & = \mu_0 + \Sigma_0 A^T \sum_{i = 1}^{m - 1} s_i s_i^T r_0 + \Sigma_0 A^T s_m s_m^T r_0 \\
    & = \mu_0 + \Sigma_0 A^T S_{m - 1} S_{m - 1}^T r_0 + \Sigma_0 A^T s_m (s_m^T b - s_m^T A \mu_0) \\
    & = \mu_{m - 1} + \Sigma_0 A^T s_m (s_m^T b - s_m^T A \mu_0 - \underbrace{s_m^T A \Sigma_0 A^T S_{m - 1}}_{= 0} S_{m - 1}^T r_0) \\
    & = \mu_{m - 1} + \Sigma_0 A^T s_m s_m^T (b - A (\mu_0 + \Sigma_0 A^T S_{m - 1} S_{m - 1}^T r_0)) \\
    & = \mu_{m - 1} + \Sigma_0 A^T s_m s_m^T \underbrace{(b - A \mu_{m - 1})}_{=: r_{m - 1}}
\end{align}

and

\begin{align}
    \Sigma_m
    & = \Sigma_0 - \Sigma_0 A^T S_m S_m^T A \Sigma_0 \\
    & = \Sigma_0 - \Sigma_0 A^T \sum_{i = 1}^m s_i s_i^T A \Sigma_0 \\
    & = \Sigma_0 - \Sigma_0 A^T \sum_{i = 1}^{m - 1} s_i s_i^T A \Sigma_0 - \Sigma_0 A^T s_i s_i^T A \Sigma_0 \\
    & = \Sigma_0 - \Sigma_0 A^T S_{m - 1} S_{m - 1}^T A \Sigma_0 - \Sigma_0 A^T s_i (\Sigma_0 A^T s_i)^T \\
    & = \Sigma_{m - 1} - \Sigma_0 A^T s_i (\Sigma_0 A^T s_i)^T.
\end{align}

We can generate $A \Sigma_0 A^T$-orthonormal search directions $s_m$ on-the-fly as in CG, i.e. $s_m = \frac{\tilde{s}_m}{\norm{\tilde{s}_m}_{A \Sigma_0 A^T}}$, $\tilde{s}_1 = r_0$ and

\begin{align}
    \tilde{s}_m = r_{m - 1} - \inprod{s_{m - 1}}{r_{m - 1}}_{A \Sigma_0 A^T} s_{m - 1}
\end{align}

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import linpde_gp

In [None]:
%matplotlib inline

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In [None]:
domain = (-1.0, 1.0)

bvp = linpde_gp.problems.pde.PoissonEquationDirichletProblem(
    domain=domain,
    rhs=linpde_gp.function.Constant(input_shape=(), value=2.0),
    boundary_values=(0.0, 0.0),
)

In [None]:
basis = linpde_gp.galerkin.bases.ZeroBoundaryFiniteElementBasis(domain, num_elements=103)
linsys = linpde_gp.galerkin.project(bvp, basis)

In [None]:
def bayescg(A, b, x0=None, maxiter=None, atol=1e-5, rtol=1e-5, callback=None):
    # Prior construction
    x_dtype = np.result_type(A.dtype, b.dtype)
    
    if isinstance(x0, pn.randvars.Normal):
        (x0, cov0) = (x0.mean, x0.cov)
        
        x0 = x0.astype(x_dtype, copy=True)
        cov0 = cov0.astype(x_dtype, copy=True)
    else:
        if isinstance(x0, np.ndarray):
            x0 = x0.astype(x_dtype, copy=True)
        else:
            assert x0 is None

            x0 = np.zeros(A.shape[1], x_dtype)
        
        cov0 = pn.linops.aslinop(A).inv()

    # Stopping Criteria
    if maxiter is None:
        maxiter = 10 * x0.size
        
    res_norm_thresh = np.maximum(rtol * np.linalg.norm(b), atol)
    
    # Callback
    if callback is None:
        callback = lambda **kwargs: None
    
    # Initialization
    x = x0
    cov = cov0
    nu = 0.0
    
    residual = b - A @ x
    prev_residual = None
    
    s = residual
    
    # Check if initialization meets stopping criteria
    stop = np.linalg.norm(residual, ord=2) < res_norm_thresh
    
    # Callback
    callback(
        iteration=0,
        x=pn.randvars.Normal(mean=x0.copy(), cov=cov0),
        residual=residual.copy(),
        stop=stop,
        action=None,
    )
    
    # Iterate
    if not stop:
        for m in range(1, maxiter + 1):
            # Update belief
            E_sq = s.T @ A @ cov0 @ A.T @ s
            alpha = residual.T @ residual / E_sq

            x += alpha * (cov0 @ A.T @ s)
            cov -= np.outer(cov0 @ A.T @ s, cov0 @ A.T @ s) / E_sq
            nu += alpha

            # Update residual
            prev_residual = residual
            residual = b - A @ x

            # Check stopping criterion
            stop = np.linalg.norm(residual, ord=2) < res_norm_thresh

            # Callback
            cov_scale = nu / m

            callback(
                iteration=m,
                x=pn.randvars.Normal(mean=x.copy(), cov=cov_scale * cov),
                cov_scale=cov_scale,
                residual=residual,
                stop=stop,
                action=s,
            )

            # Apply stopping criterion
            if stop:
                break

            # Update search direction
            beta = (residual.T @ residual) / (prev_residual.T @ prev_residual)
            s = residual + beta * s
    
    return pn.randvars.Normal(
        mean=x,
        cov=(nu / m) * cov,
    )

In [None]:
from typing import Callable, Optional, Union


def problinsolve(
    A: pn.linops.LinearOperatorLike,
    b: np.ndarray,
    prior_x: Optional[Union[pn.randvars.Normal, np.ndarray]] = None,
    auto_cov_type: str = "cg",
    max_num_steps: Optional[int] = None,
    rtol: float = 1e-5,
    atol: float = 1e-5,
    callback: Callable[..., None] = None,
):
    A = pn.linops.aslinop(A)
    
    # Prior construction
    if isinstance(prior_x, pn.randvars.Normal):
        x0 = prior_x.mean.astype(np.result_type(A.dtype, b.dtype), copy=True)
        cov0 = prior_x.cov.astype(x0.dtype, copy=True)
    else:
        if isinstance(prior_x, np.ndarray):
            x0 = prior_x.astype(np.result_type(A.dtype, b.dtype), subok=True, copy=True)
        else:
            assert prior_x is None

            x0 = np.zeros(A.shape[1], np.result_type(A.dtype, b.dtype))
            
        if auto_cov_type == "id":
            cov0 = np.eye(x0.size, dtype=x0.dtype)
        elif auto_cov_type == "cg":
            cov0 = A.inv()
        
    # Stopping Criteria
    if max_num_steps is None:
        max_num_steps = 10 * x0.size
        
    res_norm_thresh = np.maximum(rtol * np.linalg.norm(b), atol)
    
    # Callback
    if callback is None:
        callback = lambda **kwargs: None
    
    # Initialization
    step_idx = 0
    
    x = x0
    cov = cov0
    
    residual = b - A @ x
    residual_norm_sq = np.inner(residual, residual)
    residual_norm = np.sqrt(residual_norm_sq)
    
    stop = (
        step_idx >= max_num_steps
        or residual_norm < res_norm_thresh
    )
    
    callback(
        step_idx=step_idx,
        x=pn.randvars.Normal(mean=x.copy(), cov=cov),
        residual=residual.copy(),
        residual_norm_sq=residual_norm_sq,
        residual_norm=residual_norm,
        stop=stop,
        action=None,
        observation=None,
        stepdir=None,
        stepsize=None,
    )
    
    # Iteration
    action = residual

    prev_residual = None
    prev_residual_norm_sq = None
    
    while not stop:
        observation = np.inner(action, residual)
        
        # Update solution
        observation_operator = action.T @ A  # s.T @ A

        cov_xy = cov @ observation_operator.T  # Sigma @ A.T @ s
        
        gram = observation_operator @ cov_xy  # s.T @ A @ Sigma @ A.T @ s
        gram_pinv = 1.0 / gram #if gram >= 1e-10 else 0.0
        
        x += cov_xy * (gram_pinv * observation)
        cov -= np.outer(cov_xy, cov_xy) * gram_pinv
        
        # Update residual
        prev_residual = residual
        prev_residual_norm_sq = residual_norm_sq

        residual = b - A @ x
        residual_norm_sq = np.inner(residual, residual)
        residual_norm = np.sqrt(residual_norm_sq)
        
        # Check stopping criteria
        step_idx += 1
        
        stop = (
            step_idx >= max_num_steps
            or residual_norm < res_norm_thresh
        )
        
        # Callback
        callback(
            iteration=step_idx,
            x=pn.randvars.Normal(mean=x.copy(), cov=cov),
            residual=residual,
            residual_norm=residual_norm,
            stop=stop,
            action=action,
            observation=observation,
            stepdir=stepdir,
            stepsize=stepsize,
        )

        # Apply stopping criteria
        if stop:
            break
        
        # Update action
        action = residual + (residual_norm_sq / prev_residual_norm_sq) * action
    
    return pn.randvars.Normal(
        mean=x,
        cov=cov,
    )

In [None]:
solver = linpde_gp.linalg.solvers.bayescg

In [None]:
M = np.random.randn(10, 10)
A = M @ M.T + np.eye(10)
b = np.random.randn(10)

In [None]:
np.linalg.solve(A, b)

In [None]:
sol = solver(A, b)
sol.mean, sol.var

In [None]:
u_fem_coords = solver(linsys.A, linsys.b, maxiter=40)
u_fem = basis.coords2fn(u_fem_coords)

In [None]:
xs_plot = np.linspace(*domain, 200)

plt.plot(xs_plot, bvp.solution(xs_plot))
u_fem.plot(plt.gca(), xs_plot, color="C1", label="FEM Solution")
plt.show()

In [None]:
from matplotlib import animation


def animate_probsolve_poisson_1d(
    basis,
    linsys=None,
    solver=linpde_gp.linalg.solvers.bayescg,
    **solver_kwargs
):
    n = len(basis)
    
    if linsys is None:
        linsys = linpde_gp.galerkin.discretize(bvp, basis)
    
    # Run the algorithm and log step statistics
    step_xs = []
    step_residual_norms = []
    step_residual_A_norms = []
    step_actions = []

    def _callback(x: pn.randvars.Normal, residual: np.ndarray, action: np.ndarray, **kwargs):
        step_xs.append(x)
        step_residual_norms.append(np.linalg.norm(residual, ord=2))
        step_residual_A_norms.append(np.sqrt(np.inner(residual, linsys.A @ residual)))
        
        if action is not None:
            step_actions.append(action.copy())

    solver(
        linsys.A,
        linsys.b,
        callback=_callback,
        **solver_kwargs,
    )

    fig, ax = plt.subplots(ncols=3, figsize=(22, 6))
    plt.close()
    
    xs_plot = np.linspace(*domain, 200)
    
    action_inprods = linpde_gp.linalg.pairwise_inprods(
        step_actions,
        inprod=linsys.A,
        normalize=True,
    )
    
    if isinstance(basis, linpde_gp.galerkin.bases.ZeroBoundaryFiniteElementBasis):
        basis_str = "Zero Boundary FEM"
    elif isinstance(basis, linpde_gp.galerkin.bases.FiniteElementBasis):
        basis_str = "FEM"
    else:
        basis_str = "Unknown Basis"
    
    if solver is linpde_gp.linalg.solvers.bayescg:
        solver_str = "BayesCG"
    elif solver is linpde_gp.linalg.solvers.problinsolve:
        solver_str = "problinsolve"
    else:
        solver_str = "Unknown Solver"
    
    def animate(step_idx):
        ax[0].cla()
        ax[1].cla()
        ax[2].cla()

        u = basis.coords2fn(coords=step_xs[step_idx])

        fig.suptitle(f"1D Poisson - {basis_str} (N = {n}) - {solver_str} - Iteration {step_idx:03d}")

        ax[0].set_title("Solution")
        ax[0].plot(xs_plot, bvp.solution(xs_plot), label="Exact Solution")
        u.plot(ax[0], xs_plot, color="C1", label="FEM Solution")
        ax[0].legend()

        ax[1].set_title("Residual Norm")
        ax[1].semilogy(step_residual_norms[:step_idx + 1], "C0")
        ax[1].set_xlabel("Iterations")
        
#         ax[2].set_title("Residual A-norm")
#         ax[2].plot(step_residual_A_norms[:step_idx + 1], "C0")
#         ax[2].set_xlabel("Iterations")
        

        ax[2].set_title("Normalized $A$-inner products between the actions")

        if step_idx > 0:
            ax[2].imshow(
                action_inprods[:step_idx, :step_idx],
                cmap="bwr",
                vmin=-1.0,
                vmax=1.0,
            )

    return animation.FuncAnimation(
        fig,
        func=animate,
        frames=len(step_xs),
        interval=200,
        repeat_delay=4000,
        blit=False,
    )

In [None]:
from IPython.display import HTML

anim = animate_probsolve_poisson_1d(
    basis,
    linsys=linsys,
    solver=solver,
    maxiter=len(basis),
    reorthogonalize=True,
)

HTML(anim.to_jshtml())

In [None]:
anim.save("../results/fem_probsolve.gif", animation.PillowWriter(fps=5))