# Heat equation with NGD
Consider the following PDE on $\Omega=[0,1]^2$:
$$
\begin{cases}
\frac{\partial}{\partial t}u(t,x,y)-D \Delta u(t,x,y) = 0,\\
u(0, x, y)=\sin\pi x\sin\pi y
\end{cases}
$$

The solution is (TO CHECK) $u^\star(t,x) = e^{-2D\pi^2 t}\sin\pi x\sin\pi y$

In [None]:
import jinns

In [None]:
import jax
from jax import random, vmap
import jax.numpy as jnp
import equinox as eqx
import matplotlib.pyplot as plt

key = random.PRNGKey(2)
key, subkey = random.split(key)

In [None]:
n = 10000
ni = n
nb = n
domain_batch_size = 400
initial_batch_size = domain_batch_size
border_batch_size = domain_batch_size // 4
dim = 2
xmin = -1
xmax = 1
tmin = 0
tmax = 1
Tmax = 1
method = "uniform"

train_data = jinns.data.CubicMeshPDENonStatio(
    key=subkey,
    n=n,
    nb=nb,
    ni=ni,
    domain_batch_size=domain_batch_size,
    border_batch_size=border_batch_size,
    initial_batch_size=initial_batch_size,
    dim=dim,
    min_pts=(xmin, xmin),
    max_pts=(xmax, xmax),
    tmin=tmin,
    tmax=tmax,
    method=method,
)

In [None]:
eqx_list = (
    (eqx.nn.Linear, 3, 25),  # 3 = t + x (2D)
    (jax.nn.tanh,),
    (eqx.nn.Linear, 25, 25),
    (jax.nn.tanh,),
    (eqx.nn.Linear, 25, 1),
)

key, subkey = random.split(key)
u, init_sol_nn_params = jinns.nn.PINN_MLP.create(
    key=subkey, eqx_list=eqx_list, eq_type="PDENonStatio"
)

In [None]:
D = jnp.array(0.2)
init_params = jinns.parameters.Params(
    nn_params=init_sol_nn_params,
    eq_params={"D": D},
)

In [None]:
from jinns.loss import PDENonStatio


class HeatEquation(PDENonStatio):
    def equation(self, t_x, u, params):
        u_t_x = lambda t_x: u(t_x, params).squeeze()
        u_dt = jax.grad(u_t_x)(t_x)[0:1]
        lap = jinns.loss.laplacian_rev(t_x, u, params, eq_type="PDENonStatio")
        return u_dt - params.eq_params.D * lap


dyn_loss_heat = HeatEquation()

In [None]:
boundary_condition = jinns.loss.Dirichlet()

loss_weights = jinns.loss.LossWeightsPDENonStatio(
    dyn_loss=jnp.array(1.0),
    initial_condition=jnp.array(1.0),
    boundary_loss=None if boundary_condition is None else jnp.array(1.0),
)

In [None]:
def u0(x):
    return jnp.sin(jnp.pi * x[0]) * jnp.sin(jnp.pi * x[1])

In [None]:
loss = jinns.loss.LossPDENonStatio(
    u=u,
    loss_weights=loss_weights,
    dynamic_loss=dyn_loss_heat,
    initial_condition_fun=u0,
    boundary_condition=boundary_condition,
    params=init_params,
)

In [None]:
# Testing the loss function
losses_and_grad = jax.value_and_grad(loss.evaluate, 0, has_aux=True)
_, colloc_batch = train_data.get_batch()

std_grad = losses_and_grad(init_params, batch=colloc_batch)[1]

# True solution

In [None]:
def u_true(t_x):
    t, x = t_x[0], t_x[1:]
    return jnp.exp(-2 * D * t * jnp.pi**2) * u0(x)


# Sanity check of true u formula
txy = jax.random.uniform(key, shape=(100, 3))
true_res = vmap(dyn_loss_heat.equation, (0, None, None))(
    txy, lambda tx, p: u_true(tx), init_params
)
assert jnp.allclose(true_res, jnp.zeros(100), atol=1e-6)

In [None]:
figsize = (7, 7)
nx = 200
val_xydata = (jnp.linspace(xmin, xmax, nx), jnp.linspace(xmin, xmax, nx))
times = [0, 0.2, 0.6, 1]  # times in the rescaled time scale

jinns.plot.plot2d(
    u_true,
    xy_data=val_xydata,
    times=times,
    cmap="viridis",
    figsize=figsize,
    vmin_vmax=(-1, 1),
    title=r"Ground truth : $u^\star(t, x)$",
)
plt.suptitle(r"Ground truth : $u^\star(t, x)$")

# Run natural gradient descent

In [None]:
import optax
from jinns.optimizers import vanilla_ngd

n_iter = 100
ngd_optim = optax.chain(
    optax.sgd(learning_rate=1.0),
    optax.scale_by_backtracking_linesearch(max_backtracking_steps=15, verbose=True),
)
tx = vanilla_ngd(ngd_optim)  # use jinns custom wrapper to tell `solve` to use ngd

In [None]:
ngd_params = init_params

In [None]:
key, subkey = random.split(key, 2)
(
    ngd_params,
    total_loss_list,
    loss_by_term_dict,
    train_data,
    loss,
    _,
    _,
    _,
    _,
    _,
    _,
    _,
) = jinns.solve(
    init_params=ngd_params,
    data=train_data,
    optimizer=tx,
    loss=loss,
    n_iter=n_iter,
    print_loss_every=n_iter // 10,
)

In [None]:
for loss_name, loss_values in loss_by_term_dict.items():
    plt.plot(jnp.log10(loss_values), label=loss_name)
plt.plot(jnp.log10(total_loss_list), label="total loss")
plt.legend()
plt.title("Loss evolution during NGD")
plt.show()

In [None]:
def plot_pinn(u, est_params, val_xydata, times, figsize, plot_residuals=False):
    u_est = lambda t_x: u(t_x, est_params)[0]

    jinns.plot.plot2d(
        u_est,
        xy_data=val_xydata,
        times=times,
        cmap="viridis",
        figsize=figsize,
        vmin_vmax=(-1, 1),
    )
    plt.suptitle("PINN : u(t, x)")

    # Plot difference
    jinns.plot.plot2d(
        lambda tx: jnp.abs(u_true(tx) - u_est(tx)),
        xy_data=val_xydata,
        times=times,
        cmap="viridis",
        figsize=figsize,
    )
    plt.suptitle(r"Absolute difference with ground truth")

    if plot_residuals:
        # Plot the equation residuals
        print("Equation residuals : N[u](t, x)")
        jinns.plot.plot2d(
            lambda tx: dyn_loss_heat.equation(tx, u, est_params),
            xy_data=val_xydata,
            times=times,
            cmap="viridis",
            figsize=(10, 10),
            # vmin_vmax=(-1, 1),
        )

In [None]:
plot_pinn(u, ngd_params, val_xydata, times, figsize, plot_residuals=False)

# Comparison with vanilla GD (using Adam + more iterations)

In [None]:
n_iter = 10000
tx = optax.adam(learning_rate=1e-3)

In [None]:
sgd_params = init_params

In [None]:
key, subkey = random.split(key, 2)
(
    sgd_params,
    total_loss_list,
    loss_by_term_dict,
    train_data,
    loss,
    _,
    _,
    _,
    _,
    _,
    _,
    _,
) = jinns.solve(
    init_params=sgd_params,
    data=train_data,
    optimizer=tx,
    loss=loss,
    n_iter=n_iter,
    print_loss_every=n_iter // 10,
)

In [None]:
for loss_name, loss_values in loss_by_term_dict.items():
    plt.plot(jnp.log10(loss_values), label=loss_name)
plt.plot(jnp.log10(total_loss_list), label="total loss")
plt.legend()
plt.title("Loss evolution during Adam")
plt.show()

In [None]:
plot_pinn(u, sgd_params, val_xydata, times, figsize, plot_residuals=False)