In [28]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as random
import optimistix as optx
from functools import partial
import equinox as eqx
key = random.key(0)
import argparse

#################PARAMETERS####################

η = 6.0
γ = 1.0
P = 20 # period (and end of the domain)
M = 511 # M+1 equally spaced points in the domain, but we exclude the last one due to periodicity
N = 511 # N+1 time points 

t0 = 0.0 # initial time
t_final = 2.0 # end time

dt = t_final / N # time step
dx = P / M # space step

x0 = 0.0 # initial position
x_final = P-dx # final position (excluding the last point due to periodicity)

x = jnp.linspace(x0, x_final, M) # domain
t = jnp.linspace(t0, t_final, N+1) # time domain

args = {"η" : η, "γ": γ, "dx" : dx}

###############INITIAL CONDITION####################

def sech(x): return 1/jnp.cosh(x)

def initial_condition_kdv(x, key, η=6., P=20):
    """
    Generate the initial condition for the Korteweg-de Vries (KdV) equation.
    Parameters:
        x (float or array-like) : A single point or array in the spatial domain. 
        key (jax.random.PRNGKey): The random key for generating random numbers.
        η (float, optional): The coefficient for the KdV equation. Default is 6.
        P (float, optional): The period of the spatial domain. Default is 20.
    Returns:
        array-like: The initial condition for the KdV equation.
    """
    
    key_cs, key_ds = random.split(key, 2)
    c1, c2 = random.uniform(key_cs, minval=0.5, maxval=2, shape=(2,))
    d1, d2 = random.uniform(key_ds, minval=0, maxval=1, shape=(2,))
    
    u0 = (-6./-η)*2 * c1**2 * sech(c1 * ((x+P/2-P*d1) % P - P/2))**2
    u0 += (-6./-η)*2 * c2**2 * sech(c2 * ((x+P/2-P*d2) % P - P/2))**2
    return u0

#############SPATIAL DERIVATIVES####################
def Dx(y, dx, order = 6):
    """Assumes periodic boundary conditions"""
    y_p_1 = jnp.roll(y, shift=-1)
    y_m_1 = jnp.roll(y, shift=1)
    if order == 2:
        return (y_p_1 - y_m_1) / (2 * dx)
    elif order == 4:
        y_p_2 = jnp.roll(y, shift=-2)
        y_m_2 = jnp.roll(y, shift=2)
        return (-y_p_2 + 8*y_p_1 - 8*y_m_1 + y_m_2)/(12*dx)
    elif order == 6:
        y_p_2 = jnp.roll(y, shift=-2)
        y_m_2 = jnp.roll(y, shift=2)
        y_p_3 = jnp.roll(y, shift=-3)
        y_m_3 = jnp.roll(y, shift=3)
        return (y_p_3 - 9*y_p_2 + 45*y_p_1 - 45*y_m_1 + 9*y_m_2 - y_m_3)/(60*dx)
    else:
        raise ValueError("Only 2nd, 4th and 6th order accurate first derivatives are implemented")
        

def Dxx(y, dx, order = 6):
    """Assumes periodic boundary conditions"""
    y_p_1 = jnp.roll(y, shift=-1)
    y_m_1 = jnp.roll(y, shift=1)
    
    if order == 2:
        return (y_p_1 - 2 * y + y_m_1) / dx**2
    elif order == 4:
        y_p_2 = jnp.roll(y, shift=-2)
        y_m_2 = jnp.roll(y, shift=2)
        return (-y_p_2+16*y_p_1-30*y+16*y_m_1-y_m_2)/(12*dx**2)
    elif order == 6:
        y_p_2 = jnp.roll(y, shift=-2)
        y_m_2 = jnp.roll(y, shift=2)
        y_p_3 = jnp.roll(y, shift=-3)
        y_m_3 = jnp.roll(y, shift=3)
        return (270*y_m_1 - 27*y_m_2 + 2*y_m_3 + 270*y_p_1 - 27*y_p_2 + 2*y_p_3 - 490*y) / (180*dx**2)
    else:
        raise ValueError("Only 2nd, 4th and 6th order accurate second derivatives are implemented")

#############SOLVERS####################

@partial(jax.jit, static_argnums=(0,))
def implicit_midpoint(f, u0, dt, t, args, rtol, atol):
    def step(carry, tn):
        un, dt = carry

        fn = f(tn, un, args)

        # The update should satisfy y1 = eq(y1), i.e. y1 is a fixed point of fn
        def eq(u, args):
            return un + dt * f(tn+0.5*dt, 0.5*(un+u), args)

        u_next_euler = un + dt * fn # Euler step as guess

        solver = optx.Newton(rtol, atol)
        u_next = optx.fixed_point(eq, solver, u_next_euler, args).value  # satisfies y1 == fn(y1)
        return (u_next, dt), un
    
    _, u_arr = jax.lax.scan(step, (u0, dt), t)
    return u_arr


@partial(jax.jit, static_argnums=(0,))
def gauss_legendre_4(f, u0, dt, t, args, rtol, atol):
    """
    Integrates the ODE system using the Gauss-Legendre method of order 4.
    Implementation follows "IV.8 Implementation of Implicit Runge-Kutta Methods" in 
    "Solving Ordinary Differential Equations II" by Hairer and Wanner

    Args:
      f: The right-hand side function of the ODE system.
      u0: Initial condition.
      dt: Time step.
      t: Array of time points.
      args: Additional arguments to pass to f.
      rtol: Relative tolerance for the nonlinear solver.
      atol: Absolute tolerance for the nonlinear solver.

    Returns:
      An array of solution values at the given time points.
    """
    c = jnp.array([0.5 - jnp.sqrt(3)/6, 0.5 + jnp.sqrt(3)/6])
    A = jnp.array([[0.25, 0.25 - jnp.sqrt(3)/6], 
                   [0.25 + jnp.sqrt(3)/6, 0.25]])
    d = jnp.array([-jnp.sqrt(3), jnp.sqrt(3)])
    
    def step(carry, tn): 
        un, z_guess = carry
        #u_next = implicit_gauss_legendre_step(f, tn, un, dt, args, rtol, atol)
        def eq(z, args):
            z1 = dt*(A[0,0] * f(tn + c[0]*dt, un + z[0], args) + A[0,1]*f(tn + c[1]*dt, un + z[1], args))
            z2 = dt*(A[1,0] * f(tn + c[0]*dt, un + z1, args) + A[1,1]*f(tn + c[1]*dt, un + z[1], args))
            return jnp.array([z1, z2])
        
        solver = optx.Newton(rtol, atol)
        z_next = optx.fixed_point(eq, solver, z_guess, args, throw=False, max_steps=50).value
        u_next = un + jnp.dot(d, z_next)
        
        # Guess for the next step
        q = lambda x: z_next[0]*(x-c[1])/(c[0]-c[1])*x/c[0] + z_next[1]*(x-c[0])/(c[1]-c[0])*x/c[1]
        z_guess = jnp.array([q(1+c[0])+un-u_next, q(1+c[1])+un-u_next])
        return (u_next, z_guess), un

    z_guess = jnp.zeros((2, u0.shape[0]))
    _, u_arr = jax.lax.scan(step, (u0, z_guess), t)
    return u_arr

@partial(jax.jit, static_argnums=(0,))
def gauss_legendre_6(f, u0, dt, t, args, rtol, atol):
    """
    Integrates the ODE system using the Gauss-Legendre method of order 4.
    Implementation follows "IV.8 Implementation of Implicit Runge-Kutta Methods" in 
    "Solving Ordinary Differential Equations II" by Hairer and Wanner

    Args:
      f: The right-hand side function of the ODE system.
      u0: Initial condition.
      dt: Time step.
      t: Array of time points.
      args: Additional arguments to pass to f.
      rtol: Relative tolerance for the nonlinear solver.
      atol: Absolute tolerance for the nonlinear solver.
      

    Returns:
      An array of solution values at the given time points.
    """
    c = jnp.array([0.5 - jnp.sqrt(15)/10, 0.5, 0.5 + jnp.sqrt(15)/10])
    A = jnp.array([[5/36, 2/9-jnp.sqrt(15)/15, 5/36-1/30*jnp.sqrt(15)],
                    [5/36+1/24*jnp.sqrt(15), 2/9, 5/36-1/24*jnp.sqrt(15)],
                    [5/36+1/30*jnp.sqrt(15), 2/9+jnp.sqrt(15)/15, 5/36]])
    d = jnp.array([5/3, -4/3, 5/3])
    
    def q(x, z_next):
        z_guess = z_next[0]*(x-c[1])/(c[0]-c[1])*x/c[0]*(x-c[2])/(c[0]-c[2])
        z_guess += z_next[1]*(x-c[0])/(c[1]-c[0])*x/c[1]*(x-c[2])/(c[1]-c[2])
        z_guess += z_next[2]*(x-c[0])/(c[2]-c[0])*(x-c[1])/(c[2]-c[1])*x/c[2]
        return z_guess
    
    @jax.jit
    def step(carry, tn): 
        un, z_guess = carry
        
        def eq(z, args):
            return z - A @ (dt*f(tn + c*dt, un + z, args))
        
        solver = optx.Chord(rtol, atol)
        z_next = optx.root_find(eq, solver, z_guess, args, throw=False, max_steps=10).value
        u_next = un + jnp.dot(d, z_next)
            
        z_guess = q(1+c[:,None], z_next)+un-u_next
        return (u_next, z_guess), un

    z_guess = jnp.zeros((3, u0.shape[0]))
    _, u_arr = jax.lax.scan(step, (u0, z_guess), t)
    return u_arr[::4, ::4]

#############EQUATION####################
def f(t, u, args):
    η, γ, dx = args["η"], args["γ"], args["dx"]
    return Dx(-η/2*u**2 - γ**2 * Dxx(u, dx), dx)

#############HAMILTONIANS################
def H_energy(u, args):
    η, γ, dx = args["η"], args["γ"], args["dx"]
    integrand = (-η/6*u**3 + 0.5*γ**2*Dx(u, dx)**2)
    return dx*jnp.sum(integrand)

def H_mass(u, args):
    return dx * jnp.sum(u)

def H_momentum(u, args):
    return dx * jnp.sum(u**2)

In [29]:
#############SOLVING####################
NUM_SAMPLES = 100
atol, rtol = 1e-12, 1e-12

keys = random.split(random.key(1), NUM_SAMPLES)

a = jax.vmap(initial_condition_kdv, (None, 0))(x, keys)

data = jax.vmap(gauss_legendre_6, (None, 0, None, None, None, None, None))(f, a, dt, t, args, rtol, atol)