# Example notebook using the Pykonal solver for the 3D Eikonal equation

### Import modules

In [None]:
%matplotlib ipympl
import matplotlib.gridspec
import matplotlib.pyplot as plt
import numpy as np
import pykonal

### Define function to plot results

In [None]:
def tt_head(xs, ys, zs, xr, yr, zr, y1, v1, v2):
    theta = np.arcsin(v1/v2)
    l1    = (y1 - ys) / np.cos(theta)
    l2    = np.sqrt((xr-xs)**2 + (zr-zs)**2) - np.tan(theta) * (2 * y1 - ys - yr)
    l3    = (y1 - yr) / np.cos(theta)
    return ((l1 + l3) / v1 + l2 / v2)

def tt_direct(xs, ys, zs, xr, yr, zr, v1):
    return (np.sqrt((xr-xs)**2 + (yr-ys)**2 + (zr-zs)**2) / v1)

def plot(solver, ix=None, iy=None, iz=None, attr='uu', rays=None, cbar_label='Travel-time [s]'):
    if ix is None:
        ix = int(solver.pgrid.npts[0] / 2)
    if iy is None:
        iy = int(solver.pgrid.npts[1] / 2)
    if iz is None:
        iz = int(solver.pgrid.npts[2] / 2)
    data = getattr(solver, attr)
    data_xy = data[:, :, iz]
    data_xz = data[:, iy, :]
    data_yz = data[ix, :, :]
    vmin = np.min(np.concatenate([data_xy.flatten(), data_xz.flatten(), data_yz.flatten()]))
    vmax = np.max(np.concatenate([data_xy.flatten(), data_xz.flatten(), data_yz.flatten()]))
    dx, dy, dz = solver.pgrid.max_coords - solver.pgrid.min_coords
    dmax = np.max([dx, dy, dz])
    aspect = (dx + dy) / (dz + dy)
    gs = matplotlib.gridspec.GridSpec(2, 2,
                           width_ratios=[dx/dmax, dy/dmax],
                           height_ratios=[dz/dmax, dy/dmax]
                           )
    fig = plt.figure(figsize=(aspect*8+0.3, aspect*8))
    ax1 = plt.subplot(gs[0], aspect=1)
    ax2 = plt.subplot(gs[1], aspect=1)
    ax3 = plt.subplot(gs[2], aspect=1)
    
    gs = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[3], height_ratios=[1, 10])
    cax = plt.subplot(gs[0])

    kwargs = dict(
        cmap=plt.get_cmap('jet_r'),
        vmin=vmin,
        vmax=vmax
    )
    qmesh = ax1.pcolormesh(solver.pgrid[:, iy, :, 0], solver.pgrid[:, iy, :, 2], data_xz, **kwargs)
    ax1.axhline(solver.pgrid[0, 0, iz, 2], color='w')
    ax1.axvline(solver.pgrid[ix, 0, 0, 0], color='w')
    if rays is not None:
        for ray in rays:
            ax1.plot(ray[:, 0], ray[:, 2], 'k--')
    ax1.xaxis.tick_top()
    ax1.xaxis.set_label_position('top')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Z')

    ax2.pcolormesh(solver.pgrid[ix, :, :, 1], solver.pgrid[ix, :, :, 2], data_yz, **kwargs)
    ax2.axvline(solver.pgrid[0, iy, 0, 1], color='w')
    if rays is not None:
        for ray in rays:
            ax2.plot(ray[:, 1], ray[:, 2], 'k--')
    ax2.xaxis.tick_top()
    ax2.xaxis.set_label_position('top')
    ax2.yaxis.tick_right()
    ax2.yaxis.set_label_position('right')
    ax2.set_xlabel('Y')
    ax2.set_ylabel('Z')

    qmesh = ax3.pcolormesh(solver.pgrid[:, :, iz, 0], solver.pgrid[:, :, iz, 1], data_xy, **kwargs)
    ax3.axhline(solver.pgrid[0, iy, 0, 1], color='w')
    if rays is not None:
        for ray in rays:
            ax3.plot(ray[:, 0], ray[:, 1], 'k--')
    ax3.invert_yaxis()
    ax3.set_xlabel('X')
    ax3.set_ylabel('Y')
    
    cbar = fig.colorbar(qmesh, cax=cax, orientation='horizontal')
    cbar.set_label(cbar_label)

    fig.tight_layout()

In [None]:
solver = pykonal.EikonalSolver()
solver.vgrid.min_coords     = 0, 0, 0
solver.vgrid.node_intervals = 1, 1, 1
solver.vgrid.npts           = 11, 11, 11
solver.pgrid.min_coords     = solver.vgrid.min_coords
solver.pgrid.node_intervals = solver.vgrid.node_intervals
solver.pgrid.npts           = solver.vgrid.npts
solver.vv                   = np.ones(solver.vgrid.npts)
src = (0, 0, 0)
solver.add_source(src)
solver.solve()

In [None]:
plt.close('all')
plot(solver)

In [None]:
solver.pgrid.min_coords     = solver.vgrid.min_coords
solver.pgrid.node_intervals = solver.vgrid.node_intervals / 2
solver.pgrid.npts           = solver.vgrid.npts * 2 - 1
solver.solve()
plot(solver)

In [None]:
solver.pgrid.min_coords     = solver.vgrid.min_coords
solver.pgrid.node_intervals = solver.vgrid.node_intervals / 4
solver.pgrid.npts           = solver.vgrid.npts * 4 - 5
solver.solve()
plot(solver)