In [None]:
%matplotlib ipympl

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numba
import numpy as np

# Define computational precision
DTYPE_NUMBA_REAL = numba.float64
DTYPE_NUMPY_REAL = np.float64
DTYPE_NUMBA_UINT = numba.uint16
DTYPE_NUMPY_UINT = np.uint16

In [None]:
@numba.jitclass(
    [
        ('xmin',    DTYPE_NUMBA_REAL),
        ('xmax',    DTYPE_NUMBA_REAL),
        ('zmin',    DTYPE_NUMBA_REAL),
        ('zmax',    DTYPE_NUMBA_REAL),
        ('tmin',    DTYPE_NUMBA_REAL),
        ('tmax',    DTYPE_NUMBA_REAL),
        ('vmin',    DTYPE_NUMBA_REAL),
        ('vmax',    DTYPE_NUMBA_REAL),
        ('fmax',    DTYPE_NUMBA_REAL),
        ('cfl',     DTYPE_NUMBA_REAL),
        ('tt',      DTYPE_NUMBA_REAL[:]),
        ('vv',      DTYPE_NUMBA_REAL[:,:]),
        ('src_tf',  DTYPE_NUMBA_REAL[:]),
        ('_src_loc', DTYPE_NUMBA_REAL[:]),
        ('psi',     DTYPE_NUMBA_REAL[:,:,:]),
        ('d2px',    DTYPE_NUMBA_REAL[:,:]),
        ('d2pz',    DTYPE_NUMBA_REAL[:,:]),
        ('it',      DTYPE_NUMBA_UINT)
    ]
)
class FDSolver2DWaveEQ(object):
    def __init__(self):
        self.cfl      = 0.5
        self.it       = 0
        self._src_loc = np.zeros(2, dtype=DTYPE_NUMPY_REAL)

    @property
    def dt(self):
        return (self.cfl * self.dx / self.vmax)
    @property
    def dx(self):
        return (self.vmin / self.fmax / 20)
    
    @property
    def dz(self):
        return (self.vmin / self.fmax / 20)
    
    @property
    def nt(self):
        return (round((self.tmax - self.tmin) / self.dt))
    
    @property
    def nx(self):
        return (round((self.xmax - self.xmin) / self.dx))

    @property
    def nz(self):
        return (round((self.zmax - self.zmin) / self.dz))
    
    @property
    def src_ix(self):
        return (round((self.src_loc[0] - self.xmin) / self.dx))
    
    @property
    def src_iz(self):
        return (round((self.src_loc[1] - self.zmin) / self.dz))

    @property
    def src_loc(self):
        return (self._src_loc)

    @src_loc.setter
    def src_loc(self, value):
        self._src_loc[0] = value[0]
        self._src_loc[1] = value[1]

    @property
    def xx(self):
        return (np.linspace(self.xmin, self.xmax, self.nx).astype(DTYPE_NUMPY_REAL))

    @property
    def zz(self):
        return (np.linspace(self.zmin, self.zmax, self.nz).astype(DTYPE_NUMPY_REAL))


    def init(self):
        self.psi     = np.zeros((self.nx, self.nz, 3), dtype=DTYPE_NUMPY_REAL)
        self.d2px    = np.zeros((self.nx, self.nz), dtype=DTYPE_NUMPY_REAL)
        self.d2pz    = np.zeros((self.nx, self.nz), dtype=DTYPE_NUMPY_REAL)
        self.vv      = self.vmin * np.ones((self.nx, self.nz), dtype=DTYPE_NUMPY_REAL)
        tt           = np.linspace(self.tmin, self.tmax, self.nt).astype(DTYPE_NUMPY_REAL)
        t0           = 4 / self.fmax
        self.src_tf  = (
            -2.*(tt-t0) * (self.fmax**2) * (np.exp(-1.0*(self.fmax**2) * (tt-t0)**2))
        ).astype(DTYPE_NUMPY_REAL)

    def update(self, nstep=1):
        self.psi = _update(self, nstep)

@numba.njit(parallel=True)
# @numba.jit(nopython=True)
def _update(solver, nstep):
    psi = solver.psi
    d2px = solver.d2px
    d2pz = solver.d2pz
    for i in range(nstep):
        solver.it += 1
        d2px[1:-1]                            = (psi[2:, :, 1] - 2 * psi[1:-1, :, 1] + psi[:-2, :, 1]) / solver.dx ** 2
        d2pz[:, 1:-1]                         = (psi[:, 2:, 1] - 2 * psi[:, 1:-1, 1] + psi[:, :-2, 1]) / solver.dz ** 2
        psi[:, :, 2]                          = 2 * psi[:, :, 1] - psi[:, :, 0] + solver.vv ** 2 * solver.dt ** 2 * (d2px + d2pz)
        psi[solver.src_ix, solver.src_iz, 2]  = psi[solver.src_ix, solver.src_iz, 2] + solver.src_tf[solver.it] / (solver.dx * solver.dz) * solver.dt ** 2
        psi[:, :, 0], psi[:, :, 1]     = psi[:, :, 1], psi[:, :, 2]
    return (psi)

In [None]:
solver = FDSolver2DWaveEQ()
solver.xmin, solver.xmax = 0, 1000
solver.zmin, solver.zmax = 0, 1000
solver.tmin, solver.tmax = 0, 2
solver.vmin, solver.vmax = 500, 600
solver.fmax              = 20
solver.src_loc[:]        = 250, 250
solver.init()

In [None]:
xx, zz = np.meshgrid(solver.xx, solver.zz, indexing='ij')

In [None]:
solver.psi.min()

In [None]:
solver.update(100)
plt.close('all')
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, aspect=1)
qmesh = ax.pcolormesh(
    solver.psi[..., 1],
    vmin=-np.max(np.abs(solver.src_tf)) * solver.dt**2, 
    vmax=np.max(np.abs(solver.src_tf)) * solver.dt**2,
    cmap=plt.get_cmap('seismic')
)
ax.invert_yaxis()

In [None]:
def update_qmesh(idx, solver, ax, qmesh):
    solver.update(67)
    qmesh.set_array(solver.psi[:-1, :-1, 1].flatten())
    ax.text(0.05, 0.95, f'{solver.tmin + solver.it * solver.dt: 06.3f} s', ha='left', va='top', transform=ax.transAxes, bbox=dict(facecolor='w'))
    return (ax, qmesh)


plt.close('all')
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, aspect=1)
qmesh = ax.pcolormesh(
    xx,
    zz,
    solver.psi[..., 1],
    vmin=-np.max(np.abs(solver.src_tf)) * solver.dt**2, 
    vmax=np.max(np.abs(solver.src_tf)) * solver.dt**2,
    cmap=plt.get_cmap('seismic')
)
ax.invert_yaxis()
qmesh_ani = animation.FuncAnimation(fig, update_qmesh, fargs=(solver, ax, qmesh),
                                   interval=1, blit=True, frames=60)
qmesh_ani.save('/Users/malcolmwhite/Desktop/anim.gif', writer='pillow', fps=5)