In [None]:
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import linpde_gp

In [None]:
import experiment_utils
from experiment_utils import config

config.experiment_name = "0001_cpu_stationary_2d"
config.target = "jmlr"
config.debug_mode = True

In [None]:
import cpu

## Plotting

In [None]:
%matplotlib inline

import matplotlib.axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.cm

plt.rcParams.update(config.tueplots_bundle())


class BeliefPlotter:
    def __init__(
        self,
        domain: linpde_gp.domains.Domain,
        diffop: linpde_gp.linfuncops.LinearDifferentialOperator,
    ):
        self._domain = domain
        self._diffop = diffop

        self._plt_grid_x = np.linspace(*domain[0], 100),
        self._plt_grid_y = np.linspace(*domain[1], 100),

        self._plt_grid = np.stack(
            np.meshgrid(self._plt_grid_x, self._plt_grid_y),
            axis=-1,
        )
    
    def plot_geometry(
        self,
        ax: matplotlib.axes.Axes,
        q_dot_V: pn.Function | pn.randprocs.RandomProcess,
    ):
        cpu.plot_schematic(ax)

        ax.imshow(
            q_dot_V(self._plt_grid),
            cmap="coolwarm",
            extent=[0.0, cpu.width, 0.0, cpu.height],
            aspect="auto",
        )
    
    def plot_rhs(
        self,
        ax: matplotlib.axes.Axes,
        q_dot_V: pn.Function | pn.randprocs.RandomProcess,
    ):
        ax.plot_surface(
            self._plt_grid[..., 0],
            self._plt_grid[..., 1],
            q_dot_V(self._plt_grid),
            cmap="coolwarm",
        )

        cpu.adjust_xaxis(ax)
        cpu.adjust_yaxis(ax)
        cpu.adjust_q_dot_V_axis(ax.zaxis)
    
    def plot_belief_3D(
        self,
        ax: matplotlib.axes.Axes,
        u: pn.randprocs.GaussianProcess,
    ):
        ax.plot_surface(
            self._plt_grid[..., 0],
            self._plt_grid[..., 1],
            u.mean(self._plt_grid),
            cmap="inferno",
        )

        cpu.adjust_xaxis(ax)
        cpu.adjust_yaxis(ax)
        cpu.adjust_tempaxis(ax.zaxis)
    
    def plot_pred_belief_3D(
        self,
        ax: matplotlib.axes.Axes,
        u: pn.randprocs.GaussianProcess,
    ):
        # Differential Operator Image Belief
        Du = self._diffop(u)

        ax.plot_surface(
            self._plt_grid[..., 0],
            self._plt_grid[..., 1],
            Du.mean(self._plt_grid),
            cmap="coolwarm",
        )

        cpu.adjust_xaxis(ax)
        cpu.adjust_yaxis(ax)
        cpu.adjust_q_dot_V_axis(ax.zaxis)
    
    def plot_belief_heatmap(
        self,
        axs: np.ndarray,
        u: pn.randprocs.GaussianProcess,
    ):
        u_mean_im = axs[0].imshow(
            u.mean(self._plt_grid),
            cmap="inferno",
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        u_mean_im_cm = _add_top_colorbar(axs[0], u_mean_im)

        cpu.adjust_xaxis(axs[0])
        cpu.adjust_yaxis(axs[0])
        cpu.adjust_tempaxis(u_mean_im_cm.xaxis)

        u_cred_im = axs[1].imshow(
            1.96 * u.std(self._plt_grid),
            cmap="inferno",
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        u_cred_im_cm = _add_top_colorbar(axs[1], u_cred_im)

        cpu.adjust_xaxis(axs[1])
        cpu.adjust_yaxis(axs[1])
        cpu.adjust_tempaxis(u_cred_im_cm.xaxis)
    
    def plot_pred_belief_heatmap(
        self,
        axs: np.ndarray,
        u: pn.randprocs.GaussianProcess,
    ):
        # Differential Operator Image Belief
        Du = self._diffop(u)

        Du_mean_im = axs[0].imshow(
            Du.mean(self._plt_grid),
            cmap="coolwarm",
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        Du_mean_im_cm = _add_top_colorbar(axs[0], Du_mean_im)

        cpu.adjust_xaxis(axs[0])
        cpu.adjust_yaxis(axs[0])
        cpu.adjust_q_dot_V_axis(Du_mean_im_cm.xaxis)

        Du_cred_im = axs[1].imshow(
            1.96 * Du.std(self._plt_grid),
            cmap="coolwarm",
            extent=[0.0, cpu.width, 0.0, cpu.height],
        )

        Du_cred_im_cm = _add_top_colorbar(axs[1], Du_cred_im)

        cpu.adjust_xaxis(axs[1])
        cpu.adjust_yaxis(axs[1])
        cpu.adjust_q_dot_V_axis(Du_cred_im_cm.xaxis)

def _add_top_colorbar(
    ax: matplotlib.axes.Axes,
    mappable: matplotlib.cm.ScalarMappable,
) -> matplotlib.axes.Axes:
    cax = make_axes_locatable(ax).append_axes(
        "top",
        size="10%",
        pad=0.1,
    )

    plt.colorbar(
        mappable,
        cax=cax,
        orientation="horizontal",
    )

    # This must be placed after the colorbar has been drawn
    cax.xaxis.tick_top()
    cax.xaxis.set_label_position("top")

    return cax

## Problem Definition

In [None]:
domain = cpu.domain[0:2]

In [None]:
diffop = cpu.diffop_2D
q_dot_V = cpu.q_dot_V_2D

In [None]:
plotter = BeliefPlotter(domain, diffop)

### Visualize Problem Geometry

In [None]:
plotter.plot_geometry(
    ax=plt.gca(),
    q_dot_V=q_dot_V,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_rhs(
    ax,
    q_dot_V=q_dot_V,
)

## Prior

In [None]:
u_prior = pn.randprocs.GaussianProcess(
    mean=linpde_gp.functions.Constant(input_shape=(2,), value=59.0),
    cov=3.0 ** 2 * linpde_gp.randprocs.kernels.ProductMatern(
        input_shape=(2,),
        p=3,
        lengthscales=[cpu.width / 2.0, cpu.height / 2.0],
    ),
)

In [None]:
q_dot_A_north_prior = pn.randprocs.GaussianProcess(
    mean=cpu.q_dot_A_2D,
    cov=0.1 ** 2 * linpde_gp.randprocs.kernels.Matern(
        input_shape=(),
        p=3,
        lengthscale=cpu.width / 2.0,
    )
)

q_dot_A_south_prior = q_dot_A_north_prior

q_dot_A_east_prior = pn.randprocs.GaussianProcess(
    mean=cpu.q_dot_A_2D,
    cov=0.1 ** 2 * linpde_gp.randprocs.kernels.Matern(
        input_shape=(),
        p=3,
        lengthscale=cpu.height / 2.0,
    )
)

q_dot_A_west_prior = q_dot_A_east_prior

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_belief_3D(
    ax,
    u=u_prior,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_pred_belief_3D(
    ax,
    u=u_prior,
)

In [None]:
fig, axs = plt.subplots(ncols=2)

plotter.plot_belief_heatmap(
    axs,
    u=u_prior,
)

In [None]:
fig, axs = plt.subplots(ncols=2)

plotter.plot_pred_belief_heatmap(
    axs,
    u=u_prior,
)

## Conditioning on the PDE

In [None]:
N_pde = 15

X_pde = np.stack(
    np.meshgrid(
        np.linspace(
            domain[0][0] + 0.03 * cpu.width,
            domain[0][1] - 0.03 * cpu.width,
            N_pde,
        ),
        np.linspace(
            domain[1][0] + 0.03 * cpu.height,
            domain[1][1] - 0.03 * cpu.height,
            N_pde,
        ),
    ),
    axis=-1,
)

In [None]:
u_cond_pde = u_prior.condition_on_observations(
    Y=np.zeros_like(X_pde, shape=X_pde.shape[:-1]),
    X=X_pde,
    L=diffop,
    b=-q_dot_V(X_pde),
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_belief_3D(
    ax,
    u=u_cond_pde,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plotter.plot_pred_belief_3D(
    ax,
    u=u_cond_pde,
)

In [None]:
fig, axs = plt.subplots(ncols=2)

plotter.plot_belief_heatmap(
    axs,
    u=u_cond_pde,
)

In [None]:
fig, axs = plt.subplots(ncols=2)

plotter.plot_pred_belief_heatmap(
    axs,
    u=u_cond_pde,
)

## Conditioning on Neumann Boundary Conditions

In [None]:
q_A = cpu.TDP / cpu.A_sink

In [None]:
X_bc = np.stack(
    (
        np.linspace(*domain[0], 10),
        np.full((10,), domain[1][0]),
    ),
    axis=-1,
)

u_cond_pde_bc = u_cond_pde.condition_on_observations(
    Y=np.zeros_like(X_bc, shape=X_bc.shape[:-1]) + q_A,
    X=X_bc,
    L=-cpu.kappa * linpde_gp.linfuncops.diffops.DirectionalDerivative([0.0, -1.0]),
)

X_bc = np.stack(
    (
        np.linspace(*domain[0], 10),
        np.full((10,), domain[1][1]),
    ),
    axis=-1,
)

u_cond_pde_bc = u_cond_pde_bc.condition_on_observations(
    Y=np.zeros_like(X_bc, shape=X_bc.shape[:-1]) + q_A,
    X=X_bc,
    L=-cpu.kappa * linpde_gp.linfuncops.diffops.DirectionalDerivative([0.0, 1.0]),
)

X_bc = np.stack(
    (
        np.full((10,), domain[0][0]),
        np.linspace(*domain[1], 10),
    ),
    axis=-1,
)

u_cond_pde_bc = u_cond_pde_bc.condition_on_observations(
    Y=np.zeros_like(X_bc, shape=X_bc.shape[:-1]) + q_A,
    X=X_bc,
    L=-cpu.kappa * linpde_gp.linfuncops.diffops.DirectionalDerivative([-1.0, 0.0]),
)

X_bc = np.stack(
    (
        np.full((10,), domain[0][1]),
        np.linspace(*domain[1], 10),
    ),
    axis=-1,
)

u_cond_pde_bc = u_cond_pde_bc.condition_on_observations(
    Y=np.zeros_like(X_bc, shape=X_bc.shape[:-1]) + q_A,
    X=X_bc,
    L=-cpu.kappa * linpde_gp.linfuncops.diffops.DirectionalDerivative([1.0, 0.0]),
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

plotter.plot_belief_3D(
    ax,
    u=u_cond_pde_bc,
)

In [None]:
fig, axs = plt.subplots(ncols=2, sharey=True)

plotter.plot_belief_heatmap(
    axs,
    u=u_cond_pde_bc,
)

## Conditioning on Measurements

In [None]:
u_cond_pde_bc_dts = u_cond_pde_bc.condition_on_observations(
    60.0,
    X=[cpu.core_centers_xs[0], cpu.core_centers_ys[0]],
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

plotter.plot_belief_3D(
    ax,
    u=u_cond_pde_bc,
)

In [None]:
fig, axs = plt.subplots(ncols=2, sharey=True)

plotter.plot_belief_heatmap(
    axs,
    u=u_cond_pde_bc_dts,
)

In [None]:
u_cond_pde_bc_dts.sample(np.random.default_rng(), plotter._plt_grid.reshape(-1, 2)[:10000, :])