In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import linpde_gp

In [None]:
import notebook_utils

notebook_utils.config.notebook_name = "0006_poisson_nonparametric"
notebook_utils.config.debug_mode = True

In [None]:
%matplotlib inline

In [None]:
rng = np.random.default_rng(24)

## Problem Definition

In [None]:
from linpde_gp.problems.pde import BoundaryValueProblem, DirichletBoundaryCondition, diffops, domains

In [None]:
domain = domains.Interval(-1.0, 1.0)

dirichlet_bc = DirichletBoundaryCondition(
    domain.boundary,
    values=pn.randvars.Normal(
        mean=np.array([0.0, 1.0]),
        cov=np.diag(np.full((2,), 0.0 ** 2))
    ),
)

In [None]:
# PDE RHS
f = lambda x: np.full_like(x, -2.0)

# True Solution
u = lambda x: -x ** 2 + (dirichlet_bc.values.mean[1] - dirichlet_bc.values.mean[0]) / (domain[1] - domain[0]) * (x - domain[0]) + 1.0

# PDE Measurements
X = np.linspace(-0.8, 0.8, 3)
fX_std = np.full_like(X, 0.0 ** 2)

In [None]:
# PDE RHS
f = lambda x: np.pi ** 2 * np.sin(np.pi * x)

# True Solution
u = lambda x: -jnp.sin(jnp.pi * x)

# PDE Measurements
X = np.linspace(-0.8, 0.8, 3)
fX_std = np.full_like(X, 0.0) ** 2

In [None]:
bvp = BoundaryValueProblem(
    diffop=diffops.laplace,
    rhs=f,
    boundary_conditions=[
        DirichletBoundaryCondition(
            domain.boundary,
            values=pn.randvars.Normal(
                mean=np.array([0.0, 1.0]),
                cov=np.diag(np.full((2,), 0.0 ** 2))
            ),
        ),
    ]
)

In [None]:
N = X.size

fX = pn.randvars.Normal(mean=f(X), cov=np.diag(fX_std))

In [None]:
plt_grid = np.linspace(*domain, 100)

## Prior

In [None]:
lengthscale = 1.0
output_scale = 2.0

@jax.jit
def prior_mean(x):
    return jnp.full_like(x[..., 0], 0.0)
#     return -0.5 * x[..., 0] ** 2 + 0.5
#     return jnp.sin(jnp.pi * x)

@jax.jit
def prior_cov(x0, x1):
    sqnorms = jnp.sum((x0 - x1) ** 2, axis=-1)

    return output_scale ** 2 * jnp.exp(-(1.0 / (2.0 * lengthscale ** 2)) * sqnorms)

prior_gp = pn.randprocs.GaussianProcess(
    mean=linpde_gp.randprocs.mean_fns.JaxMean(prior_mean, vectorize=False),
    cov=linpde_gp.randprocs.kernels.JaxKernel(prior_cov, input_dim=1, vectorize=False),
)

In [None]:
prior_gp.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$u \sim \mathcal{GP}(m, k)$",
)

plt.plot(
    plt_grid,
    u(plt_grid),
    label="$u^*$",
)

plt.legend()

notebook_utils.savefig("00_prior")
plt.show()

## Posterior (Boundary Values First)

### Conditioning on Boundary Conditions

In [None]:
u_bc = prior_gp.condition_on_observations_jax(
    np.array(list(dirichlet_bc.boundary))[:, None],
    dirichlet_bc.values
)

In [None]:
u_bc.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label=r"$u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    plt_grid,
    u(plt_grid),
    label=r"$u^*$",
)

plt.errorbar(
    dirichlet_bc.boundary,
    dirichlet_bc.values.mean,
    yerr=dirichlet_bc.values.std,
    fmt="+",
    capsize=2,
    label=r"$g(\partial \Omega)$"
)

plt.legend()

notebook_utils.savefig("01_cond_bc")
plt.show()

### Predictive Induced by $\Delta$

In [None]:
laplace_u_bc, laplace_u_bc_crosscov = u_bc.apply_jax_linop(bvp.diffop)

In [None]:
laplace_u_bc.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    plt_grid,
    f(plt_grid),
    label="f",
)

plt.legend()

notebook_utils.savefig("02_pred_cond_bc")
plt.show()

### Conditioning on the PDE

In [None]:
u_bc_pde = u_bc.condition_on_predictive_gp_observations_jax(laplace_u_bc, laplace_u_bc_crosscov, X[:, None], fX)

In [None]:
u_bc_pde.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$u \mid u(\partial \Omega) = g(\partial \Omega), \Delta u(x_i) = f(x_i)$",
)

plt.plot(
    plt_grid,
    u(plt_grid),
    color="C1",
    label="$u^*$",
)

plt.errorbar(
    dirichlet_bc.boundary,
    dirichlet_bc.values.mean,
    yerr=dirichlet_bc.values.std,
    fmt="+",
    capsize=2,
    color="C2",
    label=r"$g(\partial \Omega)$"
)

linpde_gp.plotting.plot_local_curvature(
    plt.gca(),
    xs=X,
    f_xs=u_bc_pde.mean(X[:, None]),
    ddf_xs=fX,
    df_xs=jnp.vectorize(jax.grad(u_bc_pde._meanfun), signature="(d)->(d)")(X[:, None])[:, 0],
    color="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()

notebook_utils.savefig("03_cond_bc_pde")
plt.show()

### Posterior Predictive

In [None]:
lalace_u_bc_pde, _ = u_bc_pde.apply_jax_linop(bvp.diffop)

In [None]:
lalace_u_bc_pde.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid u(\partial \Omega) = g(\partial \Omega), \Delta u(x_i) = f(x_i)$"
)


plt.plot(
    plt_grid,
    f(plt_grid),
    label="$f$",
)


plt.errorbar(
    X,
    fX.mean,
    yerr=fX.std,
    fmt="+",
    capsize=2,
    c="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()

notebook_utils.savefig("04_pred_cond_bc_pde")
plt.show()

## Posterior (PDE First)

### Predictive Induced by $\Delta$

In [None]:
laplace_u, laplace_u_crosscov = prior_gp.apply_jax_linop(bvp.diffop)

In [None]:
laplace_u.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u$"
)

plt.plot(
    plt_grid,
    f(plt_grid),
    label="f",
)

plt.legend()

notebook_utils.savefig("pdefirst_01_prior_pred")
plt.show()

### Conditioning on the PDE

In [None]:
u_pde = prior_gp.condition_on_predictive_gp_observations_jax(laplace_u, laplace_u_crosscov, X[:, None], fX)

In [None]:
u_pde.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$u \mid \Delta u(x_i) = f(x_i)$"
)

plt.plot(
    plt_grid,
    u(plt_grid),
    label="$u^*$",
)

linpde_gp.plotting.plot_local_curvature(
    plt.gca(),
    xs=X,
    f_xs=u_pde.mean(X[:, None]),
    ddf_xs=fX,
    df_xs=jnp.vectorize(jax.grad(u_pde._meanfun), signature="(d)->(d)")(X[:, None])[:, 0],
    color="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()

notebook_utils.savefig("pdefirst_02_cond_pde")
plt.show()

### Posterior Predictive

In [None]:
laplace_u_pde, _ = u_pde.apply_jax_linop(bvp.diffop)

In [None]:
laplace_u_pde.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid \Delta u(x_i) = f(x_i)$"
)

plt.plot(
    plt_grid,
    f(plt_grid),
    label="f",
)

plt.errorbar(
    X,
    fX.mean,
    yerr=fX.std,
    fmt="+",
    capsize=2,
    c="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()

notebook_utils.savefig("pdefirst_03_pred_cond_pde")
plt.show()

### Conditioning on the Boundary Conditions

In [None]:
u_pde_bc = u_pde.condition_on_observations_jax(np.array(list(dirichlet_bc.boundary))[:, None], dirichlet_bc.values)

In [None]:
u_pde_bc.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label=r"$u \mid \Delta u(X) = f(X), u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    plt_grid,
    u(plt_grid),
    label="$u^*$",
)

plt.errorbar(
    dirichlet_bc.boundary,
    dirichlet_bc.values.mean,
    yerr=dirichlet_bc.values.std,
    fmt="+",
    capsize=2,
    label=r"$g(\partial \Omega)$"
)

linpde_gp.plotting.plot_local_curvature(
    plt.gca(),
    xs=X,
    f_xs=u_pde_bc.mean(X[:, None]),
    ddf_xs=fX,
    df_xs=jnp.vectorize(jax.grad(u_pde_bc._meanfun), signature="(d)->(d)")(X[:, None])[:, 0],
    color="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()

notebook_utils.savefig("pdefirst_04_cond_pde_bc")
plt.show()

### Posterior Predictive with PDE and Boundary Conditions

In [None]:
laplace_u_pde_bc, _ = u_pde_bc.apply_jax_linop(bvp.diffop)

In [None]:
laplace_u_pde_bc.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\Delta u \mid \Delta u(x_i) = f(x_i), u(\partial \Omega) = g(\partial \Omega)$"
)

plt.plot(
    plt_grid,
    f(plt_grid),
    label="f",
)

plt.errorbar(
    X,
    fX.mean,
    yerr=fX.std,
    fmt="+",
    capsize=2,
    c="C3",
    label=f"$(f(x_1), \dots, f(x_{N}))$",
)

plt.legend()

notebook_utils.savefig("pdefirst_05_pred_cond_pde_bc")
plt.show()