In this notebook we numerically examine the formation of a viscous or dispersive shock wave from an initial step function.  These are solutions of Burgers equation

$$
    u_t + u u_x = \epsilon u_{xx}
$$

and the Korteweg-de Vries equation

$$
    u_t + u u_x = \epsilon u_{xxx}
$$

respectively.  We use a Fourier pseudospectral method in space and 3rd-order Runge-Kutta in time.

The illustration of the dispersive shock is not very good, for two reasons.  First, because the solution is highly oscillatory, we would need a much larger number of modes to represent it with high accuracy.  Second, because of the periodic boundary condition, a second dispersive shock forms at the boundary.

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
from IPython.display import HTML
font = {'size'   : 15}
matplotlib.rc('font', **font)

In [None]:
def rk3(u,xi,rhs):
    y2 = u + dt*rhs(u,xi)
    y3 = 0.75*u + 0.25*(y2 + dt*rhs(y2,xi))
    u_new = 1./3 * u + 2./3 * (y3 + dt*rhs(y3,xi))
    return u_new

In [None]:
equation = 'Burgers'   # Change this line to determine which kind of shock is simulated
def rhs(u, xi):
    eps = 0.02
    uhat = np.fft.fft(u)
    if equation == 'Burgers': 
        return -u*np.real(np.fft.ifft(1j*xi*uhat)) + eps*np.real(np.fft.ifft(-xi**2*uhat))
    elif equation == 'KdV':
        return -u*np.real(np.fft.ifft(1j*xi*uhat)) - eps*np.real(np.fft.ifft(-1j*xi**3*uhat))

In [None]:
# Grid
m = 512
L = 2*np.pi
x = np.arange(-m/2,m/2)*(L/m)
xi = np.fft.fftfreq(m)*m*2*np.pi/L

if equation == 'KdV':
    dt = 1.73/((m/2)**3)
    tmax = 0.005

elif equation == 'Burgers':
    dt = 1.73/((m/2)**2)
    tmax = 2.5

#u = 1*(x<0)
u = -0.5*np.tanh(100*x)+0.5

uhat2 = np.abs(np.fft.fft(u))

num_plots = 50
nplt = np.floor((tmax/num_plots)/dt)
nmax = int(round(tmax/dt))

fig = plt.figure(figsize=(12,8))
axes = fig.add_subplot(111)
line, = axes.plot(x,u,lw=3)
axes.set_xlabel(r'$x$',fontsize=30)
plt.tight_layout()
plt.close()

frames = [u.copy()]
tt = [0]
uuhat = [uhat2]

for n in range(1,nmax+1):
    u_new = rk3(u,xi,rhs)

    u = u_new.copy()
    t = n*dt
    # Plotting
    if np.mod(n,nplt) == 0:
        frames.append(u.copy())
        tt.append(t)
        uhat2 = np.abs(np.fft.fft(u))
        uuhat.append(uhat2)
        
def plot_frame(i):
    line.set_data(x,frames[i])
    power_spectrum = np.abs(uuhat[i])**2
    axes.set_title('t= %.2e' % tt[i])
    axes.set_xlim((-np.pi,np.pi))
    axes.set_ylim((-1,2))
    
anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                   frames=len(frames), interval=100)

HTML(anim.to_jshtml())