In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import animation
from IPython.display import HTML
from scipy.signal import argrelextrema
ifft = np.fft.ifft
plt.rcParams['axes.labelsize'] = 16

In [None]:
# Spatial grid
m=4000                        # Number of grid points in space
L = 2 * np.pi                   # Width of spatial domain
x = np.arange(-m/2,m/2)*(L/m)   # Grid points
dx = x[1]-x[0]                  # Grid spacing

# Temporal grid
tmax = 0.1     # Final time
N = 101      # number grid points in time
k = tmax/N   # interval between output times
xi = np.fft.fftfreq(m)*m*2*np.pi/L  # Wavenumber "grid"

# Initial data
#u = np.sin(2*x)**2 * (x<-L/4)
u = np.exp(-50*(x-2.)**2)
uhat0 = np.fft.fft(u)

epsilon=0.01  # dispersion coefficient

# Store solutions in a list for plotting later
frames = [u.copy()]
fframes = [uhat0]

# Now we solve the problem
for n in range(1,N+1):
    t = n*k
    uhat = np.exp((1j*epsilon*xi**3)*t) * uhat0
    u = np.real(np.fft.ifft(uhat))
    frames.append(u.copy())
    fframes.append(uhat)

In [None]:
fig, axes = plt.subplots(3,1,figsize=(12,12))
uline, = axes[1].plot([],[],'-k',lw=3)
uhline1, = axes[0].plot([],[],lw=2)
uhline2, = axes[0].plot([],[],lw=2)
uhline3, = axes[0].plot([],[],lw=1)
uhline4, = axes[2].plot([],[],lw=2)
uhline5, = axes[2].plot([],[],lw=2)
uhline6, = axes[2].plot([],[],lw=1)
axes[0].set_xlim((x[0],x[-1])); axes[0].set_ylim((-0.1,0.1))
axes[2].set_xlim((x[0],x[-1])); axes[0].set_ylim((-0.1,0.1))
axes[1].set_xlim((x[0],x[-1])); axes[1].set_ylim((-0.5,1.))
plt.close()

freq1 = np.zeros_like(uhat0)
freq2 = np.zeros_like(uhat0)
freq3 = np.zeros_like(uhat0)
freq4 = np.zeros_like(uhat0)
freq5 = np.zeros_like(uhat0)
freq6 = np.zeros_like(uhat0)
i1 = 1
i2 = 3
i3 = 5

j1 = 10
j2 = 30
j3 = 50

def plot_frame(i):
    uline.set_data(x,frames[i])
    maxima = argrelextrema(frames[i],np.greater)
    freq1[i1] = fframes[i][i1]
    uhline1.set_data(x,np.real(ifft(freq1)))
    freq2[i2] = fframes[i][i2]
    uhline2.set_data(x,np.real(ifft(freq2))+0.)
    freq3[i3] = fframes[i][i3]
    uhline3.set_data(x,np.real(ifft(freq3))+0.)
    freq4[j1] = fframes[i][j1]
    uhline4.set_data(x,np.real(ifft(freq4)))
    freq5[j2] = fframes[i][j2]
    uhline5.set_data(x,50*np.real(ifft(freq5))+0.)
    freq6[j3] = fframes[i][j3]
    uhline6.set_data(x,200000*np.real(ifft(freq6))+0.)
    axes[0].set_title('t='+str(i*k))
    fig.canvas.draw()
    return fig

# Animate the solution
anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                   frames=len(frames),
                                   interval=200,
                                   repeat=False)

HTML(anim.to_jshtml())

In [None]:
# Spatial grid
m=40000                          # Number of grid points in space
L = 2 * np.pi                   # Width of spatial domain
x = np.arange(-m/2,m/2)*(L/m)   # Grid points
dx = x[1]-x[0]                  # Grid spacing

# Temporal grid
tmax = 0.005     # Final time
N = 100      # number grid points in time
k = tmax/N   # interval between output times
xi = np.fft.fftfreq(m)*m*2*np.pi/L  # Wavenumber "grid"

# Initial data
#u = np.sin(2*x)**2 * (x<-L/4)
u = np.exp(-5000*(x-2.9)**2)
uhat0 = np.fft.fft(u)
plt.plot(x,u);

In [None]:
epsilon=0.01  # dispersion coefficient

# Store solutions in a list for plotting later
frames = [u.copy()]
fframes = [uhat0]

# Now we solve the problem
for n in range(1,N+1):
    t = n*k
    uhat = np.exp((1j*epsilon*xi**3)*t) * uhat0
    u = np.real(np.fft.ifft(uhat))
    frames.append(u.copy())
    fframes.append(uhat)

In [None]:
fig, axes = plt.subplots(3,1,figsize=(12,12))
uline, = axes[1].plot([],[],'-k',lw=1)
xline, = axes[1].plot([],[],'.r')
kline, = axes[2].plot([],[],'-b')
uhline1, = axes[0].plot([],[],lw=3)
uhline2, = axes[0].plot([],[],lw=2)
uhline3, = axes[0].plot([],[],lw=1)
axes[0].set_xlim((x[0],x[-1])); axes[0].set_ylim((-0.01,0.01))
axes[1].set_xlim((x[0],x[-1])); axes[1].set_ylim((-0.5,1.))
axes[2].set_xlim((x[0],x[-1])); axes[2].set_ylim((-0.01,150))
axes[2].set_ylabel('local wavenumber'); axes[2].set_xlabel('x')
plt.close()

freq1 = np.zeros_like(uhat0)
freq2 = np.zeros_like(uhat0)
freq3 = np.zeros_like(uhat0)
i1 = 10
i2 = 80
i3 = 150

def plot_frame(i):
    uline.set_data(x,frames[i])
    maxima = argrelextrema(frames[i],np.greater)
    xline.set_data(x[maxima],0.8+np.zeros_like(maxima))
    kline.set_data(x[maxima][1:],1/(np.diff(maxima)*dx))
    freq1[i1] = fframes[i][i1]
    uhline1.set_data(x,np.real(ifft(freq1)))
    freq2[i2] = fframes[i][i2]
    uhline2.set_data(x,np.real(ifft(freq2))+0.)
    freq3[i3] = fframes[i][i3]
    uhline3.set_data(x,np.real(ifft(freq3))+0.)
    axes[0].set_title('t='+str(i*k))
    fig.canvas.draw()
    return fig

# Animate the solution
anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                   frames=len(frames),
                                   interval=200,
                                   repeat=False)

HTML(anim.to_jshtml())