In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import animation
from clawpack.visclaw.JSAnimation import IPython_display

# Advection-diffusion-dispersion
Solution by DFT + exact time integration.
This code is taken from notebook 1 of https://github.com/ketch/PseudoSpectralPython.

In [None]:
# Spatial grid
m=64                            # Number of grid points in space
L = 2 * np.pi                   # Width of spatial domain
x = np.arange(-m/2,m/2)*(L/m)   # Grid points
dx = x[1]-x[0]                  # Grid spacing

# Temporal grid
tmax=4.0     # Final time
N = 25       # number grid points in time
k = tmax/N   # interval between output times

xi = np.fft.fftfreq(m)*m*2*np.pi/L  # Wavenumber "grid"
# (this is the order in which numpy's FFT gives the frequencies)

In [None]:
# Initial data
u = np.sin(2*x)**2 * (x<-L/4)
uhat0 = np.fft.fft(u)

a1 = -1.0       # Advection coefficient
a2 = 0.  # Diffusion coefficient
a3 = 0.  # Dispersion coefficient

In [None]:
# Store solutions in a list for plotting later
frames = [u.copy()]

# Now we solve the problem
for n in range(1,N+1):
    t = n*k
    uhat = np.exp((1.j*xi*a1 - a2*xi**2 - 1.j*xi**3*a3)*t) * uhat0
    u = np.real(np.fft.ifft(uhat))
    frames.append(u.copy())
    
# Set up plotting
fig = plt.figure(figsize=(9,4)); axes = fig.add_subplot(111)
line, = axes.plot([],[],lw=3)
axes.set_xlim((x[0],x[-1])); axes.set_ylim((-0.1,1.))

def plot_frame(i):
    #fig = plt.figure()
    #plt.plot(x,frames[i])
    line.set_data(x,frames[i])
    axes.set_title('t='+str(i*k))
    fig.canvas.draw()
    return fig

# Animate the solution
matplotlib.animation.FuncAnimation(fig, plot_frame,
                                   frames=len(frames),
                                   interval=200,
                                   repeat=False)

# Advection with centered differences

In [None]:
# Spatial grid
m=64                            # Number of grid points in space
L = 2 * np.pi                   # Width of spatial domain
x = np.arange(-m/2,m/2)*(L/m)   # Grid points
dx = x[1]-x[0]                  # Grid spacing

# Temporal grid
tmax=4.0     # Final time
N = 25       # number grid points in time
k = tmax/N   # interval between output times

In [None]:
# Initial data
u = np.sin(2*x)**2 * (x<-L/4)
uhat0 = np.fft.fft(u)

In [None]:
# Store solutions in a list for plotting later
frames_true = [u.copy()]
frames_cd = [u.copy()]

# Now we solve the problem
for n in range(1,N+1):
    t = n*k
    uhat_true = np.exp((-1.j*xi)*t) * uhat0
    u_true = np.real(np.fft.ifft(uhat_true))
    frames_true.append(u_true.copy())
    
    uhat_cd = np.exp(-1.j*np.sin(xi*dx)*t/dx) * uhat0
    u_cd = np.real(np.fft.ifft(uhat_cd))
    frames_cd.append(u_cd)
    
# Set up plotting
fig = plt.figure(figsize=(9,4)); axes = fig.add_subplot(111)
line1, = axes.plot([],[],lw=3)
line2, = axes.plot([],[],lw=3)
plt.legend(['Exact','Centered differences'])

axes.set_xlim((x[0],x[-1])); axes.set_ylim((-0.1,1.))

def plot_frame(i):
    #fig = plt.figure()
    #plt.plot(x,frames[i])
    line1.set_data(x,frames_true[i])
    line2.set_data(x,frames_cd[i])
    axes.set_title('t='+str(i*k))
    fig.canvas.draw()
    return fig

# Animate the solution
matplotlib.animation.FuncAnimation(fig, plot_frame,
                                   frames=len(frames),
                                   interval=200,
                                   repeat=False)