In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import linpde_gp

In [None]:
import experiment_utils
from experiment_utils import config

config.experiment_name = "0008_heat_1d"
config.target = "imprs_2022"
config.debug_mode = True

In [None]:
%matplotlib inline

In [None]:
plt.rcParams.update(config.tueplots_bundle())

In [None]:
rng = np.random.default_rng(24)

## Problem Definition

In [None]:
domain = linpde_gp.domains.Box([[0.0, 1.0], [-1.0, 1.0]])

In [None]:
initial = lambda x: 1.0 - x ** 2

In [None]:
bvp = linpde_gp.problems.pde.heat_1d_bvp(
    domain,
    initial,
)

In [None]:
plt_grid_t, plt_grid_x = np.meshgrid(
    np.linspace(*domain[0], 50),
    np.linspace(*domain[1], 50),
    indexing="ij",
)

plt_grid = np.stack((plt_grid_t, plt_grid_x), axis=-1)

## Prior

In [None]:
lengthscale_t = 0.5
lengthscale_x = 1.0
output_scale = 1.0

In [None]:
@jax.jit
def matern_32(x0, x1, lengthscale=1.0):
    dx_norm = jnp.abs((x0 - x1) / lengthscale)
    
    return (1 + jnp.sqrt(3) * dx_norm) * jnp.exp(-jnp.sqrt(3) * dx_norm)

@jax.jit
def matern_52(x0, x1, lengthscale=1.0):
    dx_norm = jnp.abs((x0 - x1) / lengthscale)
    
    return (1 + jnp.sqrt(5) * dx_norm + (5. / 3.) * dx_norm ** 2) * jnp.exp(-jnp.sqrt(5) * dx_norm)

@jax.jit
def product_matern(tx0, tx1):
    t0, t1 = tx0[..., 0], tx1[..., 0]
    x0, x1 = tx0[..., 1], tx1[..., 1]
    
    return (
        matern_32(t0, t1, lengthscale=lengthscale_t)
        * matern_52(x0, x1, lengthscale=lengthscale_x)
    )

cov = linpde_gp.randprocs.kernels.JaxLambdaKernel(product_matern, input_shape=(2,), vectorize=True)

In [None]:
cov = linpde_gp.randprocs.kernels.ExpQuad(
    input_shape=(2,),
    lengthscales=[lengthscale_t, lengthscale_x],
)

In [None]:
u_prior = pn.randprocs.GaussianProcess(
    mean=linpde_gp.randprocs.mean_fns.Zero(input_shape=(2,)),
    cov=output_scale ** 2 * cov,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_prior.mean(plt_grid))

### Dirichlet Problem

In [None]:
N_ic = 3
N_bc = 10

#### Initial conditions

In [None]:
X_ic = np.stack(
    (
        np.zeros(N_ic),
        np.linspace(*domain[1], N_ic),
    ),
    axis=-1,
)
Y_ic = bvp.boundary_conditions[0].values(X_ic[..., 1])

u_ic = u_prior.condition_on_observations(X_ic, Y_ic)

#### Left Boundary Conditions

In [None]:
X_bc_left = np.stack(
    (
        np.linspace(*domain[0], N_bc + 1)[1:],
        np.full(N_bc, domain[1][0]),
    ),
    axis=-1
)
Y_bc_left = np.zeros_like(X_bc_left[..., 0])

u_ic_bc = u_ic.condition_on_observations(X_bc_left, Y_bc_left)

#### Right Boundary Conditions

In [None]:
X_bc_right = np.stack(
    (
        np.linspace(*domain[0], N_bc + 1)[1:],
        np.full(N_bc, domain[1][1]),
    ),
    axis=-1,
)
Y_bc_right = np.zeros_like(X_bc_right[..., 0])

u_ic_bc = u_ic_bc.condition_on_observations(X_bc_right, Y_bc_right)

#### Prior with Initial and Boundary Conditions

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_ic_bc.mean(plt_grid))

### Cauchy Problem

In [None]:
N_ic = 3
N_bc = 10

In [None]:
left_boundary_op = linpde_gp.linfuncops.diffops.DirectionalDerivative(direction=[0., 1.])
right_boundary_op = linpde_gp.linfuncops.diffops.DirectionalDerivative(direction=[0., -1.])

#### Initial Conditions

In [None]:
X_ic = np.stack(
    (
        np.zeros(N_ic),
        np.linspace(*domain[1], N_ic),
    ),
    axis=-1,
)
Y_ic = bvp.boundary_conditions[0].values(X_ic[..., 1])

u_ic = u_prior.condition_on_observations(X_ic, Y_ic)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_ic.mean(plt_grid))

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, right_boundary_op(u_ic).mean(plt_grid))

#### Left Boundary Condition

In [None]:
X_bc_left = np.stack(
    (
        np.linspace(*domain[0], N_bc + 1)[1:],
        np.full(N_bc, domain[1][0]),
    ),
    axis=-1
)
Y_bc_left = np.zeros_like(X_bc_left[..., 0])

In [None]:
u_ic_bc_left = u_ic.condition_on_observations(
    X=X_bc_left,
    Y=Y_bc_left,
    L=left_boundary_op,
)

#### Right Boundary Condition

In [None]:
X_bc_right = np.stack(
    (
        np.linspace(*domain[0], N_bc + 1)[1:],
        np.full(N_bc, domain[1][1]),
    ),
    axis=-1,
)
Y_bc_right = np.zeros_like(X_bc_right[..., 0])

In [None]:
u_ic_bc = u_ic_bc_left.condition_on_observations(
    X=X_bc_right,
    Y=Y_bc_right,
    L=right_boundary_op,
)

#### Plots

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_ic_bc.mean(plt_grid))

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, right_boundary_op(u_ic_bc).mean(plt_grid))

### Conditioning on the PDE

In [None]:
N_pde = (5, 5)

In [None]:
# Before
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, bvp.diffop(u_ic_bc).mean(plt_grid))

In [None]:
X_pde = np.stack(
    np.meshgrid(
        np.linspace(*domain[0], N_pde[0]) + 0.01,
        np.linspace(*domain[1], N_pde[1] + 2)[1:-1],
    ),
    axis=-1,
)
Y_pde = np.zeros_like(X_pde[..., 0])

In [None]:
u_ic_bc_pde = u_ic_bc.condition_on_observations(
    X=X_pde,
    Y=Y_pde,
    L=bvp.diffop,
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_ic_bc_pde.mean(plt_grid))
ax.set_xlabel("Time (s)")
ax.set_ylabel("Location (cm)")
ax.set_zlabel("Temperature (°C)")

experiment_utils.savefig("heat_posterior")

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, bvp.diffop(u_ic_bc_pde).mean(plt_grid))

### Generate Animation

In [None]:
plt.rcParams["font.weight"] = "light"
plt.rcParams["axes.labelweight"] = "light"
plt.rcParams["axes.titleweight"] = "light"

In [None]:
import functools
from matplotlib import animation

fig, ax = plt.subplots()

ax.set_xlabel("Location (cm)")
ax.set_ylabel("Temperature (°C)")
ax.set_title("t = 0.00s")

@functools.partial(
    animation.FuncAnimation,
    fig,
    frames=len(plt_grid_t),
    interval=200,
    repeat_delay=4000,
    blit=False,
)
def anim(frame_idx):
    txs = plt_grid[frame_idx, :, :]
    
    ax.cla()
    
    mean = u_ic_bc_pde.mean(txs)
    std = u_ic_bc_pde.std(txs)
    
    ax.plot(txs[:, 1], mean)
    ax.fill_between(
        txs[:, 1],
        mean - 1.96 * std,
        mean + 1.96 * std,
        alpha=.3,
    )
    ax.set_ylim(-0.01, 1.2)
    ax.set_xlabel("Location (cm)")
    ax.set_ylabel("Temperature (°C) ")
    ax.set_title(f"t = {plt_grid[frame_idx, 0, 0]:.2f} s")


In [None]:
from IPython.display import HTML

HTML(anim.to_jshtml())

In [None]:
anim_path = experiment_utils.config.experiment_results_path / "heat_anim"

if anim_path.is_dir():
    import shutil
    
    shutil.rmtree(anim_path)

anim_path.mkdir(exist_ok=True)

anim.save(anim_path / "{}.pdf", linpde_gp.plotting.PDFWriter())

In [None]:
# anim.save(experiment_utils.config.experiment_results_path / "heat_anim.gif", animation.PillowWriter(fps=5))