In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import probnum_galerkin as pngal

In [None]:
%matplotlib inline

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
import enum
from typing import Any, Callable, Optional


class Backend(enum.Enum):
    NUMPY = "numpy"
    JAX = "jax"


BACKEND = Backend.JAX
    

class BackendDispatcher:
    def __init__(
        self,
        numpy_impl: Optional[Callable[..., Any]],
        jax_impl: Optional[Callable[..., Any]]=None,
    ):
        self._impl = {}
        
        if numpy_impl is not None:
            self._impl[Backend.NUMPY] = numpy_impl
        
        if jax_impl is not None:
            self._impl[Backend.JAX] = jax_impl
        
    def __call__(self, *args, **kwargs) -> Any:
        return self._impl[BACKEND](*args, **kwargs)

In [None]:
class JaxMean:
    def __init__(self, m, vectorize=True):
        if vectorize:
            m = jax.numpy.vectorize(m, signature="(d)->()")
        
        self._dispatcher = BackendDispatcher(
            numpy_impl=lambda x: np.asarray(m(x)),
            jax_impl=m,
        )
    
    def __call__(self, x):
        return self._dispatcher(x)


class JaxKernel(pn.kernels.Kernel):
    def __init__(self, k, input_dim):
        self._k = jax.numpy.vectorize(k, signature="(d),(d)->()")
        super().__init__(input_dim)

    def _evaluate(self, x0, x1) -> np.ndarray:
        if x1 is None:
            x1 = x0
        
        kernmat = self._k(x0, x1)

        return np.array(kernmat)
    
    def jax(self, x0, x1):
        return self._k(x0, x1)
    
def laplace(f, argnum=0):
    Hf = jax.hessian(f, argnum)
    
    @jax.jit
    def _hessian_trace(*args):
        return jnp.trace(jnp.atleast_2d(Hf(*args)))
    
    return _hessian_trace

In [None]:
rng = np.random.default_rng(24)

In [None]:
grid = np.linspace(-1, 1, 100)

## Problem Definition

In [None]:
# PDE RHS
f = lambda x: np.full_like(x, -2.0)

# Probabilistic Boundary Conditions
boundary = np.array([-1.0, 1.0])
g = pn.randvars.Normal(
    mean=np.array([0.0, 1.0]),
    cov=np.diag(np.full_like(boundary, 0.0 ** 2))
)

# True Solution
u = lambda x: -x ** 2 + (g.mean[1] - g.mean[0]) / (boundary[1] - boundary[0]) * (x - boundary[0]) + 1.0

# PDE Measurements
X = np.linspace(-0.8, 0.8, 3)
fX_std = np.full_like(X, 0.0 ** 2)

In [None]:
# PDE RHS
f = lambda x: np.pi ** 2 * np.sin(np.pi * x)

# Probabilistic Boundary Conditions
boundary = np.array([-1.0, 1.0])
g = pn.randvars.Normal(
    mean=np.array([0.0, 0.0]),
    cov=np.diag(np.full_like(boundary, 0.0 ** 2))
)

# True Solution
u = lambda x: -jnp.sin(jnp.pi * x)

# PDE Measurements
X = np.linspace(-0.8, 0.8, 3)
fX_std = np.full_like(X, 0.0) ** 2

In [None]:
N = X.size

fX = pn.randvars.Normal(mean=f(X), cov=np.diag(fX_std))

## Prior

In [None]:
lengthscale = 1.0
output_scale = 2.0

@jax.jit
def prior_mean(x):
    return jnp.full_like(x[..., 0], 0.0)
#     return -0.5 * x[..., 0] ** 2 + 0.5
#     return jnp.sin(jnp.pi * x)

@jax.jit
def prior_cov(x0, x1):
    sqnorms = jnp.sum((x0 - x1) ** 2, axis=-1)

    return output_scale ** 2 * jnp.exp(-(1.0 / (2.0 * lengthscale ** 2)) * sqnorms)

prior_gp = pn.randprocs.GaussianProcess(
    mean=JaxMean(prior_mean, vectorize=False),
    cov=JaxKernel(
        prior_cov,
        input_dim=1,
    ),
)

In [None]:
prior_gp.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$u \sim \mathcal{GP}(m, k)$",
)

plt.plot(
    grid,
    u(grid),
    label="$u^*$",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_00_prior.pdf", dpi=300)
plt.show()

## Posterior (Boundary Values First)

### Conditioning on Boundary Conditions

In [None]:
def condition_gp_on_observations(gp: pn.randprocs.GaussianProcess, X: np.ndarray, fX: pn.randvars.Normal):
    mX = gp._meanfun(X)
    kXX = gp._covfun.jax(X[:, None, :], X[None, :, :]) + fX.cov
    L_kXX = jax.scipy.linalg.cho_factor(kXX)
    
    @jax.jit
    def cond_mean(x):
        mx = gp._meanfun(x)
        kxX = gp._covfun.jax(x, X)
        return mx + kxX @ jax.scipy.linalg.cho_solve(L_kXX, (fX.mean - mX))

    @jax.jit
    def cond_cov(x0, x1):
        kxx = gp._covfun.jax(x0, x1)
        kxX = gp._covfun.jax(x0, X)
        kXx = gp._covfun.jax(X, x1)
        return kxx - kxX @ jax.scipy.linalg.cho_solve(L_kXX, kXx)

    cond_gp = pn.randprocs.GaussianProcess(
        mean=JaxMean(cond_mean),
        cov=JaxKernel(
            cond_cov,
            input_dim=1,
        ),
    )
    
    return cond_gp

pn.randprocs.GaussianProcess.condition_on_observations = condition_gp_on_observations

In [None]:
u_bc = prior_gp.condition_on_observations(boundary[:, None], g)

In [None]:
u_bc.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label=r"$u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    grid,
    u(grid),
    label=r"$u^*$",
)

plt.errorbar(
    boundary,
    g.mean,
    yerr=g.std,
    fmt="+",
    capsize=2,
    color="C2",
    label=r"$g(\partial \Omega)$"
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_bcfirst_01_cond_bc.pdf", dpi=300)
plt.show()

### Predictive Induced by $\Delta$

In [None]:
def laplace_gp(gp: pn.randprocs.GaussianProcess):
    mean = laplace(gp._meanfun)
    cov = laplace(laplace(gp._covfun.jax, 1), 0)
    crosscov = jnp.vectorize(laplace(gp._covfun.jax, 1), signature="(d),(d)->()")

    lapf = pn.randprocs.GaussianProcess(
        mean=JaxMean(mean),
        cov=JaxKernel(cov, input_dim=1),
    )
    
    return lapf, crosscov

pn.randprocs.GaussianProcess.laplace = laplace_gp

In [None]:
laplace_u_bc, laplace_u_bc_crosscov = u_bc.laplace()

In [None]:
laplace_u_bc.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    grid,
    f(grid),
    color="C3",
    label="f",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_bcfirst_02_cond_bc_laplace.pdf", dpi=300)
plt.show()

### Conditioning on the PDE

In [None]:
def condition_gp_on_predictive_gp(f: pn.randprocs.GaussianProcess, Lf: pn.randprocs.GaussianProcess, jax_crosscov, X: np.ndarray, LfX: pn.randvars.Normal):
    LmX = Lf._meanfun(X)
    gramXX = Lf._covfun.jax(X[:, None, :], X[None, :, :]) + LfX.cov
    gramXX_cho = jax.scipy.linalg.cho_factor(gramXX)
    
    @jax.jit
    def pred_cond_mean(x):
        mx = f._meanfun(x)
        kLxX = jax_crosscov(x[None], X)
        return mx + kLxX @ jax.scipy.linalg.cho_solve(gramXX_cho, (LfX.mean - LmX))

    @jax.jit
    def pred_cond_cov(x0, x1):
        kxx = f._covfun.jax(x0, x1)
        kLxX = jax_crosscov(x0, X)
        LkXx = jax_crosscov(x1, X).T
        return kxx - kLxX @ jax.scipy.linalg.cho_solve(gramXX_cho, LkXx)

    cond_gp = pn.randprocs.GaussianProcess(
        mean=JaxMean(pred_cond_mean),
        cov=JaxKernel(pred_cond_cov, input_dim=f.input_dim),
    )
    
    return cond_gp

pn.randprocs.GaussianProcess.condition_on_predictive_gp = condition_gp_on_predictive_gp

In [None]:
u_bc_pde = u_bc.condition_on_predictive_gp(laplace_u_bc, laplace_u_bc_crosscov, X[:, None], fX)

In [None]:
u_bc_pde.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$u \mid u(\partial \Omega) = g(\partial \Omega), \Delta u(x_i) = f(x_i)$",
)

plt.plot(
    grid,
    u(grid),
    color="C1",
    label="$u^*$",
)

plt.errorbar(
    boundary,
    g.mean,
    yerr=g.std,
    fmt="+",
    capsize=2,
    color="C2",
    label=r"$g(\partial \Omega)$"
)

pngal.plotting.plot_local_curvature(
    plt.gca(),
    xs=X,
    f_xs=u_bc_pde.mean(X[:, None]),
    ddf_xs=fX,
    df_xs=jnp.vectorize(jax.grad(u_bc_pde._meanfun), signature="(d)->(d)")(X[:, None])[:, 0],
    color="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_bcfirst_03_00_cond_bc_pde.pdf", dpi=300)
plt.show()

### Posterior Predictive

In [None]:
lalace_u_bc_pde, _ = u_bc_pde.laplace()

In [None]:
lalace_u_bc_pde.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega), \Delta u(x_i) = f(x_i)$"
)


plt.errorbar(
    X,
    fX.mean,
    yerr=fX.std,
    fmt="+",
    capsize=2,
    c="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.plot(
    grid,
    f(grid),
    c="C3",
    label="f",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_bcfirst_03_01_cond_bc_pde_laplace.pdf", dpi=300)
plt.show()

## Posterior (PDE First)

### Predictive Induced by $\Delta$

In [None]:
laplace_u, laplace_u_crosscov = prior_gp.laplace()

In [None]:
laplace_u.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    grid,
    f(grid),
    color="C3",
    label="f",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_pdefirst_01_prior_laplace.pdf", dpi=300)
plt.show()

### Conditioning on the PDE

In [None]:
u_pde = prior_gp.condition_on_predictive_gp(laplace_u, laplace_u_crosscov, X[:, None], fX)

In [None]:
u_pde.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$u \mid \Delta u(x_i) = f(x_i)$"
)

plt.plot(
    grid,
    u(grid),
    color="C1",
    label="$u^*$",
)

pngal.plotting.plot_local_curvature(
    plt.gca(),
    xs=X,
    f_xs=u_pde.mean(X[:, None]),
    ddf_xs=fX,
    df_xs=jnp.vectorize(jax.grad(u_pde._meanfun), signature="(d)->(d)")(X[:, None])[:, 0],
    color="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_pdefirst_02_00_cond_pde.pdf", dpi=300)
plt.show()

### Posterior Predictive

In [None]:
laplace_u_pde, _ = u_pde.laplace()

In [None]:
laplace_u_pde.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid \Delta u(x_i) = f(x_i)$"
)


plt.errorbar(
    X,
    fX.mean,
    yerr=fX.std,
    fmt="+",
    capsize=2,
    c="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.plot(
    grid,
    f(grid),
    c="C3",
    label="f",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_pdefirst_02_01_cond_pde_laplace.pdf", dpi=300)
plt.show()

### Conditioning on the Boundary Conditions

In [None]:
u_pde_bc = u_pde.condition_on_observations(boundary[:, None], g)

In [None]:
u_pde_bc.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label=r"$u \mid \Delta u(X) = f(X), u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    grid,
    u(grid),
    color="C1",
    label="$u^*$",
)

plt.errorbar(
    boundary,
    g.mean,
    yerr=g.std,
    fmt="+",
    capsize=2,
    color="C2",
    label=r"$g(\partial \Omega)$"
)

pngal.plotting.plot_local_curvature(
    plt.gca(),
    xs=X,
    f_xs=u_pde_bc.mean(X[:, None]),
    ddf_xs=fX,
    df_xs=jnp.vectorize(jax.grad(u_pde_bc._meanfun), signature="(d)->(d)")(X[:, None])[:, 0],
    color="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_pdefirst_03_00_cond_pde_bc.pdf", dpi=300)
plt.show()

### Posterior Predictive with PDE and Boundary Conditions

In [None]:
laplace_u_pde_bc, _ = u_pde_bc.laplace()

In [None]:
laplace_u_pde.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid \Delta u(x_i) = f(x_i), u(\partial \Omega) = g(\partial \Omega)$"
)

plt.errorbar(
    X,
    fX.mean,
    yerr=fX.std,
    fmt="+",
    capsize=2,
    c="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.plot(
    grid,
    f(grid),
    c="C3",
    label="f",
)

plt.legend()
plt.savefig("../results/0006_poisson_rbf_pdefirst_03_01_cond_pde_bc_laplace.pdf", dpi=300)
plt.show()