In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import animation
from JSAnimation import IPython_display # if you don't have this, get it from: https://github.com/jakevdp/JSAnimation

# Dealing with nonlinearity and stiffness

$$
\newcommand{\F}{\mathcal F}
\newcommand{\Finv}{{\mathcal F}^{-1}}
$$
In the last notebook, we saw that discretizing Burgers' equation can be problematic because of the combination of stiffness and nonlinearity.  Fully explicit methods are slow because the stiff term requires small step sizes, while fully implicit methods are slow because of the need to solve a nonlinear system at each step.  Some solutions we know about now are:

- Operator splitting methods: handle the nonlinear part explicitly and the linear part implicitly or exactly
- Integrating factor methods: remove the stiff linear part by a problem-specific transformation

For the KdV equation, we have the same difficulty, but the stiffness is even worse since the linear term involves a third derivative.  In [the first spectral methods notebook](Spectral_methods_1.ipynb), we applied the integrating factor approach to remove the stiffness.  Here is a direct approach, based on a pseudospectral semi-discretization.  We start with KdV:

$$
u_t = -u u_x - u_{xxx}
$$
and compute the spatial derivatives via

\begin{align}
u_x & = \Finv(i\xi \F(u)) \\
u_{xxx} & = \Finv(-i\xi^3 \F(u)).
\end{align}
This gives
$$
u'(t) = -u \Finv(i\xi \F(u)) - \Finv(-i\xi^3 \F(u))
$$

What time step should be stable for this discretization?  Let's construct a rough "linearized" spectrum, using the same code from the last notebook but adapted to the KdV equation:

In [None]:
def F_matrix(m):
    F = np.zeros((m,m),dtype=complex)
    for j in range(m):
        v = np.zeros(m)
        v[j] = 1.
        F[:,j] = np.fft.fft(v)
    return F

L = 2*np.pi
m = 256
x = np.arange(-m/2,m/2)*(L/m)
xi = np.fft.fftfreq(m)*m*2*np.pi/L

A = 25; B = 16;
u = 3*A**2/np.cosh(0.5*(A*(x+2.)))**2 + 3*B**2/np.cosh(0.5*(B*(x+1)))**2
#u = u/u

F = F_matrix(m)
Finv = np.linalg.inv(F)
D1 = np.diag(1j*xi)
D3 = D1**3
rhs_matrix  = np.dot(-np.diag(u),np.dot(Finv,np.dot(D1,F))) - np.dot(Finv,np.dot(D3,F))
rhs2_matrix = np.dot(-np.diag(u),np.dot(Finv,np.dot(D1,F)))

lamda = np.linalg.eigvals(rhs_matrix)
lamda2 = np.linalg.eigvals(rhs2_matrix)
plt.plot(np.real(lamda),np.imag(lamda),'or',np.real(lamda2),np.imag(lamda2),'ob')
#plt.plot(np.real(lamda),np.imag(lamda),'or')
plt.axis('equal');
plt.legend(['Full RHS','$u u_x$ only'])

As you should have expected, the eigenvalues are purely imaginary for this dispersive wave equation.  The largest eigenvalues come from the 3rd-derivative term, and scale like $(m/2)^3$:

In [None]:
print np.max(np.abs(lamda))
print (m/2)**3

In constructing the spectrum, I used for $u$ the initial condition that we'll use below.  But it doesn't much matter what function we take for $u$, because the 3rd-derivative term is so stiff.

# Time integration

We'll use a 3rd-order Runge-Kutta method to integrate in time. How large a step size can we take?  Let's see.

In [None]:
from nodepy import rk
ssp33 = rk.loadRKM('SSP33')
ts = rk.linearly_stable_step_size(ssp33,rhs_matrix,tol=1.e-2)
print ts
print ts*(m/2)**3
print ssp33.imaginary_stability_interval()

The step size we can take is given by the length of the imaginary axis interval contained in the Runge-Kutta method's stability region, divided by the largest eigenvalue $(m/2)^3$.

# Implementation

In [None]:
def rhs(u, xi, equation='KdV'):
    uhat = np.fft.fft(u)
    if equation == 'Burgers': 
        return -u*np.real(np.fft.ifft(1j*xi*uhat)) + np.real(np.fft.ifft(-xi**2*uhat))
    elif equation == 'KdV':
        return -u*np.real(np.fft.ifft(1j*xi*uhat)) - np.real(np.fft.ifft(-1j*xi**3*uhat))

In [None]:
def rk3(u,xi,rhs):
    y2 = u + dt*rhs(u,xi)
    y3 = 0.75*u + 0.25*(y2 + dt*rhs(y2,xi))
    u_new = 1./3 * u + 2./3 * (y3 + dt*rhs(y3,xi))
    return u_new

In [None]:
# Grid
m = 256
L = 2*np.pi
x = np.arange(-m/2,m/2)*(L/m)
xi = np.fft.fftfreq(m)*m*2*np.pi/L

dt = 1.75/((m/2)**3)
print 'dt = ', dt

A = 25; B = 16;
u = 3*A**2/np.cosh(0.5*(A*(x+2.)))**2 + 3*B**2/np.cosh(0.5*(B*(x+1)))**2
tmax = 0.006

num_plots = 50
nplt = np.floor((tmax/num_plots)/dt)
nmax = int(round(tmax/dt))

fig = plt.figure()
axes = fig.add_subplot(111)
line, = axes.plot(x,u,lw=3)

frames = [u.copy()]
tt = [0]

for n in range(1,nmax+1):
    u_new = rk3(u,xi,rhs)

    u = u_new.copy()
    t = n*dt
    # Plotting
    if np.mod(n,nplt) == 0:
        frames.append(u.copy())
        tt.append(t)
        
def plot_frame(i):
    line.set_data(x,frames[i])
    axes.set_title('t= %.2f' % tt[i])
    axes.set_xlim((-np.pi,np.pi))
    
matplotlib.animation.FuncAnimation(fig, plot_frame, frames=len(frames), interval=20)

Now try taking $m=128$ in the code above.  What happens?  Can you say why?

## Alternative discretizations

There are a number of other relatively *ad hoc* approaches to developing non-stiff pseudospectral discretizations.  So far I have found no development of a general approach along these lines.  Thus it will be explained using KdV as an example.

Let's forget about the nonlinear term for the moment, and consider the equation
$$
u_t = u_{xxx}.
$$
Substituting our usual ansatz $u(x,t) = \exp(i(\xi x - \omega t)$ in the above gives the dispersion relation:
$$
\omega(\xi) = \xi^3.
$$

Next let us discretize with the midpoint method in Fourier space:
$$
\frac{\hat{u}^{n+1} - \hat{u}^{n-1}}{2\Delta t} = -i\xi^3\hat{u}^n.
$$

Of course, this is a bit silly since we know that we could solve this linear problem exactly in time.  So the idea is to replace the factor $-i\xi^3$ on the right by some other function $g(\xi)$ so that the midpoint method will give the exact answer:
$$
\frac{\hat{u}^{n+1} - \hat{u}^{n-1}}{2\Delta t} = g(\xi) \hat{u}^n.
$$

If $u$ is to be exact, then (using the exact dispersion relation) we should have

\begin{align}
\hat{u}^{n-1} & = \hat{u}^n e^{i \xi^3 \Delta t} \\
\hat{u}^{n+1} & = \hat{u}^n e^{-i \xi^3 \Delta t}.
\end{align}

Substituting these in the equation with $g$ gives
$$
g(\xi) = - \frac{i}{\Delta t} \sin(\xi^3 \delta t).
$$

Based on all of this, we'll now semi-discretize KdV as follows:

$$
u'(t) = -u \F^{-1}(i \xi \F(u)) - \frac{i}{\Delta t}\F^{-1}\left(\sin(\xi^3 \delta t)\F(u)\right).
$$

The dispersive part of the spectrum of this discretization is much better behaved, as we can see:

In [None]:
m = 256
x = np.arange(-m/2,m/2)*(2*np.pi/m)
dx = x[1]-x[0]
L = x[-1]-x[0] + dx
A = 25; B = 16;
u = 3*A**2/np.cosh(0.5*(A*(x+2.)))**2 + 3*B**2/np.cosh(0.5*(B*(x+1)))**2
#u = u/u

F = F_matrix(m)
Finv = np.linalg.inv(F)
xi = np.fft.fftfreq(m)*m*2*np.pi/L
D1 = np.diag(1j*xi)
D3 = -1j*np.diag(np.sin(dt*xi**3))
rhs_mod_matrix = np.dot(-np.diag(u),np.dot(Finv,np.dot(D1,F))) - np.dot(Finv,np.dot(D3,F))
rhs2_mod_matrix = - np.dot(Finv,np.dot(D3,F))

lamda = np.linalg.eigvals(rhs_mod_matrix)
lamda2 = np.linalg.eigvals(rhs2_mod_matrix)
plt.plot(np.real(lamda),np.imag(lamda),'or',np.real(lamda2),np.imag(lamda2),'ob')
plt.axis('equal');

In [None]:
from nodepy import rk
ssp33 = rk.loadRKM('SSP33')
ts = rk.linearly_stable_step_size(ssp33,rhs_mod_matrix,tol=1.e-2)
print ts

In [None]:
def rhs_mod(u, xi, dt):
    uhat = np.fft.rfft(u)
    return -u*np.real(np.fft.irfft(1j*xi*uhat)) - np.fft.irfft(-1j*np.sin(dt*xi**3)*uhat)/dt

In [None]:
def rk3_mod(u,xi,rhs,dt):
    y2 = u + dt*rhs(u,xi,dt)
    y3 = 0.75*u + 0.25*(y2 + dt*rhs(y2,xi,dt))
    u_new = 1./3 * u + 2./3 * (y3 + dt*rhs(y3,xi,dt))
    return u_new

def midpoint(u_old, u, xi, rhs, dt):
    return u_old + 2*dt*rhs(u,xi,dt)

In [None]:
# Grid
m = 256
L = 2*np.pi
x = np.arange(-m/2,m/2)*(L/m)
xi = np.fft.rfftfreq(m)*m*2*np.pi/L


# Initial data
A = 25; B = 16;
u = 3*A**2/np.cosh(0.5*(A*(x+2.)))**2 + 3*B**2/np.cosh(0.5*(B*(x+1)))**2
tmax = 0.006
dt = 0.5/(np.max(u)*m/2)#2 * 1.75/((m/2)**3)

num_plots = 50
nplt = np.floor((tmax/num_plots)/dt)
nmax = int(round(tmax/dt))

fig = plt.figure()
axes = fig.add_subplot(111)
line, = axes.plot(x,u,lw=3)

frames = [u.copy()]
tt = [0]

#Take one RK step
u_new = rk3(u,xi,rhs_mod,dt)
u_old = u.copy()
u = u_new.copy()
t = dt

for n in range(1,nmax+1):
    u_new = midpoint(u_old, u, xi, rhs_mod, dt)
    u_old[:] = u
    u[:] = u_new
    
    t = n*dt
    # Plotting
    if np.mod(n,nplt) == 0:
        frames.append(u.copy())
        tt.append(t)
        
def plot_frame(i):
    line.set_data(x,frames[i])
    axes.set_title('t= %.2f' % tt[i])
    axes.set_xlim((-np.pi,np.pi))
    
matplotlib.animation.FuncAnimation(fig, plot_frame, frames=len(frames), interval=20)

In [None]:
print dt

In [None]:
print np.max(np.abs(lamda))
print np.max(u)*(m/2)

In [None]:
print np.max(np.abs(lamda))/(np.max(u)*(m/2))

In [None]:
np.fft.irfft?

In [None]:
u - np.fft.irfft(np.fft.rfft(u))

In [None]:
xi = np.fft.rfftfreq(m)*m*2*np.pi/L


In [None]:
xi