In [None]:
import numpy as np
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
from IPython.display import HTML
font = {'size'   : 15}
matplotlib.rc('font', **font)

In [None]:
def rk3(u,xi,rhs):
    y2 = u + dt*rhs(u,xi)
    y3 = 0.75*u + 0.25*(y2 + dt*rhs(y2,xi))
    u_new = 1./3 * u + 2./3 * (y3 + dt*rhs(y3,xi))
    return u_new


def rhs(u, xi, epsilon=1.0):
    uhat = np.fft.fft(u)
    return -u*np.real(np.fft.ifft(1j*xi*uhat)) - epsilon*np.real(np.fft.ifft(-1j*xi**3*uhat))
    
def solve_KdV(u0,tmax=1.,m=256,epsilon=1.0, ylims=(-100,300)):
    """Solve the KdV equation using Fourier spectral collocation in space
       and SSPRK3 in time, on the domain (-pi, pi).  The input u0 should be a function.
    """
    # Grid
    L = 2*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L

    dt = 1.73/((m/2)**3)
    u = u0(x)
    uhat2 = np.abs(np.fft.fft(u))

    num_plots = 400
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    line, = axes.plot(x,u,lw=3)
    xi_max = np.max(np.abs(xi))
    axes.set_xlabel(r'$x$',fontsize=30)
    plt.close()

    frames = [u.copy()]
    tt = [0]
    uuhat = [uhat2]

    for n in range(1,nmax+1):
        u_new = rk3(u,xi,rhs)

        u = u_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(u.copy())
            tt.append(t)
        
    def plot_frame(i):
        line.set_data(x,frames[i])
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-np.pi,np.pi))
        axes.set_ylim(ylims)

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                       frames=len(frames), interval=100)

    return HTML(anim.to_jshtml())

## Initial sinusoid

Here we set up something similar to the FPUT experiment, with a single low-frequency mode as initial condition on a periodic domain.  Notice how, at some later times, the solution comes close to the initial condition.

In [None]:
def u0(x):
    return 100*np.sin(x)
solve_KdV(u0)

## Formation of a soliton train from an initial positive pulse.

In [None]:
def u0(x):
    return 2000*np.exp(-10*(x+2)**2)
solve_KdV(u0, tmax=0.005, ylims=(-100,3000))

# Interaction of two solitons

In [None]:
A = 25; B = 16;
def u0(x):
    return 3*A**2/np.cosh(0.5*(A*(x+2.)))**2 + 3*B**2/np.cosh(0.5*(B*(x+1)))**2
solve_KdV(u0,tmax = 0.006, ylims=(-10,3000))

The next simulation shows a comparison between the propagation of a single soliton versus the interaction of two solitons.

In [None]:
# Grid
m = 256
L = 2*np.pi
x = np.arange(-m/2,m/2)*(L/m)
xi = np.fft.fftfreq(m)*m*2*np.pi/L

dt = 1.73/((m/2)**3)

A = 25; B = 16;
u = 3*A**2/np.cosh(0.5*(A*(x+2.)))**2 + 3*B**2/np.cosh(0.5*(B*(x+1)))**2
v = 3*A**2/np.cosh(0.5*(A*(x+2.)))**2

tmax = 0.006

uhat2 = np.abs(np.fft.fft(u))

num_plots = 400
nplt = np.floor((tmax/num_plots)/dt)
nmax = int(round(tmax/dt))

fig = plt.figure(figsize=(12,8))
axes = fig.add_subplot(111)
line, = axes.plot(x,u,lw=3)
line2, = axes.plot(x,v,lw=3)
xi_max = np.max(np.abs(xi))
axes.set_xlabel(r'$x$',fontsize=30)
plt.close()

frames = [u.copy()]
vframes = [v.copy()]
tt = [0]
uuhat = [uhat2]

for n in range(1,nmax+1):
    u_new = rk3(u,xi,rhs)
    v_new = rk3(v,xi,rhs)

    u = u_new.copy()
    v = v_new.copy()
    t = n*dt
    # Plotting
    if np.mod(n,nplt) == 0:
        frames.append(u.copy())
        vframes.append(v.copy())
        tt.append(t)
        uhat2 = np.abs(np.fft.fft(u))
        uuhat.append(uhat2)
        
def plot_frame(i):
    line.set_data(x,frames[i])
    line2.set_data(x,vframes[i])
    power_spectrum = np.abs(uuhat[i])**2
    axes.set_title('t= %.2e' % tt[i])
    axes.set_xlim((-np.pi,np.pi))
    axes.set_ylim((-10,3000))
    
anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                   frames=len(frames), interval=100)

HTML(anim.to_jshtml())

## Formation of a dispersive shockwave

In [None]:
def u0(x):
    return -500*np.exp(-10*(x-2)**2)
solve_KdV(u0, tmax=0.005, epsilon=0.1, ylims=(-600,300))