In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn

import linpde_gp

In [None]:
import notebook_utils

notebook_utils.config.notebook_name = "0008_heat_1d"
notebook_utils.config.debug_mode = True

In [None]:
%matplotlib inline

In [None]:
rng = np.random.default_rng(24)

## Problem Definition

In [None]:
from linpde_gp.problems.pde import domains, heat_1d_bvp

In [None]:
domain = domains.Box([[0.0, 1.0], [-1.0, 1.0]])

In [None]:
initial = lambda x: 1.0 - x ** 2

In [None]:
bvp = heat_1d_bvp(
    domain,
    initial,
)

In [None]:
plt_grid_t, plt_grid_x = np.meshgrid(
    np.linspace(*domain[0], 50),
    np.linspace(*domain[1], 50),
    indexing="ij",
)

plt_grid = np.stack((plt_grid_t, plt_grid_x), axis=-1)

## Prior

In [None]:
@jax.jit
def matern_32(x0, x1, lengthscale=1.0):
    dx_norm = jnp.abs((x0 - x1) / lengthscale)
    
    return (1 + jnp.sqrt(3) * dx_norm) * jnp.exp(-jnp.sqrt(3) * dx_norm)

@jax.jit
def matern_52(x0, x1, lengthscale=1.0):
    dx_norm = jnp.abs((x0 - x1) / lengthscale)
    
    return (1 + jnp.sqrt(5) * dx_norm + (5. / 3.) * dx_norm ** 2) * jnp.exp(-jnp.sqrt(5) * dx_norm)

In [None]:
lengthscale_t = 0.5
lengthscale_x = 1.0
output_scale = 1.0

@jax.jit
def prior_mean(x):
    return jnp.full_like(x[..., 0], 0.0)
#     return -0.5 * x[..., 0] ** 2 + 0.5
#     return jnp.sin(jnp.pi * x)

@jax.jit
def prior_cov(tx0, tx1):
    dtx = tx0 - tx1
    dt, dx = dtx[..., 0], dtx[..., 1:]
    
    dt_sqnorm = (dt / lengthscale_t) ** 2
    dx_sqnorm = jnp.sum((dx / lengthscale_x) ** 2, axis=-1)

    return output_scale ** 2 * jnp.exp(-0.5 * (dt_sqnorm + dx_sqnorm))

#     t0, t1 = tx0[..., 0], tx1[..., 0]
#     x0, x1 = tx0[..., 1], tx1[..., 1]
    
#     return output_scale ** 2 * matern_32(t0, t1, lengthscale=lengthscale_t) * matern_52(x0, x1, lengthscale=lengthscale_x)

prior_gp = pn.randprocs.GaussianProcess(
    mean=linpde_gp.randprocs.mean_fns.JaxMean(prior_mean, vectorize=False),
    cov=linpde_gp.randprocs.kernels.JaxKernel(prior_cov, input_dim=2, vectorize=True),
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, prior_mean(plt_grid))

### Dirichlet Problem

In [None]:
ic_obs_locs = np.stack(
    (
        np.zeros(3),
        np.linspace(*domain[1], 3),
    ),
    axis=-1,
)
ic_obs = bvp.boundary_conditions[0].values(ic_obs_locs[..., 1])

left_bc_obs_locs = np.stack(
    (
        np.linspace(*domain[0], 11)[1:],
        np.full(10, domain[1][0]),
    ),
    axis=-1
)
left_bc_obs = np.zeros_like(left_bc_obs_locs[..., 0])

right_bc_obs_locs = np.stack(
    (
        np.linspace(*domain[0], 11)[1:],
        np.full(10, domain[1][1]),
    ),
    axis=-1,
)
right_bc_obs = np.zeros_like(right_bc_obs_locs[..., 0])

In [None]:
bc_obs_locs = np.concatenate((ic_obs_locs, left_bc_obs_locs, right_bc_obs_locs), axis=0)
bc_obs = np.concatenate((ic_obs, left_bc_obs, right_bc_obs))

u_bc = prior_gp.condition_on_observations_jax(
    bc_obs_locs,
    pn.randvars.Normal(
        mean=bc_obs,
        cov=np.diag(np.zeros_like(bc_obs)),
    ),
)

In [None]:
# u_bc = linpde_gp.randprocs.PosteriorGaussianProcess.from_measurements(prior_gp, ic_obs_locs, ic_obs)

In [None]:
# u_bc = u_bc.condition_on_observations(left_bc_obs_locs, left_bc_obs)

In [None]:
# u_bc = u_bc.condition_on_observations(right_bc_obs_locs, right_bc_obs)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_bc.mean(plt_grid))

### Cauchy Problem

#### Conditioning on Initial Conditions

In [None]:
ic_obs_locs = np.stack(
    (
        np.zeros(3),
        np.linspace(*domain[1], 3),
    ),
    axis=-1,
)
ic_obs = bvp.boundary_conditions[0].values(ic_obs_locs[..., 1])

In [None]:
u_ic = prior_gp.condition_on_observations_jax(
    ic_obs_locs,
    pn.randvars.Normal(
        mean=ic_obs,
        cov=np.diag(np.zeros_like(ic_obs)),
    ),
)

In [None]:
# u_ic = linpde_gp.randprocs.PosteriorGaussianProcess.from_measurements(prior_gp, ic_obs_locs, ic_obs)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_ic.mean(plt_grid))

#### Conditioning on Spatial Boundary Conditions

In [None]:
du_ic, du_ic_crosscov = u_ic.apply_jax_linop(linpde_gp.problems.pde.diffops.DirectionalDerivative(1.))

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, du_ic.mean(plt_grid))

In [None]:
n_bc_locs = 5

left_bc_obs_locs = np.stack(
    (
        np.linspace(*domain[0], n_bc_locs + 1)[1:],
        np.full(n_bc_locs, domain[1][0]),
    ),
    axis=-1
)
left_bc_obs = np.zeros_like(left_bc_obs_locs[..., 0])

right_bc_obs_locs = np.stack(
    (
        np.linspace(*domain[0], n_bc_locs + 1)[1:],
        np.full(n_bc_locs, domain[1][1]),
    ),
    axis=-1,
)
right_bc_obs = np.zeros_like(right_bc_obs_locs[..., 0])

bc_obs_locs = np.concatenate((left_bc_obs_locs, right_bc_obs_locs), axis=0)
bc_obs = np.concatenate((left_bc_obs, right_bc_obs), axis=0)

u_bc = u_ic.condition_on_predictive_gp_observations_jax(
    du_ic,
    du_ic_crosscov,
    bc_obs_locs,
    pn.asrandvar(bc_obs),
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_bc.mean(plt_grid))

In [None]:
du_bc, _ = u_bc.apply_jax_linop(linpde_gp.problems.pde.diffops.DirectionalDerivative(1.))

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, du_bc.mean(plt_grid))

### Predictive Induced by $\mathcal{L}$

In [None]:
Lu_bc, Lu_bc_crosscov = u_bc.apply_jax_linop(bvp.diffop)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, Lu_bc.mean(plt_grid))

### Conditioning on the PDE

In [None]:
pde_loc = np.stack(
    np.meshgrid(
        np.linspace(*domain[0], 5) + 0.01,
        np.linspace(*domain[1], 5 + 2)[1:-1],
    ),
    axis=-1,
).reshape(-1, 2)

u_bc_pde = u_bc.condition_on_predictive_gp_observations_jax(
    Lu_bc,
    Lu_bc_crosscov,
    pde_loc,
    pn.asrandvar(np.zeros_like(pde_loc[:, 1])),
)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, u_bc_pde.mean(plt_grid))
ax.set_xlabel("Time (s)")
ax.set_ylabel("Location (cm)")
ax.set_zlabel("Temperature (°C)")

notebook_utils.savefig("heat_posterior")

In [None]:
Lu_bc_pde, _ = u_bc_pde.apply_jax_linop(bvp.diffop)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

ax.plot_surface(plt_grid_t, plt_grid_x, Lu_bc_pde.mean(plt_grid))

### Generate Animation

In [None]:
import functools
from matplotlib import animation

fig, ax = plt.subplots()

@functools.partial(
    animation.FuncAnimation,
    fig,
    frames=len(plt_grid_t),
    interval=200,
    repeat_delay=4000,
    blit=False,
)
def anim(frame_idx):
    txs = plt_grid[frame_idx, :, :]
    
    ax.cla()
    
    mean = u_bc_pde.mean(txs)
    std = u_bc_pde.std(txs)
    
    ax.plot(txs[:, 1], mean)
    ax.fill_between(
        txs[:, 1],
        mean - 1.96 * std,
        mean + 1.96 * std,
        alpha=.3,
    )
    ax.set_ylim(-0.01, 1.2)
    ax.set_xlabel("Location (cm)")
    ax.set_ylabel("Temperature (°C)")
    ax.set_title(f"$t = {plt_grid[frame_idx, 0, 0]:.2f} s$")


In [None]:
from IPython.display import HTML

HTML(anim.to_jshtml())

In [None]:
anim_path = notebook_utils.config.notebook_results_path / "heat_anim"

if anim_path.is_dir():
    import shutil
    
    shutil.rmtree(anim_path)

anim_path.mkdir(exist_ok=True)

anim.save(anim_path / "{}.pdf", linpde_gp.plotting.PDFWriter())

In [None]:
# anim.save(notebook_utils.config.notebook_results_path / "heat_anim.gif", animation.PillowWriter(fps=5))