In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import linpde_gp

In [None]:
import experiment_utils
from experiment_utils import config

config.experiment_name = "0006_poisson_nonparametric"
config.target = "jmlr"
config.debug_mode = True

In [None]:
%matplotlib inline

In [None]:
plt.rcParams.update(config.tueplots_bundle())

## Problem Definition

In [None]:
from linpde_gp.problems.pde import domains, poisson_1d_bvp

In [None]:
domain = domains.Interval(-1.0, 1.0)

In [None]:
boundary_values = (0.0, 1.0)

# PDE RHS
f = lambda x: np.full_like(x, 2.0)

# True Solution
u = lambda x: 1.0 - x ** 2 + (boundary_values[1] - boundary_values[0]) / (domain[1] - domain[0]) * (x - domain[0])

# PDE Measurements
X = np.linspace(-0.8, 0.8, 3)
fX_std = np.full_like(X, 0.0 ** 2)

In [None]:
boundary_values = (0.0, 0.0)

# PDE RHS
f = lambda x: np.pi ** 2 * np.sin(np.pi * x)

# True Solution
u = lambda x: -jnp.sin(jnp.pi * x)

# PDE Measurements
X = np.linspace(-0.8, 0.8, 3)
fX_std = np.full_like(X, 0.0) ** 2

In [None]:
bvp = poisson_1d_bvp(
    domain=domain,
    rhs=f,
    boundary_values=(
        pn.randvars.Normal(boundary_values[0], 0.0 ** 2),
        pn.randvars.Normal(boundary_values[1], 0.0 ** 2),
    ),
    solution=u,
)

In [None]:
g = pn.randvars.Normal(
    mean=np.stack([boundary_condition.values.mean for boundary_condition in bvp.boundary_conditions]),
    cov=np.diag([boundary_condition.values.var for boundary_condition in bvp.boundary_conditions]),
)

In [None]:
N = X.size

fX = pn.randvars.Normal(mean=f(X), cov=np.diag(fX_std))

In [None]:
plt_grid = np.linspace(*domain, 100)

def plot_belief(ax, u, **kwargs):
    u_conditional_strs = []
    
    for key in kwargs.keys():
        if key == "g":
            u_conditional_strs.append(r"u\vert_{\partial \Omega} = g")
        elif key == "X_fX":
            u_conditional_strs.append(r"-\Delta u(X) = f(X)")
    
    u_label = (
        fr"$u \mid {', '.join(u_conditional_strs)}$"
        if len(u_conditional_strs) > 0
        else "$u$"
    )
    
    u.plot(
        ax,
        plt_grid,
        num_samples=10,
        rng=np.random.default_rng(24),
        color="C0",
        label=u_label,
    )
    
    ax.plot(
        plt_grid,
        bvp.solution(plt_grid),
        color="C1",
        label="$u^*$",
    )
    
    for key, value in kwargs.items():
        if key == "g":
            g = value

            ax.errorbar(
                bvp.domain.boundary,
                g.mean,
                yerr=1.96 * g.std,
                fmt="+",
                capsize=2,
                color="C2",
                label=r"$g$",
            )
        elif key == "X_fX":
            X, fX = value

            linpde_gp.plotting.plot_local_curvature(
                ax,
                xs=X,
                f_xs=u.mean(X),
                ddf_xs=-fX,
                df_xs=jnp.vectorize(jax.grad(u.mean.jax))(X),
                color="C3",
                label=f"$(f(x_1), \dots, f(x_{fX.size}))$",
            )
    
    ax.legend()

def plot_pred_belief(ax, Lu, **kwargs):
    u_conditional_strs = []
    
    for key in kwargs.keys():
        if key == "g":
            u_conditional_strs.append(r"u\vert_{\partial \Omega} = g")
        elif key == "X_fX":
            u_conditional_strs.append(r"-\Delta u(X) = f(X)")
    
    u_label = (
        fr"$-\Delta u \mid {', '.join(u_conditional_strs)}$"
        if len(u_conditional_strs) > 0
        else "$-\Delta u$"
    )

    Lu.plot(
        ax,
        plt_grid,
        num_samples=10,
        rng=np.random.default_rng(24),
        color="C0",
        label=u_label,
    )
    
    ax.plot(
        plt_grid,
        f(plt_grid),
        color="C1",
        label="$f$",
    )
    
    if "X_fX" in kwargs:
        X, fX = kwargs["X_fX"]

        ax.errorbar(
            X,
            fX.mean,
            yerr=fX.std,
            fmt="+",
            capsize=2,
            c="C3",
            label=f"$(f(x_1), \dots, f(x_{N}))$",
        )
    
    ax.legend()

## Prior

In [None]:
prior_gp = pn.randprocs.GaussianProcess(
    mean=linpde_gp.randprocs.mean_fns.Zero(input_shape=()),
    cov=2.0 ** 2 * linpde_gp.randprocs.kernels.ExpQuad(
        input_shape=(),
        lengthscales=1.0,
    ),
)

In [None]:
plot_belief(
    ax=plt.gca(),
    u=prior_gp,
)

experiment_utils.savefig("prior_00a_u")

## Prior Predictive

In [None]:
Lu = bvp.diffop(prior_gp)
Lu_crosscov = bvp.diffop(prior_gp.cov, argnum=1)

In [None]:
plot_pred_belief(
    ax=plt.gca(),
    Lu=Lu,
)

experiment_utils.savefig("prior_00b_Lu")

## Posterior (PDE First)

### Conditioning on the PDE

In [None]:
u_pde = prior_gp.condition_on_predictive_gp_observations_jax(Lu, Lu_crosscov, X, fX)

In [None]:
plot_belief(
    ax=plt.gca(),
    u=u_pde,
    X_fX=(X, fX),
)

experiment_utils.savefig("pdefirst_01a_u_cond_pde")

### Posterior Predictive

In [None]:
Lu_pde = bvp.diffop(u_pde)

In [None]:
plot_pred_belief(
    ax=plt.gca(),
    Lu=Lu_pde,
    X_fX=(X, fX),
)

experiment_utils.savefig("pdefirst_01b_Lu_cond_pde")

### Conditioning on the Boundary Conditions

In [None]:
u_pde_bc = u_pde.condition_on_observations_jax(
    np.hstack(bvp.domain.boundary),
    g,
)

In [None]:
plot_belief(
    ax=plt.gca(),
    u=u_pde_bc,
    X_fX=(X, fX),
    g=g,
)

experiment_utils.savefig("pdefirst_02a_u_cond_pde_bc")

### Posterior Predictive with PDE and Boundary Conditions

In [None]:
Lu_pde_bc = bvp.diffop(u_pde_bc)

In [None]:
plot_pred_belief(
    ax=plt.gca(),
    Lu=Lu_pde_bc,
    X_fX=(X, fX),
    g=g,
)

experiment_utils.savefig("pdefirst_02b_Lu_cond_pde_bc")

### Complete Plot

In [None]:
for include_bc in [False, True]:
    nrows = 3 if include_bc else 2
    
    rc = config.tueplots_bundle(nrows=nrows, ncols=2)
    rc.update(
        {
            "lines.linewidth": 1
        }
    )
    
    with plt.rc_context(rc):
        fig, ax = plt.subplots(nrows=nrows, ncols=2)

        ax[0, 0].set_title("(a)")

        plot_belief(
            ax=ax[0, 0],
            u=prior_gp,
        )

        ax[0, 1].set_title("(b)")

        plot_pred_belief(
            ax=ax[0, 1],
            Lu=Lu,
        )

        ax[1, 0].set_title("(c)")

        plot_belief(
            ax=ax[1, 0],
            u=u_pde,
            X_fX=(X, fX),
        )

        ax[1, 1].set_title("(d)")

        plot_pred_belief(
            ax=ax[1, 1],
            Lu=Lu_pde,
            X_fX=(X, fX),
        )
        
        if include_bc:
            ax[2, 0].set_title("(e)")

            plot_belief(
                ax=ax[2, 0],
                u=u_pde_bc,
                X_fX=(X, fX),
                g=g,
            )

            ax[2, 1].set_title("(f)")

            plot_pred_belief(
                ax=ax[2, 1],
                Lu=Lu_pde_bc,
                X_fX=(X, fX),
                g=g,
            )

    experiment_utils.savefig("pdefirst" + ("" if include_bc else "_nobc"))

## Posterior (Boundary Values First)

### Conditioning on Boundary Conditions

In [None]:
u_bc = prior_gp.condition_on_observations_jax(
    np.hstack(bvp.domain.boundary),
    g,
)

In [None]:
plot_belief(
    ax=plt.gca(),
    u=u_bc,
    g=g,
)

experiment_utils.savefig("bcfirst_01a_u_cond_bc")

### Predictive Induced by $\Delta$

In [None]:
Lu_bc = bvp.diffop(u_bc)
Lu_bc_crosscov = bvp.diffop(u_bc.cov, argnum=1)

In [None]:
plot_pred_belief(
    ax=plt.gca(),
    Lu=Lu_bc,
    g=g,
)

experiment_utils.savefig("bcfirst_01b_Lu_cond_bc")

### Conditioning on the PDE

In [None]:
u_bc_pde = u_bc.condition_on_predictive_gp_observations_jax(Lu_bc, Lu_bc_crosscov, X, fX)

In [None]:
plot_belief(
    ax=plt.gca(),
    u=u_bc_pde,
    g=g,
    X_fX=(X, fX),
)

experiment_utils.savefig("bcfirst_02a_u_cond_bc_pde")

### Posterior Predictive

In [None]:
Lu_bc_pde = bvp.diffop(u_bc_pde)

In [None]:
plot_pred_belief(
    ax=plt.gca(),
    Lu=Lu_bc_pde,
    g=g,
    X_fX=(X, fX),
)

experiment_utils.savefig("bcfirst_02b_Lu_cond_bc_pde")

### Complete Plot

In [None]:
rc = config.tueplots_bundle(nrows=3, ncols=2)
rc.update(
    {
        "lines.linewidth": 1
    }
)

with plt.rc_context(rc):
    fig, ax = plt.subplots(nrows=3, ncols=2)
    
    ax[0, 0].set_title("(a)")
    
    plot_belief(
        ax=ax[0, 0],
        u=prior_gp,
    )
    
    ax[0, 1].set_title("(b)")
    
    plot_pred_belief(
        ax=ax[0, 1],
        Lu=Lu,
    )
    
    ax[1, 0].set_title("(c)")
    
    plot_belief(
        ax=ax[1, 0],
        u=u_bc,
        g=g,
    )
    
    ax[1, 1].set_title("(d)")
    
    plot_pred_belief(
        ax=ax[1, 1],
        Lu=Lu_bc,
        g=g,
    )
    
    ax[2, 0].set_title("(e)")
    
    plot_belief(
        ax=ax[2, 0],
        u=u_bc_pde,
        g=g,
        X_fX=(X, fX),
    )
    
    ax[2, 1].set_title("(f)")
    
    plot_pred_belief(
        ax=ax[2, 1],
        Lu=Lu_bc_pde,
        g=g,
        X_fX=(X, fX),
    )

experiment_utils.savefig("bcfirst")