In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
from IPython.display import HTML
font = {'size'   : 10}
matplotlib.rc('font', **font)
fft = np.fft.fft
ifft = np.fft.ifft
kappa = 1.

In [None]:
def rk3(u,xi,rhs,dt,params=None):
    y2 = u + dt*rhs(u,xi,params)
    y3 = 0.75*u + 0.25*(y2 + dt*rhs(y2,xi,params))
    u_new = 1./3 * u + 2./3 * (y3 + dt*rhs(y3,xi,params))
    return u_new

def plot_solution(frames, uuhat, x, tt, xi):
    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    line, = axes.plot(x,frames[0],lw=3)
    xi_max = np.max(np.abs(xi))
    axes.set_xlabel(r'$x$',fontsize=30)
    plt.tight_layout()
    plt.close()

    def plot_frame(i):
        line.set_data(x,frames[i])
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-100*np.pi,100*np.pi))
        axes.set_ylim((-0.1,3))

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                              frames=len(frames), interval=100,
                                              repeat=False)
    return HTML(anim.to_jshtml())

# Heat Equation

In [None]:
def rhs_heatH(uv, xi, params):
    kappa = params['kappa']
    c = params['c']
    u = uv[0,:]
    v = uv[1,:]
    duvdt = np.zeros_like(uv)
    uhat = fft(u)
    vhat = fft(v)
    duvdt[0,:] = np.real(np.fft.ifft(1j*xi*vhat))
    duvdt[1,:] = c*(np.real(ifft(1j*xi*uhat)) - v)
    return duvdt
    
    
def solve_heatH(u0,tmax=1.,m=256, ylims=(-100,300),use_filter=True,c=20):
    # Grid
    L = 2*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L
    
    params = {'c': c, 'kappa': 1.}

    u_max = np.max(u0(x))
    max_speed = np.sqrt(c)
    dt = 1.73*L/(m/2)/max_speed*0.1
    dx = x[1]-x[0]
    
    uv = np.zeros((2,len(x)))
    uv[0,:] = u0(x)
    uv[1,:] = np.real(ifft(1j*xi*fft(uv[0,:])))

    num_plots = 60
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    line, = axes.plot(x,uv[0,:],lw=3)
    xi_max = np.max(np.abs(xi))
    axes.set_xlabel(r'$x$',fontsize=30)
    plt.close()

    frames = [uv.copy()]
    tt = [0]

    for n in range(1,nmax+1):
        uv_new = rk3(uv,xi,rhs_heatH,dt,params=params)

        uv = uv_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(uv.copy())
            tt.append(t)
        
    def plot_frame(i):
        line.set_data(x,frames[i][0,:])
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-1*np.pi,1*np.pi))
        axes.set_ylim(ylims)

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                       frames=len(frames), interval=100)

    return HTML(anim.to_jshtml()), frames, tt

def u0(x):
    #return np.sin(4*x)
    return np.exp(-10*x**2)
    kappa = 1.
    alpha = 20.
    cc = 1.
    return np.sqrt(2*alpha/kappa)/np.cosh(np.sqrt(alpha)*x)

def rhs_heat(u, xi, params):
    dudt = np.zeros_like(u)
    uhat = fft(u)
    dudt = np.fft.ifft(-xi**2*uhat)
    return dudt
    
    
def solve_heat(u0,tmax=1.,m=256, ylims=(-100,300)):
    # Grid
    L = 2*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L

    u_max = np.max(u0(x))
    dt = 1.73*L/(m/2)**2 * 0.1
    dx = x[1]-x[0]    
    
    u = np.zeros(len(x))
    u = u0(x)

    num_plots = 60
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    line, = axes.plot(x,np.abs(u),lw=3)
    xi_max = np.max(np.abs(xi))
    axes.set_xlabel(r'$x$',fontsize=30)
    plt.close()

    frames = [u.copy()]
    tt = [0]

    for n in range(1,nmax+1):
        u_new = rk3(u,xi,rhs_heat,dt)

        u = u_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(u.copy())
            tt.append(t)
        
    def plot_frame(i):
        line.set_data(x,np.abs(frames[i]))
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-1*np.pi,1*np.pi))
        axes.set_ylim(ylims)

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                       frames=len(frames), interval=100)

    return HTML(anim.to_jshtml()), frames, tt, x



In [None]:
anim, frames, tt, x = solve_heat(u0, tmax=2.0,  ylims=(-0.1,1.1),m=512)

cvals = [10, 50, 100]
m = 512

anim, frames20, tt20 = solve_heatH(u0, tmax=2.0,  ylims=(-0.1,1.1),m=m, c=cvals[0])
anim, frames200, tt200 = solve_heatH(u0, tmax=2.0,  ylims=(-0.1,1.1),m=m, c=cvals[1])
anim, frames2000, tt2000 = solve_heatH(u0, tmax=2.0,  ylims=(-0.1,1.1),m=m, c=cvals[2])

In [None]:
def plot_comparison(frameslist, x, tt):
    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    lines = []
    for j, frames in enumerate(frameslist):
        if j == 0:
            lines.append(axes.plot(x,np.abs(frames[0]),lw=3)[0])
        else:
            lines.append(axes.plot(x,np.abs(frames[0][0,:]),lw=3)[0])
    axes.set_xlabel(r'$x$',fontsize=30)
    axes.legend(['Heat equation']+['$\lambda$='+str(cval) for cval in cvals])
    plt.close()

    def plot_frame(i):
        for j, frames in enumerate(frameslist):
            if j == 0:
                lines[j].set_data(x,np.abs(frames[i]))
            else:
                lines[j].set_data(x,np.abs(frames[i][0,:]))
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-1*np.pi,1*np.pi))
        axes.set_ylim((-0.1,1.1))

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                              frames=41,
                                              interval=50,
                                              repeat=False)
    return HTML(anim.to_jshtml())

plot_comparison([frames, frames20, frames200, frames2000], x, tt)

In [None]:
frameslist = [frames, frames20, frames200, frames2000]

fig = plt.figure(figsize=(12,8))
axes = fig.add_subplot(111)
lines = []
axes.plot(x,frames[0],'--k')
#axes.set_xlim(-1.5,1.5)
i = 2
for j, iframes in enumerate(frameslist):
    if j == 0:
        lines.append(axes.plot(x,np.abs(iframes[i]),lw=3)[0])
    else:
        lines.append(axes.plot(x,np.abs(iframes[i][0,:]),lw=3)[0])
axes.set_xlabel(r'$x$',fontsize=20)
axes.set_ylabel(r'$u$',fontsize=20)
axes.legend(['Initial data','Heat equation']+[r'$\tau^{-1}$='+str(cval) for cval in cvals],fontsize=15);
plt.savefig('heat_comparison3.pdf')

# NLS

## Solution of original NLS

In [None]:


def rhs_NLS(u, xi, filtr, kappa=1.):
    dudt = np.zeros_like(u,dtype='complex128')
    uhat = fft(u)
    dudt = 1j*(np.fft.ifft(-xi**2*uhat)) + 1j*kappa*u*np.abs(u)**2
    return dudt
    
    
def solve_NLS(u0,tmax=1.,m=256, ylims=(-100,300)):
    # Grid
    L = 4*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L

    u_max = np.max(u0(x))
    dt = 1.73*L/(m/2)**2 * 0.1
    dx = x[1]-x[0]    
    
    u = np.zeros(len(x),dtype='complex128')
    u = u0(x)

    num_plots = 40
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    line, = axes.plot(x,np.abs(u),lw=3)
    xi_max = np.max(np.abs(xi))
    axes.set_xlabel(r'$x$',fontsize=30)
    plt.close()

    frames = [u.copy()]
    tt = [0]

    for n in range(1,nmax+1):
        u_new = rk3(u,xi,rhs_NLS,dt)

        u = u_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(u.copy())
            tt.append(t)
        
    def plot_frame(i):
        line.set_data(x,np.abs(frames[i]))
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-1*np.pi,1*np.pi))
        axes.set_ylim(ylims)

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                       frames=len(frames), interval=100, repeat=False)

    return HTML(anim.to_jshtml()), frames, tt, x

## Solution of hyperbolized NLS

In [None]:
def rhs_NLSH(uv, xi, params):
    kappa = params['kappa']
    c = 1/params['tau']
    u = uv[0,:]
    v = uv[1,:]
    duvdt = np.zeros_like(uv,dtype='complex128')
    uhat = fft(u)
    vhat = fft(v)
    duvdt[0,:] = 1j*(np.fft.ifft(1j*xi*vhat)) + 1j*kappa*u*np.abs(u)**2
    duvdt[1,:] = 1j*c*(v -(ifft(1j*xi*uhat)) )
    return duvdt
    
    
def solve_NLSH(u0,tmax=1.,m=256, ylims=(-100,300),use_filter=True,tau=1/20):
    # Grid
    L = 4*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L
    
    params = {'tau': tau, 'kappa': 1.}

    u_max = np.max(u0(x))
    max_speed = max(np.sqrt(1/tau),np.max(np.abs(u0(x))**3))
    dt = 1.73*L/(m/2)/max_speed 
    dx = x[1]-x[0]
    
    filtr = np.ones_like(xi)
    
    if use_filter:
        xi_max = np.max(np.abs(xi))
        filtr[np.where(np.abs(xi)>xi_max*2./3)] = 0.
    
    uv = np.zeros((2,len(x)),dtype='complex128')
    uv[0,:] = u0(x)
    uv[1,:] = ifft(1j*xi*fft(uv[0,:]))

    num_plots = 40
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    line, = axes.plot(x,np.abs(uv[0,:]),lw=3)
    xi_max = np.max(np.abs(xi))
    axes.set_xlabel(r'$x$',fontsize=30)
    plt.close()

    frames = [uv.copy()]
    tt = [0]

    for n in range(1,nmax+1):
        uv_new = rk3(uv,xi,rhs_NLSH,dt,params=params)

        uv = uv_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(uv.copy())
            tt.append(t)
        
    def plot_frame(i):
        line.set_data(x,np.abs(frames[i][0,:]))
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-L/2,L/2))
        axes.set_ylim(ylims)

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                       frames=len(frames), interval=100, repeat=False)

    return HTML(anim.to_jshtml()), frames, tt

In [None]:
def u0(x,alpha=20.):
    kappa = 1.
    cc = 1.
    return np.sqrt(2*alpha/kappa)/np.cosh(np.sqrt(alpha)*x) * np.exp(1j*x*cc/2.)

In [None]:
m = 512
taui_values = [40, 100, 800]
lw = 2
anim, NLS_soln, tt, x = solve_NLS( u0, tmax=2.0, ylims=(-1,6.5),m=m)
anim, NLSH_soln1, tt1 = solve_NLSH(u0, tmax=2.0, ylims=(-1,6.5),m=m, tau=1/taui_values[0])
anim, NLSH_soln2, tt2 = solve_NLSH(u0, tmax=2.0, ylims=(-1,6.5),m=m, tau=1/taui_values[1])
anim, NLSH_soln3, tt3 = solve_NLSH(u0, tmax=2.0, ylims=(-1,6.5),m=m, tau=1/taui_values[2])

In [None]:
def plot_comparison(frameslist, x, tt):
    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    lines = []
    for j, frames in enumerate(frameslist):
        if j == 0:
            lines.append(axes.plot(x,np.abs(frames[0]),'k',lw=lw)[0])
        else:
            lines.append(axes.plot(x,np.abs(frames[0][0,:]),lw=lw)[0])
    axes.set_xlabel(r'$x$',fontsize=30)
    axes.legend(['NLS',r'$1/\tau$='+str(taui_values[0]),
                       r'$1/\tau$='+str(taui_values[1]),
                       r'$1/\tau$='+str(taui_values[2])])
    plt.close()

    def plot_frame(i):
        for j, frames in enumerate(frameslist):
            if j == 0:
                lines[j].set_data(x,np.abs(frames[i]))
            else:
                lines[j].set_data(x,np.abs(frames[i][0,:]))
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-2*np.pi,2*np.pi))
        axes.set_ylim((-0.1,7))

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                              frames=41,
                                              interval=50,
                                              repeat=False)
    return HTML(anim.to_jshtml())

In [None]:
plot_comparison([NLS_soln, NLSH_soln1, NLSH_soln2, NLSH_soln3], x, tt)

In [None]:
# This code generates the figure in the paper.
i = 40
plt.figure(figsize=(8,4))
plt.plot(x,np.abs(NLS_soln[i]),'-k')
plt.plot(x,np.abs(NLSH_soln1[i][0,:]),linestyle='dotted')
plt.plot(x,np.abs(NLSH_soln2[i][0,:]),'-.')
plt.plot(x,np.abs(NLSH_soln3[i][0,:]),'--')
plt.legend(['NLS',r'$\tau^{-1}$='+str(taui_values[0]),
                       r'$\tau^{-1}$='+str(taui_values[1]),
                       r'$\tau^{-1}$='+str(taui_values[2])])
plt.xlabel('x'); plt.ylabel(r'$|u|$');
plt.savefig('NLS_comparison.pdf', bbox_inches = "tight");

# Kuramoto-Sivashinsky

In [None]:
def rk3(u,xi,rhs,dt,filtr):
    y2 = u + dt*rhs(u,xi,filtr)
    y3 = 0.75*u + 0.25*(y2 + dt*rhs(y2,xi,filtr))
    u_new = 1./3 * u + 2./3 * (y3 + dt*rhs(y3,xi,filtr))
    return u_new
    
def plot_pcolor(frames,x=None,t=None):
    ff = np.zeros((len(frames),len(frames[0][0,:])))
    print(ff.shape, frames[0].shape)
    for i, fr in enumerate(frames):
        ff[i,:] = frames[i][0,:]


    plt.figure(figsize=(6,6))
    if (x is not None) and (t is not None):
        pl = plt.pcolormesh(x,t,ff,cmap='viridis',linewidth=0,rasterized=True)
    else:
        pl = plt.pcolormesh(ff,linewidth=0)
    pl.set_edgecolor('face')

## Solution of Hyperbolized KS

In [None]:
def solve_KSH(u0,tmax=1.,L=12*np.pi,m=256, ylims=(-100,300),use_filter=True,c1=20,c2=20,c3=20,animate=False):

    def rhs_KSH(uvwz, xi, filtr):
        u = uvwz[0,:]
        v = uvwz[1,:]
        w = uvwz[2,:]
        z = uvwz[3,:]
        duvwzdt = np.zeros_like(uvwz)
        uhat = fft(u)
        vhat = fft(v)
        what = fft(w)
        zhat = fft(z)
        duvwzdt[0,:] = - np.real(ifft(1j*xi*zhat)) -u*np.real(ifft(1j*xi*filtr*uhat)) - np.real(ifft(1j*xi*vhat))
        duvwzdt[1,:] =   c1*(np.real(ifft(1j*xi*what))-z)
        duvwzdt[2,:] =   c2*(np.real(ifft(1j*xi*vhat))-w)
        duvwzdt[3,:] = - c3*(np.real(ifft(1j*xi*uhat))-v)
        return duvwzdt    
    
    # Grid
    
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L

    u_max = np.max(u0(x))
    max_speed = max(c1,c2,0.5*(u_max+np.sqrt(u_max**2+4*c3)))
    print(max_speed)
    dt = 1.73*L/(m/2)/max_speed * 0.1
    print(dt)
    dx = x[1]-x[0]
    
    filtr = np.ones_like(xi)
    
    if use_filter:
        xi_max = np.max(np.abs(xi))
        filtr[np.where(np.abs(xi)>xi_max*2./3)] = 0.
    
    uvwz = np.zeros((4,len(x)))
    uvwz[0,:] = u0(x)
    uvwz[1,:] = 1*np.real(ifft(1j*xi*fft(uvwz[0,:])))
    uvwz[2,:] = 1*np.real(ifft(-xi**2*fft(uvwz[0,:])))
    uvwz[3,:] = 1*np.real(ifft(-1j*xi**3*fft(uvwz[0,:])))

    num_plots = 400
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    line, = axes.plot(x,uvwz[0,:],lw=3)
    xi_max = np.max(np.abs(xi))
    axes.set_xlabel(r'$x$',fontsize=30)
    plt.close()

    frames = [uvwz.copy()]
    tt = [0]

    for n in range(1,nmax+1):
        uvwz_new = rk3(uvwz,xi,rhs_KSH,dt,filtr=filtr)

        uvwz = uvwz_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(uvwz.copy())
            tt.append(t)
        
    def plot_frame(i):
        line.set_data(x,frames[i][0,:])
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((-L/2,L/2))
        axes.set_ylim(ylims)

    if animate:
        anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                           frames=len(frames), interval=100)

        return HTML(anim.to_jshtml()), frames, tt
    else:
        return None, frames, tt

In [None]:
def u0(x):
    return np.sin(x)

m = 256
L = 12*np.pi
c = 200
anim, frames2, tt2 = solve_KSH(u0, tmax=6.e2, L=L, ylims=(-1.1,2.5),m=m,c1=c,c2=c,c3=c)

In [None]:
x = np.arange(-m/2,m/2)*(L/m)
plot_pcolor(frames2,x,tt2)
plt.xlabel('$x$',fontsize=15)
plt.ylabel('$t$',fontsize=15);
plt.savefig('KSh_m256_tau200.pdf', dpi=300)

## Solution of original KS equation

In [None]:
def rhs_f(u, xi, filtr):
    # Evaluate only the non-stiff nonlinear term
    uhat = fft(u)
    return -u*np.real(ifft(1j*xi*uhat))

def substep_L(u, xi, dt):
    # Advance the solution using only the stiff linear term
    return np.real(np.fft.ifft(np.exp((xi**2-xi**4)*dt)*fft(u)))


def solve_KS_Lie_Trotter(m,dt,L,T):
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L
    filtr = np.ones_like(xi)

    u = np.sin(x)    
    uhat2 = np.abs(np.fft.fft(u))

    num_plots = 1000
    nplt = np.floor((T/num_plots)/dt)
    nmax = int(round(T/dt))

    frames = [u.copy()]
    tt = [0]
    uuhat = [uhat2]

    for n in range(1,nmax+1):
        u_star = u.copy()
        u_star = substep_L(u_star,xi,dt)
        u_star = rk3(u_star,xi,rhs_f,dt,filtr)
            
        u = u_star.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(u.copy())
            tt.append(t)
            uhat2 = np.abs(np.fft.fft(u))
            uuhat.append(uhat2)
    return frames, uuhat, x, tt, xi

def plot_pcolor(frames,x=None,t=None):
    ff = np.zeros((len(frames),len(frames[0])))
    for i, fr in enumerate(frames):
        ff[i,:] = frames[i][:]


    plt.figure(figsize=(6,6))
    if (x is not None) and (t is not None):
        plt.pcolor(x,t,ff,linewidth=0,rasterized=True)
    else:
        plt.pcolor(ff)

In [None]:
m = 512
umax = 2
dt = 1.73/(m/2) * 0.9
L = 12*np.pi
T = 1000

frames, uuhat, x, tt, xi = solve_KS_Lie_Trotter(m,dt,L,T)

In [None]:
plot_pcolor(frames[:],x,tt[:])
plt.xlabel('$x$',fontsize=15)
plt.ylabel('$t$',fontsize=15);
plt.savefig('KS.pdf',dpi=300)