In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import probnum_galerkin as pn_gal

In [None]:
%matplotlib inline

from IPython.display import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
class JaxMean:
    def __init__(self, m, vectorize=True):
        self._m = jax.numpy.vectorize(m, signature="(d)->()") if vectorize else m
    
    def __call__(self, x):
        return np.asarray(self._m(x))
    
    def jax(self, x):
        return self._m(x)


class JaxKernel(pn.kernels.Kernel):
    def __init__(self, k, input_dim):
        self._k = jax.numpy.vectorize(k, signature="(d),(d)->()")
        super().__init__(input_dim)

    def _evaluate(self, x0, x1) -> np.ndarray:
        if x1 is None:
            x1 = x0
        
        kernmat = self._k(x0, x1)

        return np.array(kernmat)
    
    def jax(self, x0, x1):
        return self._k(x0, x1)
    
def laplace(f, argnum=0):
    Hf = jax.hessian(f, argnum)
    
    @jax.jit
    def _hessian_trace(*args):
        return jnp.trace(jnp.atleast_2d(Hf(*args)))
    
    return _hessian_trace

In [None]:
rng = np.random.default_rng(24)

In [None]:
grid = np.linspace(-1, 1, 100)

## Problem Definition

In [None]:
f = lambda x: np.full_like(x, -2.0)

boundary_points = np.array([-1.0, 1.0])
g = np.array([0.0, 1.0])
# g_std = np.array([0.2, 0.2])
g_std = np.zeros(2)

u = lambda x: -x ** 2 + (g[1] - g[0]) / (boundary_points[1] - boundary_points[0]) * (x - boundary_points[0]) + 1.0

In [None]:
f = lambda x: np.pi ** 2 * np.sin(np.pi * x)

boundary_points = np.array([-1.0, 1.0])
g = np.array([0.0, 0.0])
g_std = np.array([0.1, 0.1])

u = lambda x: -jnp.sin(jnp.pi * x)

## Prior

In [None]:
lengthscale = 1.0
output_scale = 3.0

@jax.jit
def prior_mean(x):
    return jnp.full_like(x[..., 0], 0.0)
#     return -0.5 * x[..., 0] ** 2 + 0.5
#     return jnp.sin(jnp.pi * x)

@jax.jit
def prior_cov(x0, x1):
    sqnorms = jnp.sum((x0 - x1) ** 2, axis=-1)

    return output_scale ** 2 * jnp.exp(-(1.0 / (2.0 * lengthscale ** 2)) * sqnorms)

prior_gp = pn.randprocs.GaussianProcess(
    mean=JaxMean(prior_mean, vectorize=False),
    cov=JaxKernel(
        prior_cov,
        input_dim=1,
    ),
)

In [None]:
prior_gp.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$u \sim \mathcal{GP}(m, k)$",
)

plt.plot(
    grid,
    u(grid),
    label="$u^*$",
)

plt.legend()
plt.title("Prior")
plt.savefig("../results/0006_01_poisson_rbf_prior.png", dpi=300)
plt.show()

## Posterior (Boundary Values First)

### Conditioning on Boundary Conditions

In [None]:
def condition_gp_on_observations(gp, X, Y, Lambda):
    mX = gp._meanfun.jax(X)
    kXX = gp._covfun.jax(X[:, None, :], X[None, :, :]) + Lambda
    L_kXX = jax.scipy.linalg.cho_factor(kXX)
    
    @jax.jit
    def cond_mean(x):
        mx = gp._meanfun.jax(x)
        kxX = gp._covfun.jax(x, X)
        return mx + kxX @ jax.scipy.linalg.cho_solve(L_kXX, (Y - mX))

    @jax.jit
    def cond_cov(x0, x1):
        kxx = gp._covfun.jax(x0, x1)
        kxX = gp._covfun.jax(x0, X)
        kXx = gp._covfun.jax(X, x1)
        return kxx - kxX @ jax.scipy.linalg.cho_solve(L_kXX, kXx)

    cond_gp = pn.randprocs.GaussianProcess(
        mean=JaxMean(cond_mean),
        cov=JaxKernel(
            cond_cov,
            input_dim=1,
        ),
    )
    
    return cond_gp

pn.randprocs.GaussianProcess.condition_on_observations = condition_gp_on_observations

In [None]:
prior_bc = prior_gp.condition_on_observations(
    boundary_points[:, None],
    g,
    jnp.diag(g_std) ** 2,
)

In [None]:
prior_bc.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    grid,
    u(grid),
    label="$u^*$",
)

plt.errorbar(
    boundary_points,
    g,
    yerr=g_std,
    fmt="+",
    capsize=2,
)

plt.title("Prior Conditioned on Boundary Values")
plt.legend()
plt.savefig("../results/0006_02_poisson_rbf_prior_bc.png", dpi=300)
plt.show()

### Predictive Induced by $\Delta$

In [None]:
def laplace_gp(gp):
    mean = laplace(gp._meanfun.jax)
    cov = laplace(laplace(gp._covfun.jax, 1), 0)
    crosscov = jnp.vectorize(laplace(gp._covfun.jax, 1), signature="(d),(d)->()")

    lapf = pn.randprocs.GaussianProcess(
        mean=JaxMean(mean),
        cov=JaxKernel(cov, input_dim=1),
    )
    
    return lapf, crosscov

pn.randprocs.GaussianProcess.laplace = laplace_gp

In [None]:
pred_bc, pred_bc_crosscov = prior_bc.laplace()

In [None]:
pred_bc.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    grid,
    f(grid),
    label="f",
)

plt.legend()
plt.title("Prior Mapped Through Differential Operator")
plt.savefig("../results/0006_03_poisson_rbf_prior_predictive.png", dpi=300)
plt.show()

### Conditioning on the PDE

In [None]:
N = 3

X = np.linspace(-0.8, 0.8, N)
fX = pn.randvars.Normal(
    mean=f(X),
    cov=np.diag(
        np.full_like(X, 0.0) ** 2
    ),
)

In [None]:
def condition_gp_on_predictive_gp(f, Lf, jax_crosscov, X, LfX):
    LmX = Lf._meanfun.jax(X)
    gramXX = Lf._covfun.jax(X[:, None, :], X[None, :, :]) + LfX.cov
    gramXX_cho = jax.scipy.linalg.cho_factor(gramXX)
    
    @jax.jit
    def pred_cond_mean(x):
        mx = f._meanfun.jax(x)
        kLxX = jax_crosscov(x[None], X)
        return mx + kLxX @ jax.scipy.linalg.cho_solve(gramXX_cho, (LfX.mean - LmX))

    @jax.jit
    def pred_cond_cov(x0, x1):
        kxx = f._covfun.jax(x0, x1)
        kLxX = jax_crosscov(x0, X)
        LkXx = jax_crosscov(x1, X).T
        return kxx - kLxX @ jax.scipy.linalg.cho_solve(gramXX_cho, LkXx)

    cond_gp = pn.randprocs.GaussianProcess(
        mean=JaxMean(pred_cond_mean),
        cov=JaxKernel(pred_cond_cov, input_dim=f.input_dim),
    )
    
    return cond_gp

pn.randprocs.GaussianProcess.condition_on_predictive_gp = condition_gp_on_predictive_gp

In [None]:
post_bcfirst = prior_bc.condition_on_predictive_gp(pred_bc, pred_bc_crosscov, X[:, None], fX)

In [None]:
post_bcfirst.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$u \mid \Delta u(X) = f(X)$"
)

plt.scatter(
    X,
    post_bcfirst.mean(X[:, None]),
    marker="|",
    s=80,
    c="C1",
    label=f"$X = (x_1, \dots, x_{N})$",
)

plt.plot(
    grid,
    u(grid),
    label="$u^*$",
)

plt.legend()
plt.title("Posterior")
plt.savefig("../results/0006_04_poisson_rbf_posterior.png", dpi=300)
plt.show()

### Posterior Predictive

In [None]:
post_pred_bcfirst, _ = post_bcfirst.laplace()

In [None]:
post_pred_bcfirst.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid \Delta u(X) = f(X)$"
)


plt.errorbar(
    X,
    fX.mean,
    yerr=fX.std,
    fmt="+",
    capsize=2,
    c="C1",
    label=f"$f(X) = (f(x_1), \dots, f(x_{N}))$",
)

plt.plot(
    grid,
    f(grid),
    c="C2",
    label="f",
)

plt.legend()
plt.title("Posterior Mapped Through Differential Operator")
plt.savefig("../results/0006_05_poisson_rbf_posterior_predictive.png", dpi=300)
plt.show()

## Posterior (PDE First)

### Predictive Induced by $\Delta$

In [None]:
pred_nobc, pred_crosscov_nobc = prior_gp.laplace()

In [None]:
pred_nobc.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    grid,
    f(grid),
    label="f",
)

plt.legend()
plt.title("Prior Mapped Through Differential Operator")
plt.savefig("../results/0006_03_poisson_rbf_prior_predictive.png", dpi=300)
plt.show()

### Conditioning on the PDE

In [None]:
N = 3

X = np.linspace(-0.8, 0.8, N)
fX = pn.randvars.Normal(
    mean=f(X),
    cov=np.diag(
        np.full_like(X, 0.0) ** 2
    ),
)

In [None]:
post_nobc = prior_gp.condition_on_predictive_gp(pred_nobc, pred_crosscov_nobc, X[:, None], fX)

In [None]:
post_nobc.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$u \mid \Delta u(X) = f(X)$"
)

plt.scatter(
    X,
    post_nobc.mean(X[:, None]),
    marker="|",
    s=80,
    c="C1",
    label=f"$X = (x_1, \dots, x_{N})$",
)

plt.plot(
    grid,
    u(grid),
    label="$u^*$",
)

plt.legend()
plt.title("Posterior")
#plt.savefig("../results/0006_04_poisson_rbf_posterior.png", dpi=300)
plt.show()

### Conditioning on the Boundary Conditions

In [None]:
post_pdefirst = post_nobc.condition_on_observations(
    boundary_points[:, None],
    g,
    jnp.diag(g_std),
)

In [None]:
post_pdefirst.plot(
    plt.gca(),
    grid,
    num_samples=10,
    rng=rng,
    label="$u \mid \Delta u(X) = f(X)$"
)

plt.scatter(
    X,
    post_pdefirst.mean(X[:, None]),
    marker="|",
    s=80,
    c="C1",
    label=f"$X = (x_1, \dots, x_{N})$",
)

plt.plot(
    grid,
    u(grid),
    label="$u^*$",
)

plt.legend()
plt.title("Posterior")
plt.savefig("../results/0006_04_poisson_rbf_posterior1.png", dpi=300)
plt.show()