In [None]:
import functools

import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn
import scipy.linalg

import linpde_gp as linpde_gp

from typing import Optional, Tuple

In [None]:
%matplotlib inline

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
rng = np.random.default_rng(24)

## Problem Definition

In [None]:
# Domain
boundary = np.array([-1.0, 1.0])

plt_grid = np.linspace(*boundary, 100)

# Differential Operator
def diffop(f, a, argnum: int = 0):
    laplace_f = linpde_gp.problems.pde.diffops.laplace(f, argnum=argnum)
    
    @jax.jit
    def _scaled_laplace(*args, **kwargs):
        return a * laplace_f(*args, **kwargs)
    
    return _scaled_laplace

In [None]:
# PDE Diffop Parameter
a_true = 2.0

# PDE RHS
_f_const = -2.0
f = lambda x: np.full_like(x, _f_const)

# Boundary Conditions
g = pn.randvars.Normal(
    mean=np.array([-0.1, 0.9]),
    cov=0.0 ** 2 * np.eye(boundary.size),
)

# True Solution
_u_sol_coeffs = [
    g.mean[0],
    (g.mean[1] - g.mean[0]) / (boundary[1] - boundary[0]),
    0.5 * (_f_const / a_true),
]
u_sol = lambda x: _u_sol_coeffs[0] + (_u_sol_coeffs[1] + _u_sol_coeffs[2] * (x - boundary[1])) * (x - boundary[0])

# Priors
lengthscale = 1.0
output_scale = 1.0

@jax.jit
def u_prior_mean(x):
    return jnp.full_like(x[..., 0], 0.0)

@jax.jit
def u_prior_cov(x0, x1):
    sqnorms = jnp.sum((x0 - x1) ** 2, axis=-1)

    return output_scale ** 2 * jnp.exp(-(1.0 / (2.0 * lengthscale ** 2)) * sqnorms)

a_prior = pn.randvars.Normal(
    mean=1.8,
    cov=1.0 ** 2,
)

def a_u_crosscov_prior(x):
    return jnp.zeros_like(x[..., 0])

# Locations of Measurements of the Solution u^*
u_sol_meas_xs = np.array([0.5])
u_sol_meas_std = np.array([0.1])

# PDE Measurements
collocation_points = np.array([-0.5, 0.0, 0.6, 0.7, -0.5])
collocation_std = np.array([4.0, 4.0, 4.0, 4.0, 4.0])

In [None]:
# Plotting
def plot_belief(
    u: pn.randprocs.GaussianProcess,
    a: pn.randvars.Normal,
    boundary_values: Optional[pn.randvars.Normal] = None,
    u_measurements: Optional[Tuple[np.ndarray, pn.randvars.Normal]] = None,
    collocation_points: Optional[np.ndarray] = None,
    a_prior: Optional[pn.randvars.Normal] = None,
):
    fig, ax = plt.subplots(ncols=2, figsize=(12, 4))
    
    # Belief over u
    ax[0].plot(
        plt_grid,
        u_sol(plt_grid),
        c="C1",
        label="$u^*$",
    )

    u_conditional_strs = []
    
    if boundary_values is not None:
        u_conditional_strs.append(r"u(\partial \Omega) = g(\partial \Omega)")
        
    if u_measurements is not None:
        u_conditional_strs.append(r"u(X_u) = Y_u")
        
    if collocation_points is not None:
        u_conditional_strs.append(r"\mathcal{L}_a u(X_c) = f(X_c)")
    
    u.plot(
        ax[0],
        plt_grid,
        num_samples=10,
        rng=rng,
        color="C0",
        label=fr"$u \mid {', '.join(u_conditional_strs)}$" if u_conditional_strs else r"$u$",
    )
    
    if boundary_values is not None:
        ax[0].errorbar(
            boundary,
            boundary_values.mean,
            yerr=boundary_values.std,
            fmt="+",
            capsize=2,
            color="C2",
            label=r"$g(\partial \Omega)$"
        )
        
    if u_measurements is not None:
        u_measurements_xs, u_measurements_ys = u_measurements

        ax[0].errorbar(
            u_measurements_xs,
            u_measurements_ys.mean,
            yerr=u_measurements_ys.std,
            fmt="+",
            capsize=2,
            color="C3",
            label=r"$Y_u$"
        )
        
    if collocation_points is not None:
        for cpoint in collocation_points:
            ax[0].axvline(
                cpoint,
                color="C4",
                alpha=0.6,
                label=r"X_c",
            )
    
    ax[0].legend()
    
    # Belief over a
    ax[1].axvline(a_true, c="C1", label=r"$a^*$")

    if a_prior is not None:
        assert collocation_points is not None
        
        a_prior.plot(
            ax[1],
            label=r"p(a)",
            c="C0",
        )
        
        a.plot(
            ax[1],
            label=fr"$p(a \mid {', '.join(u_conditional_strs)})$",
            c="C2",
        )
    else:
        a.plot(
            ax[1],
            label=r"$p(a)$",
            c="C0",
        )
    
    ax[1].legend()
    
    return fig, ax

## Prior

In [None]:
u_prior = pn.randprocs.GaussianProcess(
    mean=linpde_gp.randprocs.mean_fns.JaxMean(u_prior_mean, vectorize=False),
    cov=linpde_gp.randprocs.kernels.JaxKernel(u_prior_cov, input_dim=1, vectorize=False),
)

In [None]:
plot_belief(
    u=u_prior,
    a=a_prior,
)
plt.savefig("../results/0007_poisson_rbf_00_prior.pdf", dpi=300)
plt.show()

## Conditioning on Boundary Values

In [None]:
u_bc = u_prior.condition_on_observations_jax(boundary[:, None], g)

In [None]:
plot_belief(
    u=u_bc,
    a=a_prior,
    boundary_values=g,
)

plt.savefig("../results/0007_poisson_rbf_bcfirst_01_cond_bc.pdf", dpi=300)
plt.show()

## Condition on Measurements of the Solution

In [None]:
# Measure the Solution
u_sol_meas_ys = u_sol(u_sol_meas_xs)
u_sol_meas_ys += u_sol_meas_std * np.random.randn(*u_sol_meas_ys.shape)

u_sol_meas_ys = pn.randvars.Normal(
    mean=u_sol_meas_ys,
    cov=np.diag(u_sol_meas_std ** 2),
)

In [None]:
u_bc_obs = u_bc.condition_on_observations_jax(u_sol_meas_xs[:, None], u_sol_meas_ys)

In [None]:
plot_belief(
    u=u_bc_obs,
    a=a_prior,
    boundary_values=g,
    u_measurements=(u_sol_meas_xs, u_sol_meas_ys),
)

plt.savefig("../results/0007_poisson_rbf_bcfirst_01_cond_bc.pdf", dpi=300)
plt.show()

## Condition on the PDE iteratively

In [None]:
# Initialize loop variables
cpoint_idx = 0

u_post = u_bc_obs
a_post = a_prior
a_u_crosscov_post = a_u_crosscov_prior

### Linearize the Likelihood

In [None]:
# Define linearization points
u_hat = u_post._meanfun
a_hat = a_post.mean

In [None]:
_M_a = jax.grad(lambda a, x: diffop(u_hat, a)(x), argnums=0)

@functools.partial(jnp.vectorize, signature="(d)->()")
def M_a(x):
    return _M_a(a_hat, x)

In [None]:
Lu_post, u_Lu_post_crosscov = u_post.apply_jax_linop(diffop, a=a_hat)

In [None]:
Lu_post.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\mathcal{L}_{\hat{a}} u \mid u(\partial \Omega) = g(\partial \Omega), u(X_u) = Y_u$"
)

plt.plot(
    plt_grid,
    f(plt_grid),
    color="C3",
    label="f",
)

plt.legend()
plt.grid()
plt.show()

In [None]:
a_u_crosscov_L_adj = jnp.vectorize(jax.jit(jax.grad(a_u_crosscov_post, argnums=0)), signature="(d)->(d)")

### Compute the Predictive

In [None]:
def u_lin_pred_cov(x0, x1):
    M_a_x0 = M_a(x0)
    M_a_x1 = M_a(x1)
    
    return (
        M_a_x0 * a_post.cov * M_a_x1
        + a_u_crosscov_L_adj(x0).T[0] * M_a_x1
        + M_a_x0.T * a_u_crosscov_L_adj(x1)[0]
        + Lu_post._covfun.jax(x0, x1)
    )

u_lin_pred = pn.randprocs.GaussianProcess(
    mean=Lu_post._meanfun,
    cov=linpde_gp.randprocs.kernels.JaxKernel(
        u_lin_pred_cov,
        input_dim=1,
        vectorize=True
    ),
)

In [None]:
u_lin_pred.plot(
    plt.gca(),
    plt_grid,
    num_samples=10,
    rng=rng,
    label="$\hat{\mathcal{L}}_{a} u \mid u(\partial \Omega) = g(\partial \Omega), u(X_u) = Y_u$"
)

plt.plot(
    plt_grid,
    f(plt_grid),
    color="C3",
    label="f",
)

plt.legend()
plt.grid()
plt.show()

### Compute Posterior Corresponding to Linearized Likelihood

In [None]:
# Measure the PDE RHS
cpoint = collocation_points[[cpoint_idx]]
cpoint_std = collocation_std[[cpoint_idx]]

f_cpoint = f(cpoint)
fX = pn.randvars.Normal(
    mean=f_cpoint,
    cov=np.diag(cpoint_std),
)

In [None]:
gram = u_lin_pred.covmatrix(cpoint[:, None], cpoint[:, None]) + fX.cov
gram_cho = scipy.linalg.cho_factor(gram)

# Infer the parameter
a_pred_u_crosscov = a_post.cov * M_a(cpoint[:, None]) + a_u_crosscov_L_adj(cpoint[:, None])[0]
a_gain = scipy.linalg.cho_solve(gram_cho, a_pred_u_crosscov)

a_post_pde = pn.randvars.Normal(
    mean=np.asarray(a_post.mean + a_gain @ (fX.mean - u_lin_pred.mean(cpoint[:, None]))),
    cov=np.asarray(a_post.cov - a_gain @ a_pred_u_crosscov.T)
)

# Infer the solution
def _u_pred_u_crosscov(x1, x2):
    res = a_u_crosscov_L_adj(x1) @ M_a(x2)[..., None] + u_Lu_post_crosscov.jax(x1, x2)
    return res

u_pred_u_crosscov = linpde_gp.randprocs.kernels.JaxKernel(_u_pred_u_crosscov, input_dim=1, vectorize=True)

u_post_pde = u_post.condition_on_predictive_gp_observations_jax(
    u_lin_pred,
    u_pred_u_crosscov,
    cpoint[:, None],
    fX,
)

# Posterior cross covariance
def a_u_crosscov_post_pde(x):
    return a_u_crosscov_post(x) - a_gain @ u_pred_u_crosscov.jax(x, cpoint[:, None])

In [None]:
plot_belief(
    u=u_post_pde,
    a=a_post_pde,
    boundary_values=g,
    u_measurements=(u_sol_meas_xs, u_sol_meas_ys),
    collocation_points=collocation_points[:1],
    a_prior=a_prior,
)

plt.show()

### Iterated Inference

In [None]:
def infer(
    u,
    a,
    a_u_crosscov,
    cpoints,
    cpoints_std,
):
    # Linearize the likelihood
    u_hat = u._meanfun
    a_hat = a.mean
    
    _M_a = jax.grad(lambda a, x: diffop(u_hat, a)(x), argnums=0)

    @functools.partial(jnp.vectorize, signature="(d)->()")
    @jax.jit
    def M_a(x):
        return _M_a(a_hat, x)
    
    Lu, u_Lu_crosscov = u.apply_jax_linop(diffop, a=a_hat)
    
    a_u_crosscov_L_adj = jax.jit(jnp.vectorize(jax.jit(jax.grad(a_u_crosscov, argnums=0)), signature="(d)->(d)"))
    
    # Compute the predictive
    @jax.jit
    def u_lin_pred_cov(x0, x1):
        M_a_x0 = M_a(x0)
        M_a_x1 = M_a(x1)

        return (
            M_a_x0 * a.cov * M_a_x1
            + a_u_crosscov_L_adj(x0).T[0] * M_a_x1
            + M_a_x0.T * a_u_crosscov_L_adj(x1)[0]
            + Lu._covfun.jax(x0, x1)
        )

    u_lin_pred = pn.randprocs.GaussianProcess(
        mean=Lu._meanfun,
        cov=linpde_gp.randprocs.kernels.JaxKernel(
            u_lin_pred_cov,
            input_dim=1,
            vectorize=True
        ),
    )
    
    # Measure the PDE RHS
    f_cpoints = f(cpoints)

    fX = pn.randvars.Normal(
        mean=f_cpoints,
        cov=np.diag(cpoints_std),
    )
    
    # Joint inference
    gram = u_lin_pred.covmatrix(cpoints[:, None], cpoints[:, None]) + fX.cov
    gram_cho = scipy.linalg.cho_factor(gram)
    
    print(fX.cov, gram)

    # Infer the parameter
    a_pred_u_crosscov = a.cov * M_a(cpoints[:, None]) + a_u_crosscov_L_adj(cpoints[:, None])[0]
    a_gain = scipy.linalg.cho_solve(gram_cho, a_pred_u_crosscov)

    a_post = pn.randvars.Normal(
        mean=np.asarray(a.mean + a_gain @ (fX.mean - u_lin_pred.mean(cpoints[:, None]))),
        cov=np.asarray(a.cov - a_gain @ a_pred_u_crosscov.T)
    )

    # Infer the solution
    @jax.jit
    def _u_pred_u_crosscov(x1, x2):
        return a_u_crosscov_L_adj(x1) @ M_a(x2)[..., None] + u_Lu_crosscov.jax(x1, x2)

    u_pred_u_crosscov = linpde_gp.randprocs.kernels.JaxKernel(_u_pred_u_crosscov, input_dim=1, vectorize=True)

    u_post = u.condition_on_predictive_gp_observations_jax(
        u_lin_pred,
        u_pred_u_crosscov,
        cpoints[:, None],
        fX,
    )

    # Posterior cross covariance
    def a_u_crosscov_post(x):
        return a_u_crosscov(x) - a_gain @ u_pred_u_crosscov.jax(x, cpoints[:, None])
    
    return u_post, a_post, a_u_crosscov_post

In [None]:
posterior_0 = infer(u_bc_obs, a_prior, a_u_crosscov_prior, collocation_points[[0]], collocation_std[[0]])

In [None]:
plot_belief(
    *posterior_0[:2],
    boundary_values=g,
    u_measurements=(u_sol_meas_xs, u_sol_meas_ys),
    collocation_points=collocation_points[:1],
    a_prior=a_prior,
)

plt.show()

In [None]:
posterior_1 = infer(*posterior_0, collocation_points[[1]], collocation_std[[1]])

In [None]:
plot_belief(
    *posterior_1[:2],
    boundary_values=g,
    u_measurements=(u_sol_meas_xs, u_sol_meas_ys),
    collocation_points=collocation_points[:2],
    a_prior=a_prior,
)

In [None]:
posterior_2 = infer(*posterior_1, collocation_points[[2]], collocation_std[[2]])

In [None]:
plot_belief(
    *posterior_2[:2],
    boundary_values=g,
    u_measurements=(u_sol_meas_xs, u_sol_meas_ys),
    collocation_points=collocation_points[:3],
    a_prior=a_prior,
)

In [None]:
posterior_2[0].var(plt_grid[:, None])