In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable


In [None]:
def HJ_exact(
    x: np.ndarray,
    t: float,
    u0: Callable[[np.ndarray], np.ndarray],
    u0_x: Callable[[np.ndarray], np.ndarray],
    u0_xx: Callable[[np.ndarray], np.ndarray],
    Hp: Callable,
    dHdp: Callable,
    dH2d2p: Callable,
    init: Callable,
    ep: float,
) -> np.ndarray:
    """
    u(x,t): exact solution of the H-J equation
    """

    if t < 0:
        raise ValueError(f"{t=}<0")
    if ep <= 0:
        ep = 1e-6

    def L(p):
        return p * dHdp(p) - Hp(p)

    xi = init(x, t)

    iter_max: int = 100000
    iter: int = 0
    while iter < iter_max:
        tmp = 1 + dH2d2p(u0_x(xi)) * u0_xx(xi) * t
        dxi = (xi - x + dHdp(u0_x(xi)) * t) / tmp
        xi = xi - dxi

        iter += 1
        if np.max(np.abs(dxi)) < ep:
            break

    return u0(xi) + L(u0_x(xi)) * t

In [None]:
def nearest_even(x):
    x = np.asarray(x)
    x_rounded = np.round(x)
    odd_mask = x_rounded % 2 != 0
    lower_even = x_rounded - 1
    upper_even = x_rounded + 1
    nearest = np.where(
        odd_mask,
        np.where(
            np.abs(x - lower_even) <= np.abs(x - upper_even), lower_even, upper_even
        ),
        x_rounded,
    )
    return nearest if isinstance(x, np.ndarray) else nearest.item()


x = np.linspace(-1, 1, 200)
u = HJ_exact(
    x=x,
    t=1.5 / np.pi**2,
    u0=lambda s: -np.cos(np.pi * s),
    u0_x=lambda s: np.pi * np.sin(np.pi * s),
    u0_xx=lambda s: np.pi**2 * np.cos(np.pi * s),
    Hp=lambda p: ((p + 1) ** 2) / 2,
    dHdp=lambda p: p + 1,
    dH2d2p=lambda p: 1,
    init=lambda x, t: nearest_even(x - t),
    ep=1e-8,
)

plt.plot(x, u)

In [None]:
def start_point_via_scan(x, t):
    xstar = 0
    x = np.mod(x - xstar, 2) + xstar
    xvec = np.linspace(xstar, xstar + 2, 40)
    uvec = np.sin(1 + np.pi * np.sin(np.pi * xvec))

    condition = (xvec[:, None] + uvec[:, None] * t) >= x
    idx = np.argmax(condition, axis=0)
    return uvec[idx]


x = np.linspace(-1, 1, 200)
u = HJ_exact(
    x=x,
    t=1.5 / np.pi**2,
    u0=lambda s: -np.cos(np.pi * s),
    u0_x=lambda s: np.pi * np.sin(np.pi * s),
    u0_xx=lambda s: np.pi**2 * np.cos(np.pi * s),
    Hp=lambda p: -np.cos(p + 1),
    dHdp=lambda p: np.sin(p + 1),
    dH2d2p=lambda p: np.cos(p + 1),
    init=lambda x, t: start_point_via_scan(x, t),
    ep=1e-9,
)

plt.plot(x, u)