In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import linpde_gp

In [None]:
import notebook_utils
from notebook_utils import config

config.notebook_name = "0010_poisson_1d_inverse_rhs"
config.target = "jmlr"
config.debug_mode = True

In [None]:
%matplotlib inline

In [None]:
plt.rcParams.update(config.tueplot_bundle())
plt.rcParams["text.latex.preamble"] += "\n\\usepackage{amsfonts}"

## Problem Definition

In [None]:
from linpde_gp.problems.pde import domains, poisson_1d_bvp

In [None]:
domain = domains.asdomain((-1.0, 1.0))

mu, sigma = 0.4, 0.3
u_true = lambda x: np.exp(-0.5 / sigma ** 2 * (x[..., 0] - mu) ** 2)
f_true = lambda x: (1.0 - ((x[..., 0] - mu) / sigma) ** 2) / sigma ** 2 * u_true(x)
boundary_values = (u_true(domain[0][None]), u_true(domain[1][None]))

bvp = poisson_1d_bvp(
    domain=domain,
    rhs=f_true,
    boundary_values=boundary_values,
    solution=u_true,
)

In [None]:
# Plotting

plt_grid = np.linspace(*domain, 100)

def plot_belief(
    u: pn.randprocs.GaussianProcess,
    f: pn.randprocs.GaussianProcess,
    bc: bool = False,
    u_meas: tuple[np.ndarray, np.ndarray, pn.randvars.Normal] = None,
    pde_meas: tuple[np.ndarray, pn.randvars.Normal] = None,
):
    with plt.rc_context(config.tueplot_bundle(ncols=2)):
        fig, ax = plt.subplots(ncols=2)
        
        u.plot(
            ax[0],
            plt_grid,
            num_samples=10,
            rng=np.random.default_rng(24),
            label="$u$"
        )
        
        ax[0].plot(
            plt_grid,
            bvp.solution(plt_grid[:, None]),
            label="$u^*$",
        )
        
        if bc:
            ax[0].errorbar(
                list(domain),
                boundary_values,
                yerr=0,
                fmt="+",
                capsize=2,
                label=r"$u \vert_{\partial \Omega}$"
            )
            
        if u_meas is not None:
            X_meas, Y_meas, yerr_meas = u_meas
            
            ax[0].errorbar(
                X_meas,
                Y_meas,
                yerr=1.96 * yerr_meas.std,
                fmt="+",
                capsize=2,
                label=r"$u(X_\mathrm{meas})$"
            )
        
        ax[0].legend()
        
        f.plot(
            ax[1],
            plt_grid,
            num_samples=10,
            rng=np.random.default_rng(24),
            label="$f$"
        )
        
        ax[1].plot(
            plt_grid,
            bvp.rhs(plt_grid[:, None]),
            label="$f^*$"
        )
        
        # ax[1].plot(
        #     plt_grid,
        #     bvp.diffop(u).mean(plt_grid[:, None]),
        #     label=r"$\mathbb{E}[-\Delta u]$"
        # )
        
        if pde_meas:
            X_pde, Lu_X_pde = pde_meas
            
            ax[1].errorbar(
                X_pde,
                Lu_X_pde.mean,
                yerr=1.96 * Lu_X_pde.std,
                fmt="+",
                capsize=2,
                label=r"$-\Delta u(X_\mathrm{PDE})$"
            )
        
        ax[1].legend()

# Priors

In [None]:
u_prior = pn.randprocs.GaussianProcess(
    mean=linpde_gp.randprocs.mean_fns.Zero(),
    cov=linpde_gp.randprocs.kernels.ExpQuad(
        input_dim=1,
        lengthscales=0.5,
        output_scale=1.0,
    ),
)

f_prior = pn.randprocs.GaussianProcess(
    mean=linpde_gp.randprocs.mean_fns.Zero(),
    cov=linpde_gp.randprocs.kernels.ExpQuad(
        input_dim=1,
        lengthscales=0.25,
        output_scale=10.0,
    ),
)

In [None]:
plot_belief(
    u=u_prior,
    f=f_prior,
)

# Observations

## Boundary Conditions 

In [None]:
u_bc = u_prior.condition_on_observations(
    X=np.hstack(bvp.domain.boundary)[:, None],
    fX=np.stack([boundary_condition.values.mean for boundary_condition in bvp.boundary_conditions])
)

In [None]:
plot_belief(
    u=u_bc,
    f=f_prior,
    bc=True,
)

## Empirical Measurements

In [None]:
X_meas = np.linspace(*bvp.domain, 10 + 2)[1:-1, None]
Y_meas = bvp.solution(X_meas)
err_meas = pn.randvars.Normal(
    mean=np.zeros_like(X_meas[..., 0]),
    cov=np.diag(np.full_like(X_meas[..., 0], 0.1 ** 2)),
)

u_bc_meas = u_bc.condition_on_observations(
    X=X_meas,
    fX=Y_meas,
    noise_model=err_meas
)

In [None]:
plot_belief(
    u=u_bc_meas,
    f=f_prior,
    bc=True,
    u_meas=(X_meas, Y_meas, err_meas),
)

# PDE

In [None]:
X_pde = np.linspace(*bvp.domain, 10 + 2)[1:-1, None]
Lu_X_pde = bvp.diffop(u_bc_meas)(X_pde)

u_post = u_bc_meas.condition_on_linop_observations(
    L=bvp.diffop,
    X=X_meas,
    LfX=np.zeros_like(X_meas[..., 0]),
    noise_model=-f_prior(X_meas),
)

f_post = f_prior.condition_on_observations(
    X=X_pde,
    fX=np.zeros_like(X_pde[..., 0]),
    noise_model=-Lu_X_pde,
)

In [None]:
plot_belief(
    u=u_post,
    f=f_post,
    bc=True,
    u_meas=(X_meas, Y_meas, err_meas),
    pde_meas=(X_pde, Lu_X_pde),
)

notebook_utils.savefig("u_f_posterior")