# Gaussian Process Solution to the Dirichlet Problem for the 1D Poisson Equation

In the following, we will solve the **Poisson equation** subject to **Dirichlet boundary conditions**, i.e. we want to find a function $u: \Omega \subset \mathbb{R}^d \to \mathbb{R}$, which fulfills

\begin{equation}
    \begin{cases}
        -\Delta u(x) = f(x) & \text{if } x \in \operatorname{int} \Omega \\
        u(x) = g(x)         & \text{if } x \in \partial \Omega,
    \end{cases}
\end{equation}

where $$\Delta := \sum_{i = 1}^D \frac{\partial^2}{\partial x_i^2}$$ is the **Laplace operator**.
For simplicity, we set $d = 1$ and $\Omega = [l, r] \subset \mathbb{R}$, which means that the problem reduces to

$$
    \begin{cases}
        -u''(x) = f(x) & \text{for } x \in (l, r) \\
        u(x) = g(x) & \text{for } x \in \{l, r\}
    \end{cases}
$$

In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn
import scipy.linalg

import linpde_gp

In [None]:
from experiment_utils import config

config.experiment_name = "0000_poisson_dirichlet_1d_naive"
config.target = "jmlr"
config.debug_mode = True

In [None]:
%matplotlib inline

In [None]:
plt.rcParams.update(config.tueplots_bundle())

## Problem Definition

In [None]:
domain = (-1.0, 1.0)

In [None]:
def diffop(f: callable, /, argnum: int = 0) -> callable:
    f_hessian = jax.hessian(f, argnums=argnum)

    @jax.jit
    def f_diffop(*args, **kwargs) -> jnp.ndarray:
        return -jnp.trace(
            jnp.atleast_2d(f_hessian(*args, **kwargs))
        )
    
    return f_diffop

In [None]:
# Boundary Values
g = np.asarray((0.0, 1.0))

# RHS
_f_const = 2.0
f = lambda x: np.full_like(x, _f_const)

# True Solution
_a = -(_f_const / 2.0)
_b = (g[1] - g[0]) / (domain[1] - domain[0])
_c = g[0]
u_star = lambda x: (_a * (x - domain[1]) + _b) * (x - domain[0]) + _c

# PDE Measurement Locations
X_pde = np.linspace(-0.8, 0.8, 3)

In [None]:
plt_grid = np.linspace(*domain, 100)

## Prior

In [None]:
u_prior = pn.randprocs.GaussianProcess(
    mean=linpde_gp.functions.Zero(input_shape=()),
    cov=2.0 ** 2 * linpde_gp.randprocs.kernels.ExpQuad(
        input_shape=(),
        lengthscales=1.0,
    ),
)

In [None]:
u_prior.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=np.random.default_rng(24),
    color="C0",
    label="$u$",
)

plt.plot(
    plt_grid,
    u_star(plt_grid),
    color="C1",
    label="$u^\star$"
)

plt.legend()
plt.show()

# Prior Predictive

In [None]:
def apply_lindiffop_to_gp(
    gp: pn.randprocs.GaussianProcess,
    lindiffop: callable,
) -> pn.randprocs.GaussianProcess:
    mean = lindiffop(gp.mean.jax, argnum=0)
    crosscov = lindiffop(gp.cov.jax, argnum=1)
    cov = lindiffop(crosscov, argnum=0)

    return pn.randprocs.GaussianProcess(
        mean=linpde_gp.functions.JaxLambdaFunction(
            mean,
            input_shape=gp.input_shape,
            output_shape=(),
            vectorize=True,
        ),
        cov=linpde_gp.randprocs.kernels.JaxLambdaKernel(
            cov,
            input_shape=gp.input_shape,
            output_shape=(),
            vectorize=True,
        ),
    )

In [None]:
Du_prior = apply_lindiffop_to_gp(u_prior, diffop)

In [None]:
Du_prior.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=np.random.default_rng(24),
    color="C0",
    label="$\mathcal{D}[u]$",
)

plt.plot(
    plt_grid,
    f(plt_grid),
    color="C1",
    label="$f$"
)

plt.legend()
plt.show()

## Conditioning on the PDE

In [None]:
def condition_gp_on_observations(
    gp: pn.randprocs.GaussianProcess,
    X: np.ndarray,
    Y: np.ndarray,
    lindiffop: callable = None,
) -> pn.randprocs.GaussianProcess:
    if lindiffop is None:
        gp_pred = gp
        crosscov = gp.cov
    else:
        gp_pred = apply_lindiffop_to_gp(gp, lindiffop)
        crosscov = linpde_gp.randprocs.kernels.JaxLambdaKernel(
            lindiffop(gp.cov.jax, argnum=1),
            input_shape=gp.input_shape,
            output_shape=gp.output_shape + (),
            vectorize=True,
        )

    gp_pred_mean_X = gp_pred.mean(X)
    gramXX = gp_pred.cov(X[:, None], X[None, :])
    gramXX_cho = scipy.linalg.cho_factor(gramXX)

    @jax.jit
    def cond_mean(x: jnp.ndarray) -> jnp.ndarray:
        gp_mean_X = gp.mean.jax(x)
        crosscov_xX = crosscov.jax(x, X)
        return gp_mean_X + crosscov_xX @ jax.scipy.linalg.cho_solve(
            gramXX_cho, (Y - gp_pred_mean_X)
        )
    
    @jax.jit
    def cond_cov(x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray:
        gp_cov_x0_x1 = gp.cov.jax(x0, x1)
        crosscov_x0_X = crosscov.jax(x0, X)
        crosscov_X_x1 = crosscov.jax(x1, X).T
        return gp_cov_x0_x1 - crosscov_x0_X @ jax.scipy.linalg.cho_solve(
            gramXX_cho, crosscov_X_x1
        )
    
    return pn.randprocs.GaussianProcess(
        mean=linpde_gp.functions.JaxLambdaFunction(
            cond_mean,
            input_shape=gp.input_shape,
            output_shape=gp.output_shape,
            vectorize=True,
        ),
        cov=linpde_gp.randprocs.kernels.JaxLambdaKernel(
            cond_cov,
            input_shape=gp.input_shape,
            output_shape=gp.output_shape,
            vectorize=True,
        ),
    )

In [None]:
Y_pde = f(X_pde)

u_cond_pde = condition_gp_on_observations(
    gp=u_prior,
    X=X_pde,
    Y=Y_pde,
    lindiffop=diffop,
)

In [None]:
u_cond_pde.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=np.random.default_rng(24),
    color="C0",
    label="$u \mid \mathcal{D}[u] = f(X_{pde})$",
)

plt.plot(
    plt_grid,
    u_star(plt_grid),
    color="C1",
    label="$u^\star$"
)

linpde_gp.utils.plotting.plot_local_curvature(
    ax=plt.gca(),
    xs=X_pde,
    f_xs=u_cond_pde.mean(X_pde),
    ddf_xs=-f(X_pde),
    df_xs=jnp.vectorize(jax.grad(u_cond_pde.mean.jax))(X_pde),
    color="C3",
    label="$f(X_{pde})$",
)

plt.legend()
plt.show()

## Posterior Predictive

In [None]:
Du_cond_pde = apply_lindiffop_to_gp(u_cond_pde, diffop)

In [None]:
Du_cond_pde.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=np.random.default_rng(24),
    color="C0",
    label="$\mathcal{D}[u] \mid \mathcal{D}[u] = f(X_{pde})$",
)

plt.plot(
    plt_grid,
    f(plt_grid),
    color="C1",
    label="$f$"
)

plt.errorbar(
    X_pde,
    Y_pde,
    yerr=0,
    fmt="+",
    capsize=2,
    c="C3",
    label="$f(X_{pde})$",
)

plt.legend()
plt.show()

## Conditioning on the Boundary Values

In [None]:
u_cond_pde_bv = condition_gp_on_observations(u_cond_pde, X=np.asarray(domain), Y=g)

In [None]:
u_cond_pde_bv.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=np.random.default_rng(24),
    color="C0",
    label="$u \mid \mathcal{D}[u] = f(X_{pde}), u|_{\partial \Omega} = g$",
)

plt.plot(
    plt_grid,
    u_star(plt_grid),
    color="C1",
    label="$u^\star$"
)

linpde_gp.utils.plotting.plot_local_curvature(
    ax=plt.gca(),
    xs=X_pde,
    f_xs=u_cond_pde_bv.mean(X_pde),
    ddf_xs=-f(X_pde),
    df_xs=jnp.vectorize(jax.grad(u_cond_pde_bv.mean.jax))(X_pde),
    color="C3",
    label="$f(X_{pde})$",
)

plt.errorbar(
    domain,
    g,
    yerr=0,
    fmt="+",
    capsize=2,
    color="C2",
    label="$g$",
)

plt.legend()
plt.show()