In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dataclasses import dataclass
from math import pi
from typing import Any

import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from celluloid import Camera
from jax import Array
from jaxtyping import Float

# Burgers Equation

We use the notation $u_x$ as shorthand for $\frac{\partial{u}}{\partial{x}}$ (the same goes for $u_t$). Also, $u_{xx} = \frac{\partial^2 u}{\partial{x^2}}$ and so on. Burger's equation is a 1+1-D (1 dimension for time + 1 dimension for space) nonlinear, hyperbolic partial differential equation (PDE):
$$u_t + uu_x = \nu u_{xx}$$
Where $\nu$ is the _viscosity_ of the system.


We are going to _simulate_ this system by first setting some ground rules:
1. We'll consider the dynamics of this PDE on the spatial domain $x \in [0, L]$, and an arbitrary time span $t \in [0, T]$
2. We'll use periodic boundary conditions i.e. $u(0, t) = u(L, t),\; \forall t \in [0, T]$

We will simulate this system using a pseudospectral method

In [None]:
@dataclass
class PsuedoSpectralSolver1D:
    N: int
    bounds: tuple[float, float] = (-1, 1)

    @property
    def L(self) -> float:
        return self.bounds[1] - self.bounds[0]

    @property
    def dimension(self) -> int:
        return self.N + 2

    def __post_init__(self) -> None:
        self.ks = 2 * pi * jnp.fft.rfftfreq(self.N, self.L / self.N)
        self.domain = jnp.linspace(*self.bounds, self.N, endpoint=False)

    def to_fourier(self, u: Float[Array, " N"], N: int | None = None) -> Float[Array, " dimension"]:
        """FFT of a signal u, returns real/imag split format

        :param u: input signal of shape (N,)
        :param N: dimension of input signal

        :return: Fourier coefficients in real/imag split format of shape (dimension)
        """
        if N is None:
            N = self.N

        # shape: (N//2 + 1,)
        uk = jnp.fft.rfft(u, n=N)

        return jnp.concatenate([uk.real, uk.imag], axis=-1)

    def to_spatial(self, uk: Float[Array, " dimension"], N: int | None = None) -> Float[Array, " N"]:
        """Inverse FFT of the modes to get u(x) at a certain time

        :param uk: array of flattened fourier coefficients (real and imag components), can have batch dimensions
        :param N: grid resolution in the spatial domain

        :return: solution in the spatial domain
        """
        if N is None:
            N = self.N

        coeffs = uk[..., : self.dimension // 2] + 1j * uk[..., self.dimension // 2 :]

        return jnp.fft.irfft(coeffs, n=N)

    def integrate(
        self,
        init_cond: Float[Array, " dim"],
        tspan: list[float, float],
        args: Any | None = None,
        num_save_pts: int | None = None,
        method: str = "Tsit5",
        max_dt: float = 1e-3,
        atol: float = 1e-8,
        rtol: float = 1e-8,
        pid: tuple[float] = [0.3, 0.3, 0.0],
        max_steps: int | None = None,
    ) -> tuple[Float[Array, " T"], Float[Array, "T dim"]]:
        """Integrate the dynamical system, and save a equispaced trajectory

        :param init_cond: initial condition for the system
        :param tspan: integration timespan [t_init, t_final]
        :param args: optional args
        :param num_save_pts: number of points to save for the trajectory from the integrator

        :returns: Tuple of integrator times, and solution trajectory
        """
        if num_save_pts is None:
            save_pts = diffrax.SaveAt(t1=True)
            progress_bar = diffrax.NoProgressMeter()
        else:
            save_pts = diffrax.SaveAt(ts=jnp.linspace(tspan[0], tspan[1], num_save_pts))
            progress_bar = diffrax.TqdmProgressMeter(num_save_pts)

        stepsize_controller = diffrax.PIDController(
            rtol=atol,  # error tolerance of solution
            atol=rtol,  # error tolerance of solution
            pcoeff=pid[0],  # proportional strength for PID stepsize controller
            icoeff=pid[1],  # integral strength for PID stepsize controller
            dcoeff=pid[2],  # integral strength for PID stepsize controller
            dtmax=max_dt,  # max step size
        )
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(self.__call__),
            getattr(diffrax, method)(),
            *tspan,
            y0=init_cond,
            args=args,
            dt0=max_dt,
            saveat=save_pts,
            max_steps=max_steps,
            stepsize_controller=stepsize_controller,
            progress_meter=progress_bar,
        )
        return jnp.asarray(sol.ts), jnp.asarray(sol.ys)


In [None]:
@dataclass
class BurgersSolver(PsuedoSpectralSolver1D):
    nu: float = 0.1

    def __call__(self, t: float, uk: Float[Array, " dim"], args: Any | None = None) -> Float[Array, " dim"]:
        n_half = self.dimension // 2
        coeffs = uk[:n_half] + 1j * uk[n_half:]

        # Diffusion term: -ν * k² * uk (in Fourier space)
        diffusion_term = -self.nu * self.ks**2 * coeffs

        # Nonlinear term: -u * u_x (computed in spatial domain, then FFT)
        u = jnp.fft.irfft(coeffs, n=self.N)
        ux_coeffs = 1j * self.ks * coeffs
        ux = jnp.fft.irfft(ux_coeffs, n=self.N)
        nonlinear_spatial = -u * ux
        nonlinear_coeffs = jnp.fft.rfft(nonlinear_spatial)

        flow = diffusion_term + nonlinear_coeffs
        return jnp.concatenate([jnp.real(flow), jnp.imag(flow)])


In [None]:
solver = BurgersSolver(N=256, nu=1e-3)

twosigma = 0.1
ic = -solver.domain * jnp.exp(-(solver.domain**2) / twosigma)

ic_k = solver.to_fourier(ic)
ic_recon = solver.to_spatial(ic_k)

plt.title("Initial Condition")
plt.plot(ic, label="original")
plt.plot(ic_recon, label="reconstruction")
plt.legend();

In [None]:
# Integrate the system
tspan = [0, 5]
ts, uks = solver.integrate(ic_k, tspan, num_save_pts=100)

u = solver.to_spatial(uks)
print(u.shape)

In [None]:
extent = [tspan[0], tspan[1], solver.bounds[0], solver.bounds[1]]
plt.figure(figsize=(10, 5))
plt.imshow(u.T, aspect="auto", origin="lower", extent=extent, cmap="magma")
plt.xlabel("t")
plt.ylabel("x")
plt.title("Burgers Equation Wave Front")
plt.colorbar(label="u");

In [None]:
from IPython.display import HTML

fig, ax = plt.subplots(figsize=(8, 4))
camera = Camera(fig)

for i in range(u.shape[0]):
    ax.plot(solver.domain, u[i], color="r")
    ax.axvline(color="k", linestyle="-.")
    ax.set_ylim(u.min() - u.std(), u.max() + u.std())
    ax.set_title(f"Wavefront at t={ts[i]:.2f}")
    ax.set_xlabel("x")
    ax.set_ylabel("u(x, t)")
    camera.snap()

animation = camera.animate(interval=50)
plt.close(fig)
HTML(animation.to_jshtml())


In [None]:
class KuramotoShivashinskySolver(PsuedoSpectralSolver1D):
    def __call__(self, t: float, uk: Float[Array, " dim"], args: Any | None = None) -> Float[Array, " dim"]:
        n_half = self.dimension // 2
        coeffs = uk[:n_half] + 1j * uk[n_half:]

        # Linear term: -u_xx - u_xxxx
        linear_term = -(self.ks**4 - self.ks**2) * coeffs

        # Nonlinear term: d/dx (0.5 * u**2)
        u = jnp.fft.irfft(coeffs, n=self.N)
        nonlinear_term = 1j * self.ks * jnp.fft.rfft(0.5 * u**2, n=self.N)

        flow = linear_term + nonlinear_term
        return jnp.concatenate([jnp.real(flow), jnp.imag(flow)])


In [None]:
solver = KuramotoShivashinskySolver(N=256, bounds=(0, 100))

# make random fourier series as initial condition
num_sines = 5
rng = jax.random.key(0)
amp_rng, freq_rng = jax.random.split(rng, 2)
amplitudes = jax.random.normal(amp_rng, shape=num_sines)
frequencies = jax.random.uniform(freq_rng, shape=num_sines, minval=0, maxval=2)
ic = amplitudes @ jnp.sin(frequencies[:, jnp.newaxis] * solver.domain[jnp.newaxis, :])
ic *= jnp.exp(solver.domain / solver.L)

ic_k = solver.to_fourier(ic)
ic_recon = solver.to_spatial(ic_k)

plt.plot(ic)
plt.plot(ic_recon);

In [None]:
ts, uks = solver.integrate(ic_k, tspan=(0, 100), num_save_pts=300)

us = solver.to_spatial(uks)

In [None]:
extent = [tspan[0], tspan[1], solver.bounds[0], solver.bounds[1]]
plt.figure(figsize=(10, 5))
plt.imshow(us.T, aspect="auto", origin="lower", extent=extent, cmap="magma")
plt.xlabel("t")
plt.ylabel("x")
plt.title("KS Equation")
plt.colorbar(label="u");

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
camera = Camera(fig)

for i in range(us.shape[0]):
    ax.plot(solver.domain, us[i], color="r")
    ax.set_ylim(us.min() - us.std(), us.max() + us.std())
    ax.set_xlabel("x")
    ax.set_ylabel("u(x, t)")
    camera.snap()

animation = camera.animate(interval=50)
plt.close(fig)
HTML(animation.to_jshtml())