In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dataclasses import dataclass
from math import pi
from typing import Any

import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from celluloid import Camera
from IPython.display import HTML
from jax import Array
from jaxtyping import Float

jax.config.update("jax_enable_x64", True)

In [None]:
def plot_solution(
    u: Float[Array, "T N"],
    tspan: tuple[float, float],
    bounds: tuple[float, float],
    title: str | None = None,
    cmap: str = "magma",
) -> None:
    extent = [tspan[0], tspan[1], bounds[0], bounds[1]]
    plt.figure(figsize=(10, 5))
    plt.imshow(u.T, aspect="auto", origin="lower", extent=extent, cmap="magma")
    plt.xlabel("$t$")
    plt.ylabel("$x$")
    plt.title(title)
    plt.colorbar(label="$u(x, t)$")


def plot_wavefront_1d(
    ts: Float[Array, " T"],
    xs: Float[Array, " N"],
    u: Float[Array, "T N"],
    bounds: tuple[float, float],
    vline: float | None = None,
    figsize: tuple[int, int] = (8, 4),
) -> HTML:
    fig, ax = plt.subplots(figsize=figsize)
    camera = Camera(fig)

    for i in range(u.shape[0]):
        ax.plot(xs, u[i], color="r")
        if vline is not None:
            ax.axvline(x=vline, color="k", linestyle="-.")
        ax.set_ylim(u.min() - u.std(), u.max() + u.std())
        ax.set_xlim(bounds[0], bounds[1])
        ax.set_xlabel("$x$")
        ax.set_ylabel("$u(x, t)$")
        ax.text(0.5, 1.02, f"$t = {ts[i]:.2f}$", transform=ax.transAxes, ha="center", fontsize=12)
        camera.snap()

    animation = camera.animate(interval=50)
    plt.close(fig)
    return HTML(animation.to_jshtml())


# The Psuedospectral Method

Given a PDE operator $\mathcal{R}$ it is often useful to split it into linear and nonlinear operators that satisfy the equation: $$\frac{\partial u}{\partial t} = \mathcal{R}[u] = \mathcal{L}u + \mathcal{N}[u]$$
If we have periodic boundary conditions on the solution/function $u$, we can rely on the fourier transform to yield the __Fourier Pseudospectral Method__, though of course, there are fancier pseudospectral methods that can handle non-periodic functions.
$$\frac{\partial \hat{u}}{\partial t} = \sigma\cdot\hat{u} + \mathcal{F}[\mathcal{N}[\mathcal{F}^{-1}[\hat{u}]]]$$
Where $\hat{u} = \mathcal{F}[u]$ and $\mathcal{F}\{\mathcal{L}u\}(k) = \sigma(k)\cdot\hat{u}(k)$ is the Fourier transform of the linear term (this is attributed to the fact that _every continuous, shift-invariant, linear operator is equivalent to a convolution_ [[1]](https://matthewhirn.com/wp-content/uploads/2020/01/math994_spring2020_lecture01-1.pdf))

In practice, we use the FFT to implement the Fourier transform. Note that in the above equation, the linear term can be treated exactly, but the nonlinear term is actually an approximation. The whole equation is in terms of $\hat{u}$, so we transform back to "real space" to compute the nonlinear term and then transform back to Fourier space - this is mostly for efficiency reasons.


In [None]:
@dataclass
class PsuedoSpectralSolver1D:
    N: int
    bounds: tuple[float, float] = (-1, 1)

    @property
    def L(self) -> float:
        return self.bounds[1] - self.bounds[0]

    @property
    def dimension(self) -> int:
        return self.N + 2

    def __post_init__(self) -> None:
        self.ks = 2 * pi * jnp.fft.rfftfreq(self.N, self.L / self.N)
        self.domain = jnp.linspace(*self.bounds, self.N, endpoint=False)

    def to_fourier(self, u: Float[Array, " N"], N: int | None = None) -> Float[Array, " dimension"]:
        """FFT of a signal u, returns real/imag split format

        :param u: input signal of shape (N,)
        :param N: dimension of input signal

        :return: Fourier coefficients in real/imag split format of shape (dimension)
        """
        if N is None:
            N = self.N

        # shape: (N//2 + 1,)
        uk = jnp.fft.rfft(u, n=N)

        return jnp.concatenate([uk.real, uk.imag], axis=-1)

    def to_spatial(self, uk: Float[Array, " dimension"], N: int | None = None) -> Float[Array, " N"]:
        """Inverse FFT of the modes to get u(x) at a certain time

        :param uk: array of flattened fourier coefficients (real and imag components), can have batch dimensions
        :param N: grid resolution in the spatial domain

        :return: solution in the spatial domain
        """
        if N is None:
            N = self.N

        coeffs = uk[..., : self.dimension // 2] + 1j * uk[..., self.dimension // 2 :]

        return jnp.fft.irfft(coeffs, n=N)

    def integrate(
        self,
        init_cond: Float[Array, " dim"],
        tspan: list[float, float],
        args: Any | None = None,
        num_save_pts: int | None = None,
        method: str = "Tsit5",
        max_dt: float = 1e-3,
        atol: float = 1e-8,
        rtol: float = 1e-8,
        pid: tuple[float] = [0.3, 0.3, 0.0],
        max_steps: int | None = None,
    ) -> tuple[Float[Array, " T"], Float[Array, "T dim"]]:
        """Integrate the dynamical system, and save a equispaced trajectory

        :param init_cond: initial condition for the system
        :param tspan: integration timespan [t_init, t_final]
        :param args: optional args
        :param num_save_pts: number of points to save for the trajectory from the integrator

        :returns: Tuple of integrator times, and solution trajectory
        """
        if num_save_pts is None:
            save_pts = diffrax.SaveAt(t1=True)
            progress_bar = diffrax.NoProgressMeter()
        else:
            save_pts = diffrax.SaveAt(ts=jnp.linspace(tspan[0], tspan[1], num_save_pts))
            progress_bar = diffrax.TqdmProgressMeter(num_save_pts)

        stepsize_controller = diffrax.PIDController(
            rtol=atol,  # error tolerance of solution
            atol=rtol,  # error tolerance of solution
            pcoeff=pid[0],  # proportional strength for PID stepsize controller
            icoeff=pid[1],  # integral strength for PID stepsize controller
            dcoeff=pid[2],  # integral strength for PID stepsize controller
            dtmax=max_dt,  # max step size
        )
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(self.__call__),
            getattr(diffrax, method)(),
            *tspan,
            y0=init_cond,
            args=args,
            dt0=max_dt,
            saveat=save_pts,
            max_steps=max_steps,
            stepsize_controller=stepsize_controller,
            progress_meter=progress_bar,
        )
        return jnp.asarray(sol.ts), jnp.asarray(sol.ys)


# Burgers Equation

We use the notation $u_x$ as shorthand for $\frac{\partial{u}}{\partial{x}}$ (the same goes for $u_t$). Also, $u_{xx} = \frac{\partial^2 u}{\partial{x^2}}$ and so on. Burger's equation is a (1+1)-D (1 dimension for time + 1 dimension for space) nonlinear, hyperbolic partial differential equation (PDE):
$$u_t + uu_x = \nu u_{xx}$$
Where $\nu$ is the _viscosity_ of the system.

This models _shock formation_ in traveling waves.


We are going to _simulate_ this system by first setting some ground rules:
1. We'll consider the dynamics of this PDE on the spatial domain $x \in [0, L]$, and an arbitrary time span $t \in [0, T]$
2. We'll use periodic boundary conditions i.e. $u(0, t) = u(L, t),\; \forall t \in [0, T]$

Under the psuedospectral scheme, we have: $$\frac{d \hat{u}}{d t} = -\nu k^2 \hat{u} - \text{FFT}[\text{iFFT}[\hat{u}]\cdot\text{iFFT}[ik\hat{u}]]$$
Notice that $\hat{u}(k)$ can be expressed as a vector $\hat{u} = [\hat{u}_1, \ldots, \hat{u}_k, \ldots, \hat{u}_M]$ where $M$ is the number of modes that results from the discretization of the length $L$ domain. Thus, each Fourier mode is _predetermined_ via the discretization and is no longer a functional variable i.e. the PDE is now purely an ordinary differential equation (ODE) in time!



In [None]:
@dataclass
class BurgersSolver(PsuedoSpectralSolver1D):
    nu: float = 0.1

    def __call__(self, t: float, uk: Float[Array, " dim"], args: Any | None = None) -> Float[Array, " dim"]:
        n_half = self.dimension // 2
        coeffs = uk[:n_half] + 1j * uk[n_half:]

        # Diffusion term: -ν * k² * uk (in Fourier space)
        diffusion_term = -self.nu * self.ks**2 * coeffs

        # Nonlinear term: -u * u_x (computed in spatial domain, then FFT)
        u = jnp.fft.irfft(coeffs, n=self.N)
        ux_coeffs = 1j * self.ks * coeffs
        ux = jnp.fft.irfft(ux_coeffs, n=self.N)
        nonlinear_spatial = -u * ux
        nonlinear_coeffs = jnp.fft.rfft(nonlinear_spatial)

        flow = diffusion_term + nonlinear_coeffs
        return jnp.concatenate([jnp.real(flow), jnp.imag(flow)])


In [None]:
solver = BurgersSolver(N=256, nu=1e-3)

twosigma = 0.1
ic = -solver.domain * jnp.exp(-(solver.domain**2) / twosigma)

ic_k = solver.to_fourier(ic)
ic_recon = solver.to_spatial(ic_k)

plt.title("Initial Condition")
plt.plot(ic, label="original")
plt.plot(ic_recon, label="reconstruction")
plt.xlabel("$x$")
plt.ylabel("$u(x, 0)$")
plt.legend();

In [None]:
# Integrate the system
tspan = [0, 5]
ts, uks = solver.integrate(ic_k, tspan, num_save_pts=100)

u = solver.to_spatial(uks)
print(u.shape)

In [None]:
plot_solution(u, tspan, solver.bounds, title="Burgers Equation Wavefront")

In [None]:
plot_wavefront_1d(ts, solver.domain, u, solver.bounds, vline=0.0)

# Kuramoto-Shivashinsky Equation
This is a (1+1)-D spatiotemporal PDE, that exhibits _very_ rich dynamical behavior - namely, _spatiotemporal chaos_. 
$$u_t + u_{xx} + u_{xxxx} + \frac{1}{2}(u^2)_x = 0$$
This PDE models a wide range of phenomena such as 1-D flame fronts, liquid film sliding down an inclined surface, etc.

Again, we assume periodic boundary equations and use the Fourier pseudospectral method:
$$\frac{d\hat{u}}{dt} = (k^2 - k^4)\hat{u} -\frac{1}{2}ik\cdot\text{FFT}[\text{iFFT}[\hat{u}]^2]$$

In [None]:
class KuramotoShivashinskySolver(PsuedoSpectralSolver1D):
    def __call__(self, t: float, uk: Float[Array, " dim"], args: Any | None = None) -> Float[Array, " dim"]:
        n_half = self.dimension // 2
        coeffs = uk[:n_half] + 1j * uk[n_half:]

        # Linear term: -u_xx - u_xxxx
        linear_term = -(self.ks**4 - self.ks**2) * coeffs

        # Nonlinear term: d/dx (0.5 * u**2)
        u = jnp.fft.irfft(coeffs, n=self.N)
        nonlinear_term = -1j * self.ks * jnp.fft.rfft(0.5 * u**2, n=self.N)

        flow = linear_term + nonlinear_term
        return jnp.concatenate([jnp.real(flow), jnp.imag(flow)])


In [None]:
solver = KuramotoShivashinskySolver(N=256, bounds=(0, 100))

# make random fourier series as initial condition
num_sines = 5
rng = jax.random.key(0)
amp_rng, freq_rng = jax.random.split(rng, 2)
amplitudes = jax.random.normal(amp_rng, shape=num_sines)
frequencies = jax.random.uniform(freq_rng, shape=num_sines, minval=0, maxval=2)
ic = amplitudes @ jnp.sin(frequencies[:, jnp.newaxis] * solver.domain[jnp.newaxis, :])
ic *= jnp.exp(solver.domain / solver.L)

ic_k = solver.to_fourier(ic)
ic_recon = solver.to_spatial(ic_k)

plt.title("Initial Condition")
plt.plot(ic, label="original")
plt.plot(ic_recon, label="reconstructed")
plt.ylabel("$u(x, 0)$")
plt.xlabel("$x")
plt.legend();

In [None]:
tspan = (0, 100)
ts, uks = solver.integrate(ic_k, tspan=tspan, num_save_pts=300)

us = solver.to_spatial(uks)

In [None]:
plot_solution(us, tspan, bounds=solver.bounds, title="Kuramoto-Shivashinsky Wavefront")

In [None]:
plot_wavefront_1d(ts, solver.domain, us, bounds=solver.bounds)

# Question
What happens when you change the domain length $L$? What happens to the shape of solution and the solve time and why?

# Korteweg-De Vries (KdV) Equation
Yet another (1+1)-D nonlinear spatiotemporal PDE that is that models dispersive, non-dissipative waves on shallow water:
$$u_t + u_{xxx} - 6uu_x = 0$$
Mathematically, this PDE is very intriguing. It exhibits the following properties:
1. Soliton solutions [[2]](https://en.wikipedia.org/wiki/Soliton): wave packets that are robust to collisions
2. _Infinitely_ many conserved quantities e.g. mass, momentum, energy, etc
3. Continuum limit of the Fermi-Pasta-Ulam-Tsinguo (FPUT) problem which is said to have birthed the field of scientific computing [[3]](https://people.maths.ox.ac.uk/porterm/papers/fpupop_final.pdf)

With the usual Fourier machinery, we have:
$$\frac{d\hat{u}}{dt} = ik^3\hat{u} + 3ik\cdot\text{FFT}[\text{iFFT}[\hat{u}]^2]$$

In [None]:
class KortewegDeVriesSolver(PsuedoSpectralSolver1D):
    def __call__(self, t: float, uk: Float[Array, " dim"], args: Any | None = None) -> Float[Array, " dim"]:
        n_half = self.dimension // 2
        coeffs = uk[:n_half] + 1j * uk[n_half:]

        # Linear term: -u_xxx
        linear_term = 1j * self.ks**3 * coeffs

        # Nonlinear term: 3*(u^2)_x
        u = jnp.fft.irfft(coeffs, n=self.N)
        nonlinear_term = 3j * self.ks * jnp.fft.rfft(u**2, n=self.N)

        flow = linear_term + nonlinear_term
        return jnp.concatenate([jnp.real(flow), jnp.imag(flow)])

In [None]:
solver = KortewegDeVriesSolver(N=256, bounds=(0, 50))

twosigma = 50
ic = jnp.cos(2.5 * pi * solver.domain / solver.L)

ic_k = solver.to_fourier(ic)
ic_recon = solver.to_spatial(ic_k)

plt.title("Initial Condition")
plt.plot(ic, label="original")
plt.plot(ic_recon, label="reconstructed")
plt.ylabel("$u(x, 0)$")
plt.xlabel("$x$")
plt.legend();

In [None]:
tspan = (0, 10)
ts, uks = solver.integrate(ic_k, tspan=tspan, num_save_pts=100)

us = solver.to_spatial(uks)

In [None]:
plot_solution(us, tspan, solver.bounds, title="Korteweg-De Vries Wavefront")

In [None]:
plot_wavefront_1d(ts, solver.domain, us, solver.bounds)

# What Happened?
Why does it look like shit? This is a deterministic PDE, why does the solution look very stochastic?

The noise is actually due to _aliasing error_.

On an equally spaced grid of $N$ points, the maximum resolvable wavenumber from the DFT is $k_{max} = N/2$ (prove this). However, for a quadratic nonlinearity in the PDE, $\text{DFT}[u\cdot v]_k = [\hat{u} * \hat{v}]_k = \sum_{i+j=k}\hat{u}_i\hat{v}_j$. Notice that since $i$ and $j$ _both_ go up to $k_{max}$, $k$ goes up to $2k_{max}$ which exceeds the maximum resolvable wavenumber. Moreover, $k_A = i + j + 2mk_{max}$ are _alias_ modes that are equivalent to modes within the region |k| < k_max (prove this). 

Aliased modes map high frequencies outside the resolvable band $-k_max, \ldots, 0, \ldots, k_max$ _back into_ the resolvable region. Thus, __high frequency modes from the real system corrupt the low frequency modes of the signal__.

The solution to this issue is called _Orzag's 2/3 Aliasing Rule_ [[4]](https://journals.ametsoc.org/view/journals/atsc/28/6/1520-0469_1971_028_1074_oteoai_2_0_co_2.xml), which simply prescribes truncating the Fourier spectrum to the bottom 2/3 frequencies for the nonlinear terms.

In [None]:
class KortewegDeVriesDealiasedSolver(PsuedoSpectralSolver1D):
    def __call__(self, t: float, uk: Float[Array, " dim"], args: Any | None = None) -> Float[Array, " dim"]:
        n_half = self.dimension // 2
        coeffs = uk[:n_half] + 1j * uk[n_half:]

        # Linear term: -u_xxx
        linear_term = 1j * self.ks**3 * coeffs

        # Nonlinear term: 3*(u^2)_x
        u = jnp.fft.irfft(coeffs, n=self.N)
        nonlinear_term = 3j * self.ks * jnp.fft.rfft(u**2, n=self.N)
        mask = self.ks < 1 / 2 * jnp.max(self.ks)
        nonlinear_term = nonlinear_term * mask

        flow = linear_term + nonlinear_term
        return jnp.concatenate([jnp.real(flow), jnp.imag(flow)])

In [None]:
solver = KortewegDeVriesDealiasedSolver(N=256, bounds=(0, 50))

twosigma = 50
ic = jnp.cos(2.5 * pi * solver.domain / solver.L)

ic_k = solver.to_fourier(ic)
ic_recon = solver.to_spatial(ic_k)

plt.title("Initial Condition")
plt.plot(ic, label="original")
plt.plot(ic_recon, label="reconstructed")
plt.ylabel("$u(x, 0)$")
plt.xlabel("$x$")
plt.legend();

In [None]:
tspan = (0, 10)
ts, uks = solver.integrate(ic_k, tspan=tspan, num_save_pts=100)

us = solver.to_spatial(uks)

In [None]:
plot_solution(us, tspan, solver.bounds, title="Korteweg-De Vries Wavefront")

In [None]:
plot_wavefront_1d(ts, solver.domain, us, solver.bounds)