In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import probnum_galerkin as pngal

In [None]:
%matplotlib inline

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
rng = np.random.default_rng(24)

## Problem Definition

In [None]:
# Domain
boundary = np.array([-1.0, 1.0])

plt_grid = np.linspace(*boundary, 100)

# Differential Operator
def diffop(f, a, argnum: int = 0):
    laplace_f = pngal.diffops.laplace(f, argnum=argnum)
    
    @jax.jit
    def _scaled_laplace(*args, **kwargs):
        return a * laplace_f(*args, **kwargs)
    
    return _scaled_laplace

In [None]:
# PDE Diffop Parameter
a = 2.0

# PDE RHS
_f_const = -2.0
f = lambda x: np.full_like(x, _f_const)

# Boundary Conditions
g = pn.randvars.Normal(
    mean=np.array([-0.1, 0.9]),
    cov=np.diag(np.full_like(boundary, 0.0 ** 2))
)

# True Solution
_aff_slope = (g.mean[1] - g.mean[0]) / (boundary[1] - boundary[0])

u = lambda x: g.mean[0] + (_aff_slope + 0.5 * (_f_const / a) * (x - boundary[1])) * (x - boundary[0])

# Priors
lengthscale = 1.0
output_scale = 1.0

@jax.jit
def prior_u_mean(x):
    return jnp.full_like(x[..., 0], 0.0)
#     return u(x[..., 0])

@jax.jit
def prior_u_cov(x0, x1):
    sqnorms = jnp.sum((x0 - x1) ** 2, axis=-1)

    return output_scale ** 2 * jnp.exp(-(1.0 / (2.0 * lengthscale ** 2)) * sqnorms)

prior_a = pn.randvars.Normal(
    mean=1.8,
    cov=1.0 ** 2,
)

# PDE Measurements
X = np.linspace(-0.8, 0.8, 11)
fX = pn.randvars.Normal(
    mean=f(X),
    cov=np.diag(np.full_like(X, 0.0 ** 2)),
)

In [None]:
N = X.size

## Prior

In [None]:
prior_u = pn.randprocs.GaussianProcess(
    mean=pngal.randprocs.mean_fns.JaxMean(prior_u_mean, vectorize=False),
    cov=pngal.randprocs.kernels.JaxKernel(prior_u_cov, input_dim=1, vectorize=False),
)

In [None]:
prior_u.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$u \sim \mathcal{GP}(m, k)$",
)

plt.plot(
    plt_grid,
    u(plt_grid),
    label="$u^*$",
)

plt.legend()
plt.savefig("../results/0007_poisson_rbf_00_prior.pdf", dpi=300)
plt.show()

In [None]:
prior_a.plot(
    plt.gca()
)

plt.axvline(a, c="C1")

## Conditioning on Boundary Values

In [None]:
u_bc = prior_u.condition_on_observations_jax(boundary[:, None], g)

In [None]:
u_bc.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label=r"$u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    plt_grid,
    u(plt_grid),
    label=r"$u^*$",
)

plt.errorbar(
    boundary,
    g.mean,
    yerr=g.std,
    fmt="+",
    capsize=2,
    color="C2",
    label=r"$g(\partial \Omega)$"
)

plt.legend()
plt.savefig("../results/0007_poisson_rbf_bcfirst_01_cond_bc.pdf", dpi=300)
plt.show()

## Linearized Predictive

In [None]:
laplace_u_bc, laplace_u_bc_crosscov = u_bc.apply_jax_linop(diffop, a=prior_a.mean)

In [None]:
laplace_u_bc.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    plt_grid,
    f(plt_grid),
    color="C3",
    label="f",
)

plt.legend()
plt.savefig("../results/0007_poisson_rbf_bcfirst_02_cond_bc_laplace.pdf", dpi=300)
plt.show()

In [None]:
Lu_grad_a = lambda x: jax.grad(lambda x, a: diffop(u_bc._meanfun, a)(x), argnums=1)(x, prior_a.mean)

@jax.jit
def lin_pred_cov(x0, x1):
    return laplace_u_bc._covfun.jax(x0, x1) + prior_a.cov.item() * Lu_grad_a(x0) * Lu_grad_a(x1)

u_pred = pn.randprocs.GaussianProcess(
    mean=laplace_u_bc._meanfun,
    cov=pngal.randprocs.kernels.JaxKernel(lin_pred_cov, input_dim=1, vectorize=True),
)

In [None]:
u_pred.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    plt_grid,
    f(plt_grid),
    color="C3",
    label="f",
)

plt.legend()
plt.savefig("../results/0007_poisson_rbf_bcfirst_02_cond_bc_laplace.pdf", dpi=300)
plt.show()

## Posterior Process

In [None]:
u_bc_pde = u_bc.condition_on_predictive_gp_observations_jax(u_pred, laplace_u_bc_crosscov, X[:, None], fX)

In [None]:
u_bc_pde.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$u \mid u(\partial \Omega) = g(\partial \Omega), \Delta u(x_i) = f(x_i)$",
)

plt.plot(
    plt_grid,
    u(plt_grid),
    color="C1",
    label="$u^*$",
)

plt.errorbar(
    boundary,
    g.mean,
    yerr=g.std,
    fmt="+",
    capsize=2,
    color="C2",
    label=r"$g(\partial \Omega)$"
)

plt.legend()
plt.show()