# The KdV equation

In this notebook, we consider the one-dimensional forced Korteweg–de Vries (KdV) equation for waves on shallow water. It is given by
\begin{equation}
u_t + \eta u u_x + \gamma^2 u_{xxx} = g(x,t),
\end{equation}
where $\eta, \gamma \in \mathbb{R}$. If $g(x,t) = 0$, we have the standard unforced KdV equation.

For the standard KdV equation, the energy given by
$$
\begin{align*}
\mathcal{H}[u] &= \int_\mathbb{R} \left(-\frac{\eta}{6} u^3 + \frac{\gamma^2}{2}u_x^2 \right)\, dx
\end{align*}
$$
is conserved, i.e. constant over time.

In this notebook I will use a neural operator network, in an attempt to learn the transformation from initial conditions (at $t=0$) to solutions at a later timepoints, $t>0$.

I will also enfore hard constraints on the network, to make it preserve the Hamiltonian.

## Import libraries

In [47]:
from typing import Callable

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float  # https://github.com/google/jaxtyping
import jax.random as random
import optimistix as optx
from functools import partial

jax.config.update("jax_enable_x64", True)
key = random.PRNGKey(0)

from discretization import SpatialDiscretization, SpatioTemporalDiscretization, central_difference_1, central_difference_2

## Parameters

In [48]:
η = 6.0
γ = 1.0
P = 20 # period (and end of the domain)
M = 100 # M+1 equally spaced points in the domain, but we exclude the last one due to periodicity
N = 100 # N+1 time points 

t0 = 0.0 # initial time
t_final = 2.0 # end time

δt = t_final / N # time step
δx = P / M # space step

x0 = 0.0 # initial position
x_final = P-δx # final position (excluding the last point due to periodicity)

x = jnp.linspace(x0, x_final, M) # domain
t = jnp.linspace(t0, t_final, N+1) # time domain

args = {"η" : η, "γ": γ}

## Generate initial conditions

In [49]:
def sech(x): return 1/jnp.cosh(x)

def initial_condition_kdv(x, key, η=6., P=20):
    """
    Generate the initial condition for the Korteweg-de Vries (KdV) equation.
    Parameters:
        x (float or array-like) : A single point or array in the spatial domain. 
        key (jax.random.PRNGKey): The random key for generating random numbers.
        η (float, optional): The coefficient for the KdV equation. Default is 6.
        P (float, optional): The period of the spatial domain. Default is 20.
    Returns:
        array-like: The initial condition for the KdV equation.
    """
    
    key_cs, key_ds = random.split(key, 2)
    c1, c2 = random.uniform(key_cs, minval=0.5, maxval=2, shape=(2,))
    d1, d2 = random.uniform(key_ds, minval=0, maxval=1, shape=(2,))
    
    u0 = (-6./-η)*2 * c1**2 * sech(c1 * ((x+P/2-P*d1) % P - P/2))**2
    u0 += (-6./-η)*2 * c2**2 * sech(c2 * ((x+P/2-P*d2) % P - P/2))**2
    return u0

In [50]:
def newton(f, j, guess, args, tol):
    f_prime = jax.grad(f)
    def q(x):
        return x - f(x) / f_prime(x)

    error = tol + 1
    x = guess
    while error > tol:
        y = q(x)
        error = jnp.linalg.norm(F(un))
        x = y
        
    return x

@partial(jax.jit, static_argnums=(0,))
def implicit_midpoint_step(f, tn, un, dt, args, rtol, atol):
    fn = f(tn, un, args)

    # The update should satisfy y1 = eq(y1), i.e. y1 is a fixed point of fn
    def eq(u, args):
        return un + dt * f(tn+0.5*dt, 0.5*(un+u), args)
    
    u_next_euler = un + dt * fn # Euler step as guess

    solver = optx.Newton(rtol, atol) # Chord is Newtons' method, but reuses the jacobian of the first guess
    u_next = optx.fixed_point(eq, solver, u_next_euler, args).value  # satisfies y1 == fn(y1)
    
    return u_next

@partial(jax.jit, static_argnums=(0,))
def implicit_midpoint(f, u0, t, args, rtol, atol):
    dt = t[1] - t[0]
    @jax.jit
    def scan_body(carry, t):
        u, dt = carry
        u_next = implicit_midpoint_step(f, t, u, dt, args, rtol, atol)
        return (u_next, dt), u
    
    _, u_arr = jax.lax.scan(scan_body, (u0, dt), t)
    return SpatioTemporalDiscretization(x0, x_final, t0, t_final, u_arr.vals)

In [51]:
def vector_field(t, u, args):
    η = 6.0# args["η"]
    γ = 1.0#args["γ"]
    return -1.*central_difference_1(0.5*η*u*u + γ**2 * central_difference_2(u))

In [52]:
u0 = SpatialDiscretization(x0, x_final, initial_condition_kdv(x, random.PRNGKey(0)))

atol, rtol = 1e-12, 1e-12

u_arr = implicit_midpoint(vector_field, u0, t, args, rtol, atol)

In [53]:
%timeit jax.block_until_ready(implicit_midpoint(vector_field, u0, t, args, rtol, atol))

56.2 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [54]:
def H_energy(u : SpatialDiscretization) -> Float:
    η, γ = 6.0, 1.0
    ux = central_difference_1(u)
    integrand = (-η/6*u*u*u + 0.5*γ**2*ux*ux).vals
    return u.δx*jnp.sum(integrand)

In [55]:
def H_mass(u : SpatialDiscretization) -> Float:
    return u.δx * jnp.sum(u.vals)

def H_momentum(u : SpatialDiscretization) -> Float:
    return u.δx * jnp.sum((u*u).vals)

In [56]:
u_arr

SpatioTemporalDiscretization(
  x0=0.0,
  x_final=19.8,
  t0=0.0,
  t_final=2.0,
  vals=f64[101,100]
)

In [57]:
#Hs = jax.vmap(H_energy)(u_arr)

#fig = plt.figure(figsize=(7,4))
#lw = 2
#plt.plot(t, jnp.abs(Hs-Hs[0]), 'k', linewidth=lw)
#plt.xlabel(r'$t$', fontsize=12)
#plt.ylabel(r'$|\mathcal{H}(t) - \mathcal{H}(t_0)|$', fontsize=12)
#plt.title('Energy error', fontsize=14)
#plt.show()

## Solve many KdVs at once!

We want to store the solutions in an array with shape (Samples,) where each element is a SpatioTemporalDiscretization

In [58]:
u0 = SpatialDiscretization.discretize_fn(x0, x_final, M, lambda x: initial_condition_kdv(x, random.PRNGKey(0)))

atol, rtol = 1e-12, 1e-12

u_sol = implicit_midpoint(vector_field, u0, t, args, rtol, atol)


In [59]:
num_samples = 8000

keys = random.split(random.PRNGKey(0), num_samples)

initial_conditions = jax.vmap(jax.vmap(initial_condition_kdv, [0, None]), [None, 0])(x, keys)

In [61]:
solutions = [implicit_midpoint(vector_field, SpatialDiscretization(x0, x_final, ic), t, args, rtol, atol) for ic in initial_conditions]

KeyboardInterrupt: 

In [22]:
#eqx.tree_serialise_leaves("kdv_solutions.eqx", solutions)

In [23]:
#solutions[0].vals.shape

(101, 100)

In [24]:
#import json

In [26]:
"""
all_args = {"η": η, "γ": γ, "M": M, "N": N, "t0": t0, "t_final": t_final, "x0": x0, "x_final": x_final, "num_samples": num_samples}

def save(filename, all_args, data):
    with open(filename, "wb") as f:
        hyperparam_str = json.dumps(all_args)
        f.write((hyperparam_str + "\n").encode())
        eqx.tree_serialise_leaves(f, data)


save("kdv_solutions.eqx", all_args, solutions)
"""

In [None]:
#jnp.save("kdv_solutions.npy", solutions)

In [None]:
#solutions.shape