# A simple elastic free-surface implementation

#### Authors: 
#### Ed Caunt -- Devito Codes, 2025
#### Thomas Cullison -- Stanford University, 2025

## Python and Devito System Level Imports and Verificiations

In [None]:
import sys
mypython = sys.executable
print(mypython)

In [None]:
from os import environ as os_env

os_env['DEVITO_LANGUAGE'] = 'openmp'

# Print all Devito env vars
for key, value in os_env.items():
    if 'DEVITO' in key:
        print(f'{key}: {value}')
        
os_env['OMP_NUM_THREADS'] = '32'

#### Show Devito Verison

In [None]:
!{mypython} -m pip show devito

In [None]:
!{mypython} --version

## General Imports

In [None]:
import sympy

import devito as dv
import numpy as np
import matplotlib.pyplot as plt

from datetime import timedelta

from devito.tools import memoized_func, as_tuple
from examples.seismic import TimeAxis, RickerSource

## Subsurface Model Paramters

### Model Geometry

In [None]:
dxyz_try = 10.

ex = 2560.
ey = 2560.
ez = 1920.
gextent = (ex,ey,ez)

ox = 0.
oy = 0.
oz = 0.
gorigin = (ox,oy,oz)

### Subsurface Model Values

In [None]:
vp0 = 2.0
vs0 = 1.25
rh0 = 1.

pscale = 1.5 #2.5

vp1 = pscale*vp0
vs1 = pscale*vs0
rh1 = pscale*rh0

### Time and Frequency Parameters

In [None]:
max_time = 600   # ms
max_freq = 0.015 # mHz

cfl_scale = 0.125 # emperical for stability for dt

## Estimated Operations Per Seconds

In [None]:
emp_op_p_sec = 362644294.   # TAC: emperically determined using 32 cores and OpenMP backend, 1-node

## Construct Devito Grid

In [None]:
def calc_model_by_extent(extent,v_min,v_max,f_max,dxyz_try=None,cfl_scale=0.25,ppwl=8):
    ex,ey,ez = extent
    wavelen_min = v_min/float(f_max)
    dxyz_max = wavelen_min/ppwl
    if dxyz_try is not None:
        if dxyz_try < dxyz_max:
            dxyz_max = dxyz_try
    dt = cfl_scale*dxyz_max/v_max
    nx = int(np.ceil(ex/dxyz_max)) + 1
    ny = int(np.ceil(ey/dxyz_max)) + 1
    nz = int(np.ceil(ez/dxyz_max)) + 1
    return (nx,ny,nz), dxyz_max, dt
    

In [None]:
v_min = np.amin([vp0,vp1,vs0,vs1]).item()
v_max = np.amax([vp0,vp1,vs0,vs1]).item()
print((v_min,v_max))

In [None]:
gshape, dxyz, dt = calc_model_by_extent(gextent,v_min,v_max,max_freq,dxyz_try=dxyz_try,cfl_scale=cfl_scale)

In [None]:
(gshape, dxyz, dt)

In [None]:
nx, ny, nz = gshape

In [None]:
so = 8

grid = dv.Grid(gshape, extent=gextent,origin=gorigin)
dx,dy,dz = grid.spacing

## Devito Vector and Tensors Functions

In [None]:
v = dv.VectorTimeFunction(name='v', grid=grid, space_order=so, time_order=1)
tau = dv.TensorTimeFunction(name='tau', grid=grid, space_order=so, time_order=1)

In [None]:
v

In [None]:
tau

## Devito Subsurface Model Functions

### Where to Start 2nd Layer

In [None]:
h_horizon = (70./151.)*nz #ez/3
ihz = int(np.floor(h_horizon + 0.5))
# ihz = 70
ihz

### Set Devito Model Values

In [None]:
cp = dv.Function(name='cp', grid=grid)
cs = dv.Function(name='cs', grid=grid)
ro = dv.Function(name='ro', grid=grid)

cp.data[:,:,:ihz] = vp0
cs.data[:,:,:ihz] = vs0
ro.data[:,:,:ihz] = rh0

cp.data[:,:,ihz:] = vp1
cs.data[:,:,ihz:] = vs1
ro.data[:,:,ihz:] = rh1

### Derived Model Paramters

In [None]:
# Shorthands
mu = cs**2*ro
lam = (cp**2*ro - 2*mu)
b = 1/ro

## Create Source

In [None]:
t0 = 0.
tn = max_time

time_range = TimeAxis(start=t0, stop=tn, step=dt)
f0 = max_freq

src = RickerSource(name='src', grid=grid, f0=f0, time_range=time_range, interpolation='sinc')
src.coordinates.data[:] = 0.5*ex
src.coordinates.data[:, -1] = 0.

In [None]:
# The source injection term
src_xx = src.inject(field=tau.forward[0, 0], expr=src)
src_yy = src.inject(field=tau.forward[1, 1], expr=src)
src_zz = src.inject(field=tau.forward[2, 2], expr=src)

## Calculate Estimate Runtime for 32 Cores and for Single Core

In [None]:
num_op = np.prod(gshape)*len(time_range.time_values)
est_run_sec = num_op/emp_op_p_sec
est_run_time_str = str(timedelta(seconds=est_run_sec))
est_co_run_time_str = str(timedelta(seconds=32*est_run_sec))
print(f'Estimated Runtime 32 Cores: {est_run_time_str}') 
print(f'Estimated Single Core Runtime: {est_co_run_time_str}')

## Define Functions for Applying Free-surface BC

In [None]:
# Free surface mirroring scheme
#   * Antisymmetric mirror for stresses to impose txz=tyz=tzz=0
#   * Particle velocities linearly extrapolated across boundary
# Z-derivatives for extrapolation obtained by setting txz.dt, tyz.dt, tzz.dt to zero
# in governing equations and rearranging.

# These equations will take interior values and project them into the halo region
# to create a suitable image imposing the (approximate) free surface boundary condition

@memoized_func
def fs_dim(z, so, z0):
    return dv.CustomDimension(name="zfs", parent=z, symbolic_min=1, symbolic_max=so+z0,
                              symbolic_size=so)


def mirror(grid, fields, r_coeff=-1):
    """
    Generate the stencil that mirrors the field to implement a free surface

    Parameters
    ----------
    grid: Grid
        Computational grid
    eq: Eq or List of Eq
        Equation to apply mirror to
    """
    eqs = []

    z = grid.dimensions[-1]
    z0 = 0

    for u in as_tuple(fields):
        if u == 0:
            continue

        # Get left (top) halo width and corresponding mirroring dim
        sow = u.halo[z][0]
        zfs = fs_dim(z, sow, z0)

        sh = 1 if z in as_tuple(u.staggered) else 0
        eqs.extend([dv.Eq(u._subs(z, z0 - zfs), r_coeff * u._subs(z, z0 + zfs - sh))])

        if z not in as_tuple(u.staggered):
            eqs.append(dv.Eq(u._subs(z, z0), 0))

    return eqs


def flux(grid, fields, gradient=0):
    """
    Generate the stencil that extends the field to implement a flux

    Parameters
    ----------
    grid: Grid
        Computational grid
    eq: Eq or List of Eq
        Equation to apply flux to
    """
    eqs = []

    z = grid.dimensions[-1]
    hz = z.spacing
    z0 = 0

    for u in as_tuple(fields):
        if u == 0:
            continue

        # Get left (top) halo width and corresponding mirroring dim
        sow = u.halo[z][0]
        zfs = fs_dim(z, sow, z0)

        sh = 1 if z in as_tuple(u.staggered) else 0
        eqs.extend([dv.Eq(u._subs(z, z0 - zfs),
                          u._subs(z, z0 + zfs - sh)
                          + 2*(zfs - 0.5*sh)*hz*gradient._subs(z, z0 + zfs - sh))])

    return eqs


eqs_fs_tau = mirror(grid, (tau[0, 2], tau[1, 2], tau[2, 2]))

eqs_fs_v = flux(grid, v[0].forward, gradient=-v[2].dx)
eqs_fs_v += flux(grid, v[1].forward, gradient=-v[2].dy)
eqs_fs_v += flux(grid, v[2].forward, gradient=-(lam/(lam+2*mu))*(v[0].dx + v[1].dy))

In [None]:
eqs_fs_v

## Define PDE and Discrete Equations

In [None]:
# First order elastic wave equation
pde_v = v.dt - ro*dv.div(tau)
pde_tau = tau.dt - lam*dv.diag(dv.div(v.forward)) - mu*(dv.grad(v.forward) + dv.grad(v.forward).transpose(inner=False))

# Time update
u_v = dv.Eq(v.forward, dv.solve(pde_v, v.forward))
u_tau = dv.Eq(tau.forward, dv.solve(pde_tau, tau.forward))

In [None]:
u_v

In [None]:
u_tau

## Construct Devito Operator (order matters)

In [None]:
# Note: v free surface equations inserted before tau update
op = dv.Operator([u_v] + eqs_fs_v + [u_tau] + src_xx + src_yy + src_zz + eqs_fs_tau)

## Show Compiler Code of 'op'

In [None]:
print(op.ccode)

## Run Propagator

In [None]:
%%time
op(dt=dt)

## Plot Wavefields

### $\tau_{xx}$

In [None]:
vmax = np.amax(np.abs(tau[0, 0].data[-1, 100]))
plt.imshow(tau[0, 0].data[-1, 100].T, vmax=vmax, vmin=-vmax, cmap='seismic')
plt.colorbar()
plt.show()

### $\tau_{yy}$

In [None]:
vmax = np.amax(np.abs(tau[1, 1].data[-1, 100]))
plt.imshow(tau[1, 1].data[-1, 100].T, vmax=vmax, vmin=-vmax, cmap='seismic')
plt.colorbar()
plt.show()

### $\tau_{zz}$

In [None]:
vmax = np.amax(np.abs(tau[2, 2].data[-1, 100]))
plt.imshow(tau[2, 2].data[-1, 100].T, aspect='auto',vmax=vmax, vmin=-vmax, cmap='seismic')
plt.colorbar()
plt.show()

### Pressure Field $ = \tau_{xx} + \tau_{yy} + \tau_{zz}$

In [None]:
p_field = tau[0, 0].data + tau[1, 1].data + tau[2, 2].data
vmax = np.amax(np.abs(p_field))
plt.imshow(p_field[-1,100].T, aspect='auto',vmax=vmax, vmin=-vmax, cmap='seismic')
plt.colorbar()
plt.show()

### $\tau_{xy}$

In [None]:
vmax = np.amax(np.abs(tau[0, 1].data[-1, 100]))
plt.imshow(tau[0, 1].data[-1, 100].T, vmax=vmax, vmin=-vmax, cmap='seismic')
plt.colorbar()
plt.show()

### $\tau_{xz}$

In [None]:
vmax = np.amax(np.abs(tau[0, 2].data[-1, 100]))
plt.imshow(tau[0, 2].data[-1, 100].T, vmax=vmax, vmin=-vmax, cmap='seismic')
plt.colorbar()
plt.show()

### $\tau_{yz}$

In [None]:
vmax = np.amax(np.abs(tau[1, 2].data[-1, 100]))
plt.imshow(tau[1, 2].data[-1, 100].T, vmax=vmax, vmin=-vmax, cmap='seismic')
plt.colorbar()
plt.show()

## Plotting Code for Butterfly/Slicer/Unfold View Shown in Slides and Report

In [None]:
import holoviews as hv
import panel as pn
from holoviews import streams
hv.extension('bokeh')
pn.extension()
# pn.extension(debug=True)
# pn.extension('bokeh')

def interactive_3d_slicer(volume,
                          overlay = None,
                          src_coords=None, rec_coords=None,
                          ox=0.,oy=0.,oz=0.,
                          dx=1.,dy=1.,dz=1.,
                          xprat=1.,yprat=1.,zprat=1.,
                          isx=None,isy=None,isz=None,
                          xc_init=None, yc_init=None, zc_init=None,
                          xlabel='x',ylabel='y',zlabel='z',
                          xunit=None, yunit=None, zunit=None,
                          npixels=400,
                          cmap='jet',
                          ocmap='jet'):
    """
    Create an interactive 3D slicer using HoloViews, Panel, and Streams with global color scaling.

    Parameters:
        volume (np.ndarray): 3D numpy array representing the data volume.
    """
    # print("Starting Slicer")
    
    fontsize={'xlabel': 16, 'ylabel': 16, 'ticks': 12}

    # data volume size-by-index
    x_size, y_size, z_size = volume.shape
    

    # Initial slice indices
    if isx is None:
        isx = x_size // 2
    if isy is None:
        isy = y_size // 2
    if isz is None:
        isz = z_size // 2
    x_initial = xc_init if xc_init is not None else ox + dx*(isx)
    y_initial = yc_init if yc_init is not None else oy + dy*(isy)
    z_initial = zc_init if zc_init is not None else oz + dz*(isz)


    # Color scaling
    global_min, global_max = volume.min(), volume.max()

    # Float Sliders
    x_slider = pn.widgets.FloatSlider(name=xlabel, start=ox, end=ox+dx*(x_size - 1), value=x_initial,step=dx)
    y_slider = pn.widgets.FloatSlider(name=ylabel, start=oy, end=oy+dy*(y_size - 1), value=y_initial,step=dy)
    z_slider = pn.widgets.FloatSlider(name=zlabel, start=oz, end=oz+dz*(z_size - 1), value=z_initial,step=dz)
    a_slider = pn.widgets.FloatSlider(name='Alpha', start=0, end=1, value=0.5, step=0.01)


    # Streams: tie sliders to DynamicMap
    z_stream = streams.Stream.define('Z', Z=z_initial)()
    y_stream = streams.Stream.define('Y', Y=y_initial)()
    x_stream = streams.Stream.define('X', X=x_initial)()
    a_stream = streams.Stream.define('Alpha', alpha=0.5)()



    # Plotting dims (e.g., width and heigh of panels) (separate from data volume dims)
    # Scale plots by "phyiscal ranges"
    pdims = [dx*xprat*x_size, dy*yprat*y_size, dz*zprat*z_size]
    ipmax = np.argmax(pdims)

    dmax = pdims[ipmax]
    dscale = 1.0*npixels/dmax

    x_psize = int(dscale*pdims[0] + 0.5)
    y_psize = int(dscale*pdims[1] + 0.5)
    z_psize = int(dscale*pdims[2] + 0.5)


    # Internal functions: convert coord and index
    def coord2index(c,oc,dc,):
        return int((c - oc)/dc + 0.5)

    def index2coord(i,oc,dc):
        return oc + dc*i

    # Units
    pxlabel = xlabel if xunit is None else xlabel + f' ({xunit})'
    pylabel = ylabel if yunit is None else ylabel + f' ({yunit})'
    pzlabel = zlabel if zunit is None else zlabel + f' ({zunit})'


    # Sources and Receivers
    msize = 30
    scolor = 'red'
    slw = 3
    rlw = 3
    rcolor = 'yellow'
    smark = '*'
    rmark = 'inverted_triangle'
    xy_src_markers = hv.Points([])
    xz_src_markers = hv.Points([])
    yz_src_markers = hv.Points([])
    xy_rec_markers = hv.Points([])
    xz_rec_markers = hv.Points([])
    yz_rec_markers = hv.Points([])

    if src_coords is not None:
        xy_src_markers = hv.Points(src_coords[:,:-1]).opts(marker=smark, size=msize, color=scolor,line_width=slw)
        xz_src_markers = hv.Points(src_coords[:,::2]).opts(marker=smark, size=msize, color=scolor,line_width=slw)
        yz_src_markers = hv.Points(src_coords[:,::-1][:,:2]).opts(marker=smark, size=msize, color=scolor,line_width=slw)

    if rec_coords is not None:
        xy_rec_markers = hv.Points(rec_coords[:,:-1]).opts(marker=rmark, size=msize, color=rcolor,line_width=rlw)
        xz_rec_markers = hv.Points(rec_coords[:,::2]).opts(marker=rmark, size=msize, color=rcolor,line_width=rlw)
        yz_rec_markers = hv.Points(rec_coords[:,::-1][:,:2]).opts(marker=rmark, size=msize, color=rcolor,line_width=rlw)


    # XY Slice
    def slice_xy(Z, Y, X,alpha):
        iz = coord2index(Z,oz,dz)
        base_image = hv.Image((ox + dx*np.arange(x_size), oy+dy*np.arange(y_size), volume[:, :, iz].T)).opts(
            xlabel=pxlabel, ylabel=pylabel, cmap=cmap, invert_yaxis=False, clim=(global_min, global_max),
            toolbar=None, width=x_psize, height=y_psize, xaxis='top',fontsize=fontsize)

        # Check if extra_image_data is provided
        composite_image = None
        if overlay is not None:
            overlay_image = hv.Image((ox + dx*np.arange(x_size), oy+dy*np.arange(y_size), overlay[:, :, iz].T)).opts(
                alpha=alpha,  # Set the opacity level
                cmap=ocmap  # Optionally set a different colormap for the overlay
            )
            composite_image = base_image * overlay_image
        else:
            composite_image = base_image

        return (composite_image *
                hv.HLine(Y).opts(color='yellow', line_width=2) *
                hv.VLine(X).opts(color='yellow', line_width=2) *
                xy_src_markers *
                xy_rec_markers)


    # XZ Slice
    def slice_xz(Y, Z, X, alpha):
        iy = coord2index(Y,oy,dy)
        base_image = hv.Image((ox+dx*np.arange(x_size), oz+dz*np.arange(z_size), volume[:, iy, :].T)).opts(
            xlabel=pxlabel, ylabel=pzlabel, cmap=cmap, invert_yaxis=True, clim=(global_min, global_max),
            toolbar=None, width=x_psize, height=z_psize, yaxis='left',fontsize=fontsize)

        # Check if extra_image_data is provided
        composite_image = None
        if overlay is not None:
            overlay_image = hv.Image((ox+dx*np.arange(x_size), oz+dz*np.arange(z_size), overlay[:, iy, :].T)).opts(
                alpha=alpha,  # Set the opacity level
                cmap=ocmap  # Optionally set a different colormap for the overlay
            )
            composite_image = base_image * overlay_image
        else:
            composite_image = base_image

        return (composite_image *
                hv.HLine(Z).opts(color='yellow', line_width=2) *
                hv.VLine(X).opts(color='yellow', line_width=2) *
                xz_src_markers *
                xz_rec_markers)

    # YZ Slice
    def slice_yz(X, Z, Y, alpha):
        ix = coord2index(X, ox, dx)
        base_image = hv.Image((oz + dz * np.arange(z_size), oy + dy * np.arange(y_size), volume[ix, :, :])).opts(
            xlabel=pzlabel, ylabel=pylabel, cmap=cmap, invert_xaxis=False, clim=(global_min, global_max),
            toolbar=None, width=z_psize, height=y_psize, yaxis='right', xaxis='top',fontsize=fontsize)

        # Check if extra_image_data is provided
        composite_image = None
        if overlay is not None:
            overlay_image = hv.Image((oz + dz * np.arange(z_size), oy + dy * np.arange(y_size), overlay[ix,:,:])).opts(
                alpha=alpha,  # Set the opacity level
                cmap=ocmap  # Optionally set a different colormap for the overlay
            )
            composite_image = base_image * overlay_image
        else:
            composite_image = base_image

        return (composite_image *
                hv.VLine(Z).opts(color='yellow', line_width=2) *
                hv.HLine(Y).opts(color='yellow', line_width=2) *
                yz_src_markers *
                yz_rec_markers)


    # Dynamic Maps for each slice
    dmap_xy = hv.DynamicMap(slice_xy, streams=[z_stream, y_stream, x_stream, a_stream])
    dmap_xz = hv.DynamicMap(slice_xz, streams=[y_stream, z_stream, x_stream, a_stream])
    dmap_yz = hv.DynamicMap(slice_yz, streams=[x_stream, z_stream, y_stream, a_stream])

    # Blank panel
    blank_panel = hv.Text(0.5, 0.5, '').opts(width=z_psize, height=z_psize, xaxis=None, yaxis=None, border=0)

    # Arrange the layout in a 2x2 grid
    layout = ((dmap_xy + dmap_yz) + (dmap_xz + blank_panel)).cols(2)
    layout = layout.opts(shared_axes=False)


    print("Setup complete, callbacks should be active.")

    # Update Streams from slider values
    def update_z(event):
        # print("Z slider updated:", event.new)
        z_stream.event(Z=event.new)

    def update_y(event):
        y_stream.event(Y=event.new)

    def update_x(event):
        x_stream.event(X=event.new)

    def update_a(event):
        a_stream.event(alpha=event.new)

    # Attach event watcher
    x_slider.param.watch(update_x, 'value')
    y_slider.param.watch(update_y, 'value')
    z_slider.param.watch(update_z, 'value')
    a_slider.param.watch(update_a, 'value')

    # layout = pn.Row(pn.Column(x_slider, y_slider, z_slider, a_slider), xy_slice_view)

    
    return pn.Row(pn.Column(x_slider, y_slider, z_slider, a_slider), layout)

## Show Plots

**Note**: sliders may not work depending on Python or Jupyter environments (they work in Colab as of May 2025)

### Pressure-Field Overlay on Model

In [None]:
splot = np.array(p_field[-1])
vmax = 0.05*np.abs(splot).max()
splot = np.clip(splot,a_min=-vmax,a_max=vmax)

In [None]:
vp_plot = np.array(cp.data)
vp_plot *= vmax/vp_plot.max()

In [None]:
slicer = interactive_3d_slicer(vp_plot,
                               overlay=splot,
                               src_coords=src.coordinates.data,
                               dx=dx,
                               dy=dy,
                               dz=dz,
                               xunit='m',
                               yunit='m',
                               zunit='m',
                               zc_init=20,
                               npixels=600,
                               cmap='seismic',
                               ocmap='gray')

In [None]:
slicer

### $v_x$ Overlay on Pressure-Field

In [None]:
vx_plot = np.array(v[0].data[-1])
vx_plot /= np.abs(vx_plot).max()
vx_plot *= 3*vmax

In [None]:
slicer = interactive_3d_slicer(splot,
                               overlay=vx_plot,
                               src_coords=src.coordinates.data,
                               dx=dx,
                               dy=dy,
                               dz=dz,
                               xunit='m',
                               yunit='m',
                               zunit='m',
                               zc_init=200,
                               npixels=600,
                               cmap='gray',
                               ocmap='seismic')

In [None]:
slicer