# Third-order Lax-Wendroff
Derivations and summary results are presented in file LW_ImEx.lyx

In [None]:
#For the LaTeX equations (such as eqnarray) in this document to work, include the following in file
#~/.jupyter/_config.yml
#
#parse:
#  myst_enable_extensions:  # default extensions to enable in the myst parser. See https://myst-parser.readthedocs.io/en/latest/using/syntax-optional.html
#     - amsmath
#
#(the default ~/.jupyter/_config.yml will have amsmath commented out)
#
#This notebook uses sympy and sparse linear algegra
import numpy as np
from numpy import exp
import sympy as sy
from sympy.matrices import Matrix, MatrixSymbol
from fractions import Fraction as Fr
import matplotlib.pyplot as plt
from matplotlib import colors
from scipy.sparse import diags
from scipy.sparse.linalg import spsolve
import matplotlib.pyplot as plt
import numpy as np
# Convension is that symbols start with capital letters

In [None]:
# Derivation of spatial gradients from a cubic polynomial
PolyCoeffs = sy.Matrix(sy.symarray("PolyCoeffs", (4,))) # Coefficients a,b,c,d of the polynomial
Psi = sy.Matrix(sy.symarray("Psi", (4,)))               # Grid point values at i-2, i-1, i and i+1
polyM = Matrix([[-8,4,-2,1], [-1,1,-1,1], [0,0,0,1], [1,1,1,1]])
PolyCoeffs = polyM.solve(Psi)
Ddx = PolyCoeffs[2]
D2dx2 = 2*PolyCoeffs[1]
D3dx3 = 6*PolyCoeffs[0]
print('d/dx =', Ddx, '\nd2/dx2 =', D2dx2, '\nd3/dx3 =', D3dx3)

In [None]:
# From this we can work out the explicit LW3e  increment
[C, Chi2, Chi3] = sy.symbols("C, Chi2, Chi3")   # The Courant number and the HO limiters
LW3e = - C*Ddx + Chi2*C**2/2*D2dx2 - Chi3*C**3/6*D3dx3
print('LW3e increment is', LW3e)
print('LW3e increment as coefficients of grid points is\n', sy.collect(sy.expand(LW3e), Psi))

In [None]:
# The implicit LW3i scheme
LW3i = - C*Ddx - C**2/2*D2dx2 - C**3/6*D3dx3
print('LW3i increment is ', LW3i)
print('LW3i increment as coefficients of grid points is\n', sy.collect(sy.expand(LW3i), Psi))

# The amplification factor for LW3_AdImEx
[Kdx, Alp] = sy.symbols("Kdx, Alp")
A_LW3ImEx = (1 + (1-Alp)*LW3e)/(1 - Alp*LW3i)
A_LW3ImEx = A_LW3ImEx.subs({Psi[0] : sy.E**(-2*sy.I*Kdx), Psi[1] : sy.E**(-sy.I*Kdx), Psi[2] : 1, Psi[3] : sy.E**(sy.I*Kdx)})
print('Amplification factor for LW3 ImEx is \n', A_LW3ImEx)

In [None]:
# Stability Analysis of LW3 AdImEx

def alpha(co):
    return np.maximum(0, 1-1/np.maximum(co, 0.1))

def chi(co):
    return {"chi2": np.maximum(1,0.5/np.maximum(co, 0.1)),
            "chi3": np.maximum(1,0.25/np.maximum(co**2, 0.1))}

#A_LW3 = sy.lambdify([(Co Alp, Kdx, Chi2, Chi3)], A_LW3ImEx, 'numpy')

def A_LW3(co, a, kdx, chi2 = 1, chi3 = 1):
    """Amplification factor for LW3 for Courant number c, off-center a, wave kdx
       chi2 is the limiter on the 2nd order temporal correction and chi3 on the 3rd order"""
    return complex(A_LW3ImEx.subs({C:co, Alp:a, Kdx:kdx, Chi2: chi2, Chi3: chi3}))
    #psi_0 = exp(-2*1j*kdx)
    #psi_1 = exp(-1j*kdx)
    #psi_2 = 1
    #psi_3 = exp(1j*kdx)
    #return ((1 - a)*(-c**3*chi3*(-psi_0 + 3*psi_1 - 3*psi_2 + psi_3)/6
    #                 + c**2*chi2*(psi_1 - 2*psi_2 + psi_3)/2
    #                 - c*(psi_0/6 - psi_1 + psi_2/2 + psi_3/3)) + 1)\
    #    /(-a*(-c**3*(-psi_0 + 3*psi_1 - 3*psi_2 + psi_3)/6
    #          - c**2*(psi_1 - 2*psi_2 + psi_3)/2
    #          - c*(psi_0/6 - psi_1 + psi_2/2 + psi_3/3)) + 1)
    #cSqr = chi2*c
    #cCub = chi3*c**2
    #return (1 - (1-a)*c*(psi0(kdx)/6*(1 - cCub) + psi1(kdx)/2*(2 + cSqr - cCub)\
    #                 - psi2(kdx)/3*(1 + 2*cSqr - cCub) - psi3(kdx)/6*(2 - 3*cSqr + cCub)))/\
    #        (1 + a*c*(psi0(kdx)/6*(1 - cCub) + psi1(kdx)/2*(-2 + cSqr + cCub)
    #            + psi2(kdx)/2*(1 - 2*cSqr - cCub) + psi3(kdx)/6*(2 + 3*cSqr + cCub)))

a =     [0, alpha, 1]
chis =  [{"chi2":1, "chi3":1}, {"chi2":1, "chi3":0}, {"chi2":0, "chi3":0}, chi]
titles = [["3rd order Explicit", "3rd order a=1-1/c", "3rd order Implicit"], 
          ["2nd order in time", "2nd order in time", "2nd order in time"], 
          ["1st order in time", "1st order in time", "1st order in time"], 
          ["chi2 = 1/2c, chi3 = 1/4c^2", "chi2 = 1/2c, chi3 = 1/4c^2", "chi2 = 1/2c, chi3 = 1/4c^2"]]

kdxs = np.linspace(1e-6, 2*np.pi, 37)
cs = np.arange(0, 5.1, 0.1) #10**(np.linspace(-1, 1, 81))
magA = np.zeros([len(kdxs), len(cs)])
for ich in range(len(chis)):
    fig,axs = plt.subplots(1,len(a), figsize=(12,4), layout='constrained')
    if ich == 0:
        fig.suptitle("Lax-Wendroff Amplification Factor Magnitudes")
    for i in range(len(a)):
        for ic in range(len(cs)):
            co = cs[ic]
            ai = a[i]
            if callable(a[i]):
                ai = a[i](co)
            ch = chis[ich]
            if callable(ch):
                ch = ch(co)
            for ik in range(len(kdxs)):
                kdx = kdxs[ik]
                magA[ik,ic] = abs(A_LW3(co, ai,kdx, **ch))
        axplot = axs[i].contourf(cs, kdxs,magA, np.arange(0, 2.1, 0.1))
        axs[i].axvline(x=1, color="black", linestyle=":")
        axs[i].axvline(x=2, color="black", linestyle=":")
        fig.colorbar(axplot,ax=axs[i], orientation='horizontal')
        axs[i].contour(cs, kdxs, magA, [0, 1+1e-6], colors=['k', 'k'])
        axs[i].set(xlabel=r'$c$', ylabel=r'$k\Delta x$', title = titles[ich][i])

    plt.show()

In [None]:
# Finding the root of the problem of LW32i
kdx = np.pi
chi2, chi3, a, c = sy.symbols("chi2, chi3, a, c")

def A_LW3_denom(c, a, chi2, chi3):
    cSqr = chi2*c**2
    cCub = chi3*c**3
    return 1-a*((cCub/6 - c/6) - (-cCub/2 - cSqr/2 + c)
                + (cCub/2 + cSqr - c/2) - (-cCub/6 - cSqr/2 - c/3))

def A_LW3_pi(c, a, chi2, chi3):
    """Amplification factor for LW3 for Courant number c, off-center a, wave kdx=pi
       chi2 is the limiter on the 2nd order temporal correction and chi3 on the 3rd order"""
    cSqr = chi2*c**2
    cCub = chi3*c**3
    return (1 + (1-a)*((cCub/6 - c/6) - (-cCub/2 + cSqr/2 + c)\
                     + (cCub/2 - cSqr - c/2) - (-cCub/6 + cSqr/2 - c/3)))/\
            A_LW3_denom(c, a, chi2, chi3)

print(A_LW3_denom(c, a, chi2, chi3))


In [None]:
def LW3iMatrix(nx, c, a, chi2, chi3):
    """The matrix for the implicit part of LW3_ImEx. 
    nx: nx by nx matrix
    c: Courant number
    a: Off-centering
    chi2: fraction of 2nd-order in time terms
    chi3: fraction of 3rd-order in time terms"""
    cSqr = chi2*c
    cCub = chi3*c**2
    ac = a*c
    M = diags([ac/6*(2 + 3*cSqr + cCub), # The bottom left corner for j+1
               ac/6*(1 - cCub)*np.ones(nx-2), # The diagonal for j-2
               -0.5*ac*(2 - cSqr - cCub)*np.ones(nx-1), # The diagonal for j-1
               (1 + 0.5*ac*(1 - 2*cSqr - cCub))*np.ones(nx), # The diagonal
               ac/6*(2 + 3*cSqr + cCub)*np.ones(nx-1), # The diagonal for j+1
               ac/6*(1 - cCub)*np.ones(2), # the diagonal next to the top right  corner (j-2)
               -0.5*ac*(2 - cSqr - cCub)], # the top right corner (j-1)
               [-nx+1, -2, -1, 0, 1, nx-2, nx-1], # the locations of each of the diagonals
               shape=(nx,nx), format = 'csr')
    return M


def LW32i(phi, c):
    """Lax Wendroff advection of profile phi with Courant number c for one time step
       third-order in space, second-order in time, implicit. Periodic boundary conditions"""
    nx = len(phi)
    # The sparse matrix for the implicit solve, defining the diagonals
    #M = diags([c/3+c**2/2, # The bottom left corner
    #           c/6*np.ones(nx-2), # The diagonal for j-2
    #           -0.5*c*(2-c)*np.ones(nx-1), # The diagonal for j-1
    #           (1 + 0.5*c - c**2)*np.ones(nx), # The diagonal
    #           c/6*(2 + 3*c)*np.ones(nx-1), # The diagonal for j+1
    #           c/6*np.ones(2), # the diagonal next to the top right  corner
    #           -c+0.5*c**2], # the top right corner
    #           [-nx+1, -2, -1, 0, 1, nx-2, nx-1], # the locations of each of the diagonals
    #           shape=(nx,nx), format = 'csr')
    #print('M = ', M.toarray())
    return spsolve(M, phi)

# Parameters for some revolutions of the periodic domain
nRevs = 10
nt = 24*nRevs
plotFreq = 24 #6
dt = nRevs/nt
nxs = np.array([10, 20, 40, 80, 160])

for i in range(len(nxs)):
    nx = nxs[i]
    c = nRevs*nx/nt
    M = LW3iMatrix(nx, c, 1, 1, 0)
    dx = 1/nx
    x = np.arange(0, 1, dx)
    phi0 = np.where(x<0.5, 1., 0.)
    phi = phi0.copy()
    plt.plot(x, phi, 'k', label = 't=0')
    for it in range(nt):
        phi = LW32i(phi, c)
        if (it+1)%plotFreq == 0:
            plt.plot(x, phi, label = 't='+str(round((it+1)*dt, 2)))
    plt.legend()
    plt.title('c = '+str(round(c,2))+', dx = '+str(round(dx,2))+', dt = '+str(round(dt,2))
                                                  +', nt = '+str(nt)+', nx = '+str(nx))
    plt.show()

In [None]:
help(sy.lambdify)