In [None]:
import numpy as np
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
from IPython.display import HTML
font = {'size'   : 15}
matplotlib.rc('font', **font)

fft = np.fft.fft
ifft = np.fft.ifft

$$
v = \begin{bmatrix} p \\ u \end{bmatrix}
$$

In [None]:
def rk3(v,xi,rhs,dt):
    y2 = v + dt*rhs(v,xi)
    y3 = 0.75*v + 0.25*(y2 + dt*rhs(y2,xi))
    v_new = 1./3 * v + 2./3 * (y3 + dt*rhs(y3,xi))
    return v_new


    
def solve_wave_equation(m,dt):
    L = 100*np.pi
    x = np.arange(-m/2,m/2)*(L/m)
    xi = np.fft.fftfreq(m)*m*2*np.pi/L

    p0 = np.exp(-0.01*(x+2)**2)
    u0 = np.zeros_like(p0)
    v = np.zeros((2,len(x)))
    v[0,:] = p0; v[1,:] = u0
    tmax = 600.0

    K = 4 + np.sin(x)
    rho = 4 + np.sin(x)
    
    def rhs(v, xi):
        d = np.zeros_like(v)
        p = v[0,:]
        u = v[1,:]
        d[0,:] = - K * np.real(ifft(1j*xi*fft(u)))
        d[1,:] = -1./rho * np.real(ifft(1j*xi*fft(p)))
        return d

    num_plots = 50
    nplt = np.floor((tmax/num_plots)/dt)
    nmax = int(round(tmax/dt))

    frames = [v.copy()]
    tt = [0]
    
    for n in range(1,nmax+1):
        v_new = rk3(v,xi,rhs,dt)

        v = v_new.copy()
        t = n*dt
        # Plotting
        if np.mod(n,nplt) == 0:
            frames.append(v.copy())
            tt.append(t)
    return frames, x, tt, xi
    


def plot_solution(frames, x, tt, xi):
    fig = plt.figure(figsize=(12,8))
    axes = fig.add_subplot(111)
    line, = axes.plot(x,frames[0][0,:],lw=3)
    xi_max = np.max(np.abs(xi))
    axes.set_xlabel(r'$x$',fontsize=30)
    plt.close()

    def plot_frame(i):
        line.set_data(x,frames[i][0,:])
        axes.set_title('t= %.2e' % tt[i])
        axes.set_xlim((x[0],x[-1]))
        axes.set_ylim((-0.1,1.1))

    anim = matplotlib.animation.FuncAnimation(fig, plot_frame,
                                              frames=len(frames), interval=100,
                                              repeat=False)
    return HTML(anim.to_jshtml())

In [None]:
m = 512
dt = 1.73/(m/2)
frames, x, tt, xi = solve_wave_equation(m,dt)

In [None]:
plot_solution(frames, x, tt, xi)