In [None]:
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import linpde_gp
from linpde_gp.typing import RandomProcessLike

In [None]:
import experiment_utils
from experiment_utils import config

config.experiment_name = "0001_cpu_stationary_2d_a_simplified"
config.target = "jmlr"
config.debug_mode = True

## Problem Definition

In [None]:
import cpu

domain = cpu.domain_2D

diffop = cpu.diffop_2D

## Plotting

In [None]:
%matplotlib inline

import matplotlib.axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.cm
import matplotlib.colors

plt.rcParams.update(config.tueplots_bundle())


class BeliefPlotter:
    def __init__(
        self,
        domain: linpde_gp.domains.Domain,
        diffop: linpde_gp.linfuncops.LinearDifferentialOperator,
    ):
        self._domain = domain
        self._diffop = diffop

        self._plt_grid_x = np.linspace(*domain[0], 50)
        self._plt_grid_y = np.linspace(*domain[1], 50)

        self._plt_grid = np.stack(
            np.meshgrid(self._plt_grid_x, self._plt_grid_y),
            axis=-1,
        )
    
    def plot_geometry(
        self,
        ax: matplotlib.axes.Axes,
        q_dot_V: RandomProcessLike | None = None,
    ):
        cpu.plot_schematic(ax)

        if q_dot_V is not None:
            q_dot_V = linpde_gp.randprocs.asrandproc(q_dot_V)

            ax.imshow(
                q_dot_V.mean(self._plt_grid),
                cmap="coolwarm",
                norm=matplotlib.colors.TwoSlopeNorm(0.0),
                extent=[0.0, cpu.width, 0.0, cpu.height],
            )
    
    def plot_rhs(
        self,
        ax: matplotlib.axes.Axes,
        q_dot_V: RandomProcessLike,
    ):
        q_dot_V = linpde_gp.randprocs.asrandproc(q_dot_V)

        ax.plot_surface(
            self._plt_grid[..., 0],
            self._plt_grid[..., 1],
            q_dot_V.mean(self._plt_grid),
            cmap="coolwarm",
            norm=matplotlib.colors.TwoSlopeNorm(0.0),
        )

        cpu.adjust_xaxis(ax)
        cpu.adjust_yaxis(ax)
        cpu.adjust_q_dot_V_axis(ax.zaxis)
    
    def plot_rhs_heatmap(
        self,
        ax: matplotlib.axes.Axes,
        q_dot_V: RandomProcessLike,
        colorbar: bool = True,
    ):
        cpu.adjust_xaxis(ax)
        cpu.adjust_yaxis(ax)

        q_dot_V_im = ax.imshow(
            q_dot_V(self._plt_grid),
            cmap="coolwarm",
            norm=matplotlib.colors.TwoSlopeNorm(0.0),
            aspect="auto",
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        if colorbar:
            q_dot_V_im_cm = _add_top_colorbar(ax, q_dot_V_im)
            cpu.adjust_q_dot_V_axis(q_dot_V_im_cm.xaxis)
    
    def plot_q_dot_V_A(
        self,
        ax: matplotlib.axes.Axes,
        q_dot_V: RandomProcessLike,
        q_dot_A_north: RandomProcessLike,
        q_dot_A_south: RandomProcessLike,
        q_dot_A_east: RandomProcessLike,
        q_dot_A_west: RandomProcessLike,
    ):
        pass

    def plot_belief_3D(
        self,
        ax: matplotlib.axes.Axes,
        u: pn.randprocs.GaussianProcess,
    ):
        ax.plot_surface(
            self._plt_grid[..., 0],
            self._plt_grid[..., 1],
            u.mean(self._plt_grid),
            cmap="inferno",
        )

        cpu.adjust_xaxis(ax)
        cpu.adjust_yaxis(ax)
        cpu.adjust_tempaxis(ax.zaxis)
    
    def plot_pred_belief_3D(
        self,
        ax: matplotlib.axes.Axes,
        u: pn.randprocs.GaussianProcess,
    ):
        # Differential Operator Image Belief
        Du = self._diffop(u)

        ax.plot_surface(
            self._plt_grid[..., 0],
            self._plt_grid[..., 1],
            Du.mean(self._plt_grid),
            cmap="coolwarm",
            norm=matplotlib.colors.TwoSlopeNorm(0.0),
        )

        cpu.adjust_xaxis(ax)
        cpu.adjust_yaxis(ax)
        cpu.adjust_q_dot_V_axis(ax.zaxis)
    
    def plot_belief_heatmap(
        self,
        axs: np.ndarray,
        u: pn.randprocs.GaussianProcess,
    ):
        u_mean_im = axs[0].imshow(
            u.mean(self._plt_grid),
            cmap="inferno",
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        u_mean_im_cm = _add_top_colorbar(axs[0], u_mean_im)

        cpu.adjust_xaxis(axs[0])
        cpu.adjust_yaxis(axs[0])
        cpu.adjust_tempaxis(u_mean_im_cm.xaxis)

        u_cred_im = axs[1].imshow(
            1.96 * u.std(self._plt_grid),
            cmap="inferno",
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        u_cred_im_cm = _add_top_colorbar(axs[1], u_cred_im)

        cpu.adjust_xaxis(axs[1])
        cpu.adjust_yaxis(axs[1])
        cpu.adjust_tempaxis(u_cred_im_cm.xaxis)
    
    def plot_pred_belief_heatmap(
        self,
        axs: np.ndarray,
        u: pn.randprocs.GaussianProcess,
    ):
        # Differential Operator Image Belief
        Du = self._diffop(u)

        Du_mean_im = axs[0].imshow(
            Du.mean(self._plt_grid),
            cmap="coolwarm",
            norm=matplotlib.colors.TwoSlopeNorm(0.0),
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        Du_mean_im_cm = _add_top_colorbar(axs[0], Du_mean_im)

        cpu.adjust_xaxis(axs[0])
        cpu.adjust_yaxis(axs[0])
        cpu.adjust_q_dot_V_axis(Du_mean_im_cm.xaxis)

        Du_cred_im = axs[1].imshow(
            1.96 * Du.std(self._plt_grid),
            cmap="inferno",
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        Du_cred_im_cm = _add_top_colorbar(axs[1], Du_cred_im)

        cpu.adjust_xaxis(axs[1])
        cpu.adjust_yaxis(axs[1])
        cpu.adjust_q_dot_V_axis(Du_cred_im_cm.xaxis)

def _add_top_colorbar(
    ax: matplotlib.axes.Axes,
    mappable: matplotlib.cm.ScalarMappable,
) -> matplotlib.axes.Axes:
    cax = make_axes_locatable(ax).append_axes(
        "top",
        size="10%",
        pad=0.1,
    )

    plt.colorbar(
        mappable,
        cax=cax,
        orientation="horizontal",
    )

    # This must be placed after the colorbar has been drawn
    cax.xaxis.tick_top()
    cax.xaxis.set_label_position("top")

    return cax

plotter = BeliefPlotter(domain, diffop)

## Simplified Model

## Prior

In [None]:
# Solution
u = pn.randprocs.GaussianProcess(
    mean=linpde_gp.functions.Constant(input_shape=(2,), value=59.0),
    cov=3.0 ** 2 * linpde_gp.randprocs.kernels.ProductMatern(
        input_shape=(2,),
        p=3,
        lengthscales=[cpu.width / 2.0, cpu.height / 2.0],
    ),
)

# Volumetric (Interior) Heat Sources and Sinks
q_dot_V = linpde_gp.randprocs.asrandproc(cpu.q_dot_V_2D)

# Boundary Heat Flux
q_dot_A = linpde_gp.randprocs.asrandproc(cpu.q_dot_A_2D)

### Visualize Problem Geometry

In [None]:
plotter.plot_geometry(
    ax=plt.gca(),
    q_dot_V=q_dot_V,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_rhs(
    ax,
    q_dot_V=q_dot_V,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_belief_3D(
    ax,
    u=u,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_pred_belief_3D(
    ax,
    u=u,
)

In [None]:
fig, axs = plt.subplots(ncols=2)

plotter.plot_belief_heatmap(
    axs,
    u=u,
)

In [None]:
fig, axs = plt.subplots(ncols=2)

plotter.plot_pred_belief_heatmap(
    axs,
    u=u,
)

## Conditioning on the PDE

In [None]:
N_pde = 15

X_pde = domain.uniform_grid(
    (N_pde, N_pde),
    inset=(0.03 * cpu.width, 0.03 * cpu.height)
)

In [None]:
u_cond_pde = u.condition_on_observations(
    Y=np.zeros_like(X_pde, shape=X_pde.shape[:-1]),
    X=X_pde,
    L=diffop,
    b=-q_dot_V(X_pde),
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_belief_3D(
    ax,
    u=u_cond_pde,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_pred_belief_3D(
    ax,
    u=u_cond_pde,
)

In [None]:
fig, axs = plt.subplots(ncols=2)

plotter.plot_belief_heatmap(
    axs,
    u=u_cond_pde,
)

In [None]:
fig, axs = plt.subplots(ncols=2)

plotter.plot_pred_belief_heatmap(
    axs,
    u=u_cond_pde,
)

## Conditioning on Neumann Boundary Conditions

In [None]:
N_nbc = (10, 8)

In [None]:
# North Boundary
X_nbc_north = domain[0].uniform_grid(N_nbc[0])

u_cond_pde_nbc = u_cond_pde.condition_on_observations(
    Y=np.zeros_like(X_nbc_north),
    X=np.stack(
        (
            X_nbc_north,
            np.full_like(X_nbc_north, domain[1][1]),
        ),
        axis=-1,
    ),
    L=-cpu.kappa * linpde_gp.linfuncops.diffops.DirectionalDerivative([0.0, -1.0]),
    b=-q_dot_A(X_nbc_north)
)

In [None]:
# East Boundary
X_nbc_east = domain[1].uniform_grid(N_nbc[1])

u_cond_pde_nbc = u_cond_pde_nbc.condition_on_observations(
    Y=np.zeros_like(X_nbc_east),
    X=np.stack(
        (
            np.full_like(X_nbc_east, domain[0][1]),
            X_nbc_east,
        ),
        axis=-1,
    ),
    L=-cpu.kappa * linpde_gp.linfuncops.diffops.DirectionalDerivative([-1.0, 0.0]),
    b=-q_dot_A(cpu.width + X_nbc_east)
)

In [None]:
# South Boundary
X_nbc_south = domain[0].uniform_grid(N_nbc[0])

u_cond_pde_nbc = u_cond_pde_nbc.condition_on_observations(
    Y=np.zeros_like(X_nbc_south),
    X=np.stack(
        (
            X_nbc_south,
            np.full_like(X_nbc_south, domain[1][0]),
        ),
        axis=-1,
    ),
    L=-cpu.kappa * linpde_gp.linfuncops.diffops.DirectionalDerivative([0.0, 1.0]),
    b=-q_dot_A(cpu.width + cpu.height + X_nbc_south)
)

In [None]:
# West Boundary
X_nbc_west = domain[1].uniform_grid(N_nbc[1])

u_cond_pde_nbc = u_cond_pde_nbc.condition_on_observations(
    Y=np.zeros_like(X_nbc_west),
    X=np.stack(
        (
            np.full_like(X_nbc_west, domain[0][0]),
            X_nbc_west,
        ),
        axis=-1,
    ),
    L=-cpu.kappa * linpde_gp.linfuncops.diffops.DirectionalDerivative([-1.0, 0.0]),
    b=-q_dot_A(cpu.width + cpu.height + cpu.width + X_nbc_west)
)

In [None]:
from mpl_toolkits import mplot3d

from probnum.typing import ArrayLike


def plot_2d_gp(
    f: pn.randprocs.GaussianProcess,
    ax: mplot3d.Axes3D,
    xy: np.ndarray,
    /,
    *,
    cred_int_slice_xs: ArrayLike | None = None,
    cred_int_slice_ys: ArrayLike | None = None,
    cred_int_slice_axis: str = "x",
    mean_zorder: int = 2,
    cred_int_slice_lower_zorder: int = 1,
    cred_int_slice_upper_zorder: int = 3,
    **kwargs,
):
    # Plot mean function
    mean_xy = f.mean(xy)

    ax.plot_surface(
        xy[..., 0],
        xy[..., 1],
        mean_xy,
        zorder=mean_zorder,
        **kwargs,
    )

    # Plot slice of marginal credible interval
    if cred_int_slice_xs is not None and cred_int_slice_ys is not None:
        cred_int_slice_xys = np.stack(
            np.broadcast_arrays(cred_int_slice_xs, cred_int_slice_ys),
            axis=-1,
        )

        cred_int_slice_mean = f.mean(cred_int_slice_xys)
        cred_int_slice_std = f.std(cred_int_slice_xys)

        linpde_gp.utils.plotting.fill_between_3d(
            ax,
            cred_int_slice_xs,
            cred_int_slice_ys,
            cred_int_slice_mean - 1.96 * cred_int_slice_std,
            cred_int_slice_mean,
            axis=cred_int_slice_axis,
            zorder=cred_int_slice_lower_zorder,
            alpha=0.5,
        )

        linpde_gp.utils.plotting.fill_between_3d(
            ax,
            cred_int_slice_xs,
            cred_int_slice_ys,
            cred_int_slice_mean,
            cred_int_slice_mean + 1.96 * cred_int_slice_std,
            axis=cred_int_slice_axis,
            zorder=cred_int_slice_upper_zorder,
            alpha=0.5,
        )


fig, ax = plt.subplots(
    subplot_kw={
        "projection": "3d",
        "computed_zorder": False,
    }
)

xy = np.stack(
    (
        plotter._plt_grid_x,
        np.full_like(plotter._plt_grid_x, cpu.core_centers_ys[1]),
    ),
    axis=-1,
)

mean = u_cond_pde_nbc.mean(xy)
std = u_cond_pde_nbc.std(xy)

linpde_gp.utils.plotting.fill_between_3d(
    ax,
    plotter._plt_grid_x,
    cpu.core_centers_ys[1],
    np.full_like(mean, 52.5),
    mean - 1.96 * std,
    alpha=0.8,
    color="#AAAAAA",
)

plot_2d_gp(
    u_cond_pde_nbc,
    ax,
    plotter._plt_grid,
    cmap="inferno",
    cred_int_slice_xs=plotter._plt_grid_x,
    cred_int_slice_ys=cpu.core_centers_ys[1],

)

linpde_gp.utils.plotting.fill_between_3d(
    ax,
    plotter._plt_grid_x,
    cpu.core_centers_ys[1],
    mean + 1.96 * std,
    np.full_like(mean, 64),
    alpha=0.8,
    color="#AAAAAA",
)

ax.set_zlim(52.5, 64.0)

In [None]:
fig, axs = plt.subplots(ncols=2, sharey=True)

plotter.plot_belief_heatmap(
    axs,
    u=u_cond_pde_nbc,
)

## Conditioning on Measurements

In [None]:
u_cond_pde_nbc_dts = u_cond_pde_nbc.condition_on_observations(
    60.0,
    X=[cpu.core_centers_xs[0], cpu.core_centers_ys[0]],
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

plotter.plot_belief_3D(
    ax,
    u=u_cond_pde_nbc_dts,
)

In [None]:
fig, axs = plt.subplots(ncols=2, sharey=True)

plotter.plot_belief_heatmap(
    axs,
    u=u_cond_pde_nbc_dts,
)

In [None]:
sample = u_cond_pde_nbc_dts.sample(
    np.random.default_rng(1234),
    plotter._plt_grid.reshape(-1, 2),
).reshape(50, 50)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.plot_surface(
    plotter._plt_grid[..., 0],
    plotter._plt_grid[..., 1],
    sample,
    cmap="inferno",
)

## Full Model

### Prior

In [None]:
# Solution
u = pn.randprocs.GaussianProcess(
    mean=linpde_gp.functions.Constant(input_shape=(2,), value=59.0),
    cov=3.0 ** 2 * linpde_gp.randprocs.kernels.ProductMatern(
        input_shape=(2,),
        p=3,
        lengthscales=[cpu.width / 2.0, cpu.height / 2.0],
    ),
)

# Volumetric (Interior) Heat Sources and Sinks
q_dot_V = pn.randprocs.GaussianProcess(
    mean=cpu.q_dot_V_2D,
    cov=3.0 ** 2 * linpde_gp.randprocs.kernels.ProductMatern(
        input_shape=(2,),
        p=3,
        lengthscales=[cpu.width / 2.0, cpu.height / 2.0],
    ),
)

# Boundary Heat Flux
q_dot_A = pn.randprocs.GaussianProcess(
    mean=cpu.q_dot_A_2D,
    cov=0.1 ** 2 * linpde_gp.randprocs.kernels.Matern(
        input_shape=(),
        p=3,
        lengthscale=(cpu.width + cpu.height) / 2.0 / 2.0,
    )
)

In [None]:
with plt.rc_context(config.tueplots_bundle(nrows=4, ncols=3)):
    fig, ax = plt.subplots(
        3, 3,
        sharex="col",
        sharey="row",
        gridspec_kw={
            "width_ratios": (1, 2, 1),
        },
    )

    ax[0, 0].remove()
    ax[0, 2].remove()
    ax[2, 0].remove()
    ax[2, 2].remove()

    ax[0, 1].yaxis.set_tick_params(which="both", labelleft=True)
    ax[2, 1].yaxis.set_tick_params(which="both", labelleft=True)

    ax[1, 0].xaxis.set_tick_params(which="both", labelbottom=True)
    ax[1, 2].xaxis.set_tick_params(which="both", labelbottom=True)

    plotter.plot_rhs_heatmap(ax[1, 1], q_dot_V, colorbar=False)

    q_dot_A_north_prior.plot(ax[0, 1], plotter._plt_grid_x)
    q_dot_A_south_prior.plot(ax[2, 1], plotter._plt_grid_x)
    q_dot_A_west_prior.plot(ax[1, 0], plotter._plt_grid_y, vertical=True)