In [None]:
from IPython.core.display import HTML
css_file = './custom.css'
HTML(open(css_file, "r").read())

###### Content provided under a Creative Commons Attribution license, CC-BY 4.0; code under MIT License. (c)2015-2023 [David I. Ketcheson](http://davidketcheson.info)

##### Version 0.4 - March 2023

In [None]:
import numpy as np
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
from IPython.display import HTML
font = {'size'   : 15}
matplotlib.rc('font', **font)

fft = np.fft.fft
ifft = np.fft.ifft

# Time discretization for pseudospectral methods

In previous notebooks we have touched briefly on the topic of time discretization.  So far, we have used explicit Runge-Kutta methods and a simple heuristic for determining a stable time step size.

## Stiff semilinear problems

Many important applications of pseudospectral methods involve evolution PDEs of the form

$$
    u_t = f(u) + L(u)
$$

where $f$ is a non-stiff, nonlinear operator and $L$ is a stiff, linear operator.  Most often, $f$ involves at most first-order derivatives while $L$ involves higher-order derivatives.  An example that we have already dealt with is the KdV equation

$$
    u_t + uu_x + u_{xxx} = 0
$$

in which $f(u)=-uu_x$ and $L(u)=-u_{xxx}$.  Other applications which share this overall structure include many other dispersive wave models, the Navier-Stokes equations, the Kuramoto-Sivashinsky equation, and many more.

Application of an explicit Runge-Kutta method to such problems is requires that the time step satisfy a condition of the form

$$
    \Delta t \le C (\Delta x)^j
$$

where $j$ is the order of the highest derivative in $L(u)$ and $C$ is a constant depending on the spectrum of $g$ and the stability region of the RK method.  This is inefficient, since discretizations based on spectral differentiation in space and high order RK in time can usually give a reasonable local error with $\Delta t \approx \Delta x$.  The computational cost of an explicit time discretization becomes especially noticeable if a large number of spatial grid points is required to resolve the solution.

One way to overcome this is to use a fully implicit time discretization with unconditional stability.  However, this incurs the substantial cost of solving a nonlinear algebraic system of equations at each step.

A number of specialized classes of time discretizations have been developed to efficiently solve problems in this class.  In this notebook we will briefly examine each of the following:

- Simple operator splitting (fractional-step methods)
- ImEx additive Runge-Kutta methods
- Exponential integrators

In the examples here, we consider only stiff semilinear problems.  For some problems of interest, the stiff operator $L$ is also nonlinear.  In such problems, similar approaches can be employed, but the cost per step will be noticeably higher.

For an excellent (though now somewhat out-of-date) study of these and other methods, see [the 2005 paper of Kassam and Trefethen](https://epubs.siam.org/doi/epdf/10.1137/S1064827502410633).

## Explicit integration
Let's time our earlier implementation.

In [None]:
def rk3(u,xi,rhs,dt,filtr):
    y2 = u + dt*rhs(u,xi,filtr)
    y3 = 0.75*u + 0.25*(y2 + dt*rhs(y2,xi,filtr))
    u_new = 1./3 * u + 2./3 * (y3 + dt*rhs(y3,xi,filtr))
    return u_new

def rhs(u, xi, filtr):
    uhat = np.fft.fft(u)
    return -u*np.real(np.fft.ifft(1j*xi*uhat)) - \
            np.real(np.fft.ifft(-1j*xi**3*uhat))
    
def solve_KdV_ERK(m,dt):
    L = 2*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L

    u = 1500*np.exp(-10*(x+2)**2)
    tmax = 0.005

    uhat2 = np.abs(np.fft.fft(u))

    num_plots = 50
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    frames = [u.copy()]
    tt = [0]
    uuhat = [uhat2]
    filtr = np.ones_like(u)
    
    for n in range(1,nmax+1):
        u_new = rk3(u,xi,rhs,dt,filtr)

        u = u_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(u.copy())
            tt.append(t)
            uhat2 = np.abs(np.fft.fft(u))
            uuhat.append(uhat2)
    return frames, uuhat, x, tt, xi
    


def plot_solution(frames, uuhat, x, tt, xi):
    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(211)
    axes2 = fig.add_subplot(212)
    line, = axes.plot(x,frames[0],lw=3)
    line2, = axes2.semilogy(xi,np.abs(np.fft.fft(frames[0])))
    xi_max = np.max(np.abs(xi))
    axes2.semilogy([xi_max/2.,xi_max/2.],[1.e-6,4e8],'--r')
    axes2.semilogy([-xi_max/2.,-xi_max/2.],[1.e-8,4e10],'--r')
    axes.set_xlabel(r'$x$',fontsize=30)
    axes2.set_xlabel(r'$\xi$',fontsize=30)
    plt.tight_layout()
    plt.close()

    def plot_frame(i):
        line.set_data(x,frames[i])
        power_spectrum = np.abs(uuhat[i])**2
        line2.set_data(np.sort(xi),power_spectrum[np.argsort(xi)])
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-np.pi,np.pi))
        axes.set_ylim((-200,3000))

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                              frames=len(frames), interval=100,
                                              repeat=False)
    return HTML(anim.to_jshtml())

In [None]:
%%timeit

m = 512
dt = 1.73/((m/2)**3)

frames, uuhat, x, tt, xi = solve_KdV_ERK(m,dt)

# Operator splitting

Operator splitting consists of alternately solving the evolution equations

\begin{align}
    u_t & = f(u) \\
    u_t & = L(u).
\end{align}

In the simplest operator splitting approach, known in different contexts as Lie-Trotter splitting or Godunov splitting, one time step simply consists of a full time step on each equation.  For the substeps, one may use any desired time integration method.  For instance, if the explicit Euler method is used, the discretization takes the form

\begin{align}
    u^* & = u^n + \Delta t f(u^n) \\
    u^{n+1} & = u^* + \Delta t L(u^*).
\end{align}

We can write this more abstractly as

$$
    u^{n+1} = \exp(\Delta t L) \exp(\Delta t f) u^n,
$$

where it is understood that $\exp(L)$ represents an approximation to the solution operator for the equation $u_t = L(u)$.  

## Accuracy

Clearly, splitting methods involve two sources of discretization error:

 - Errors due to the discretization of the substeps
 - Errors due to the splitting itself
 
In the case of Lie-Trotter splitting, even if the substeps are solved exactly, the splitting error results in a first-order accurate method.  Second-order accuracy can be achieved using **Strang splitting**:

$$
    u^{n+1} = \exp((\Delta t/2) L) \exp(\Delta t f) \exp((\Delta t/2) L) u^n.
$$

Although this seems to require 50% more substeps, in practice the cost is negligible since the half-steps with $L$ in adjacent steps can be combined, so one only needs to take a half-step at the beginning and a half-step at the end.  However, in practice one often sees relatively little difference in accuracy between Lie-Trotter and Strang splitting.

## Stability and time step size

For stiff semilinear problems (such as the KdV equation), since $g$ is a linear differential operator, one can use a pure spectral discretization in order to solve $u_t = L(u)$ exactly (and cheaply), as discussed in the first notebook of this course.  With this approach, the $L$ substep is also unconditionally stable, so the step size can be chosen based entirely on the properties of $f$ (or based on desired accuracy).

For some problems, such as the nonlinear Schrodinger equation, operator splitting allows for both substeps to be solved exactly.

In [None]:
def rhs_f(u, xi, filtr):
    # Evaluate only the non-stiff nonlinear term
    uhat = np.fft.fft(u)
    return -u*np.real(np.fft.ifft(1j*xi*uhat*filtr))

#def rhs_f(u, xi, filtr):
#    return -np.real(ifft(1j*xi*fft(u**2)*filtr))

def substep_g(u, xi, dt):
    # Advance the solution using only the stiff linear term
    uhat = np.fft.fft(u)
    return np.real(np.fft.ifft(np.exp(1j*xi**3*dt)*uhat))

def solve_KdV_Lie_Trotter(m,dt,use_filter=True):
    L = 2*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L

    u = 1500*np.exp(-10*(x+2)**2)
    tmax = 0.005

    uhat2 = np.abs(np.fft.fft(u))

    num_plots = 50
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    frames = [u.copy()]
    tt = [0]
    uuhat = [uhat2]

    filtr = np.ones_like(xi)
    
    if use_filter:
        xi_max = np.max(np.abs(xi))
        filtr[np.where(np.abs(xi)>xi_max*2./3)] = 0.

    for n in range(1,nmax+1):
        u_star = rk3(u,xi,rhs_f,dt,filtr)
        u_new = substep_g(u_star,xi,dt)
            
        u = u_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(u.copy())
            tt.append(t)
            uhat2 = np.abs(np.fft.fft(u))
            uuhat.append(uhat2)
    return frames, uuhat, x, tt, xi

In [None]:
m = 2048
umax = 3000
dt = 1.73/(m/2)/umax
use_filter=True

frames, uuhat, x, tt, xi = solve_KdV_Lie_Trotter(m,dt,use_filter)

In [None]:
plot_solution(frames, uuhat, x, tt, xi)

Notes:
- Instability usually appears even with a fine spatial mesh, unless filtering is applied.  This appears to be due to the first-order operator splitting itself exciting an aliasing instability.
- Due to the improvement in linear stability, we can take a drastically larger time step.  Even with a highly-refined mesh ($m=2048$) the simulation above runs in less than one second.
- A more efficient implementation could reduce by one the number of FFTs required per step.

## A high-order operator splitting method

Many higher-order operator splitting methods have been developed; a collection of many of them is maintained at https://www.asc.tuwien.ac.at/~winfried/splitting/index.php.  Here we implement a 4th-order method that takes the form

$$
u^{n+1} = \prod_{j=1}^5 e^{b_j\Delta t f}e^{a_j\Delta t L} u^n,
$$

i.e. we alternate between solving the stiff and non-stiff parts, 5 times each (actually just 4 times for the non-stiff part since $b_5=0$).  This particular scheme is symmetric, meaning that $a_j = a_{6-j}$ and $b_j=b_{5-j}$.

In [None]:
a = np.array([0.267171359000977615,-0.0338279096695056672,
              0.5333131013370561044,-0.0338279096695056672
              ,0.267171359000977615])
b = np.array([-0.361837907604416033,0.861837907604416033,
              0.861837907604416033,-0.361837907604416033,0.])

def rhs_f(u, xi, filtr):
    # Evaluate only the non-stiff nonlinear term
    uhat = np.fft.fft(u)
    return -u*np.real(np.fft.ifft(1j*xi*uhat*filtr))

def substep_L(u, xi, dt):
    # Advance the solution using only the stiff linear term
    uhat = np.fft.fft(u)
    return np.real(np.fft.ifft(np.exp(1j*xi**3*dt)*uhat))

def solve_KdV_AK4(m,dt,use_filter=True):
    L = 2*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L

    u = 1500*np.exp(-10*(x+2)**2)
    tmax = 0.005

    uhat2 = np.abs(np.fft.fft(u))

    num_plots = 50
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    frames = [u.copy()]
    tt = [0]
    uuhat = [uhat2]

    filtr = np.ones_like(xi)
    
    if use_filter:
        xi_max = np.max(np.abs(xi))
        filtr[np.where(np.abs(xi)>xi_max*2./3)] = 0.

    for n in range(1,nmax+1):
        u_star = u.copy()
        for j in range(5):
            u_star = substep_L(u_star,xi,a[j]*dt)
            u_star = rk3(u_star,xi,rhs_f,b[j]*dt,filtr)
            
        u = u_star.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(u.copy())
            tt.append(t)
            uhat2 = np.abs(np.fft.fft(u))
            uuhat.append(uhat2)
    return frames, uuhat, x, tt, xi

In [None]:
m = 2048
umax = 3000
dt = 1.73/(m/2)/umax
use_filter=True

frames, uuhat, x, tt, xi = solve_KdV_AK4(m,dt,use_filter)

In [None]:
plot_solution(frames, uuhat, x, tt, xi)

With this method also, it seems that filtering is necessary regardless of the spatial resolution.

## Comparison

In [None]:
m = 1024
umax = 3000
dt = 1.73/(m/2)/umax
use_filter=True

frames1, uuhat, x, tt, xi = solve_KdV_Lie_Trotter(m,dt,use_filter)
frames4, uuhat, x, tt, xi = solve_KdV_AK4(m,dt,use_filter)

plt.figure(figsize=(12,8))
plt.plot(x,frames1[-1])
plt.plot(x,frames4[-1])
plt.legend(['1st-order','4th-order']);

For this problem, we don't see any difference between the 4th-order solution and the 1st-order solution.  This suggests that the temporal splitting error is not the dominant part of the numerical error.