# Full-waveform inversion (FWI) solver using automated differentiation

Full-waveform inversion (FWI) is a technique in general employed to estimate the physical parameters in a subsurface region. It is a wave-equation-based seeking an optimal match between observed and computed data. The former is recorded by receivers in a real case. The latter consists of a computed data, which is obtained by solving numerically a wave equation with an forcing term representing a source of wave emission. 

## Functional to be minimised
A tradicional form to measure the difference between the observed and computed data is given by the functional (Tarantola, 1984): 
$$
    I(u, u^{obs}) = \sum_{r=0}^{N-1} \int_\Omega \left(u(c,\mathbf{x},t)- u^{obs}(c, \mathbf{x},t)\right)^2 \delta(x - x_r) \, dx
$$
where $u = u(c, \mathbf{x},t)$ and $u_{obs} = u_{obs}(c,\mathbf{x},t)$, are respectively the computed and observed data, both recorded at a finite number of receivers ($N_r$), located at the point positions $\check{\mathbf{x}} \in \Omega_{0}$, in a time interval $\tau\equiv[t_0, t_f]\subset \mathbb{R}$, where $t_0$ is the initial time and $t_f$ is the final time. The spatial domain of interest is set as $\Omega_{0}$.

An FWI problem consists in finding an optimal parameter $c$ that minimizes the functional $I(u, u^{obs})$.

### Functional implementation
In Firedrake, the functional $I(u, u^{obs})$  is implemented as follows:

In [None]:
from firedrake import *
def functional(self, u_obs, u_computed):
    """Computes the functional J(u) = 0.5 * ||u - u_obs||^2.

    Parameters
    ----------
    u_obs : firedrake.Function
        The observed data at the receivers.
    u_computed : firedrake.Function
        The computed data at the receivers.

    Returns
    -------
    float
        The value of the functional.
    """
    return 0.5 * assemble(inner(u_computed - u_obs, u_computed - u_obs) * dx)
    

## Wave equation
To have the computed data, we need first solve a wave equation. In this example, we consider the scalar acoustic wave equation using a damping term to attenate the reflections originated by the boundaries of the domain.
$$
    \frac{\partial^2 u}{\partial t^2}(\mathbf{x},t)- c^2\frac{\partial^2 u}{\partial \mathbf{x}^2} = f(\mathbf{x}_s,t) \tag{2}
$$
where $c(\mathbf{x}):\Omega_{0}\rightarrow \mathbb{R}$ is the pressure wave ($P$-wave) velocity, which is assumed here a piecewise-constant and positive function. The source term $f(\mathbf{x}_s,t)$ is a function that represents the wave emission. Here, the source of waves is a Ricker wavelet, which is a function of time. 

The wave equation (2) satisfies the boundary conditions:
$$
    u(\mathbf{x},t) = 0, \quad \mathbf{x} \in \partial\Omega_{0}, \quad t\in\tau
$$
The acoustic wave equation should satisfy the initial conditions $u(\mathbf{x}, 0)  = 0 = u_t(\mathbf{x}, 0) = 0$. The domain $\Omega_{0}$ is illustrated at Figure. The boundaries $\partial\Omega_{i}$ with $i = 1,2,3$ are here referred to as truncated boundaries, and satisfy a null-Dirichlet boundary condition $u(\mathbf{x},t) = 0$. The boundary $\Omega_{4}$ satisfies the null-Neumann $\nabla u(\mathbf{x},t) \cdot \mathbf{n} = 0$ (free surface) boundary condition, where $\mathbf{n}$ represents the outward normal (with respect to $\partial \Omega_4$) unit vector.

To solve the wave equation, we consider the following weak form over the domain $\Omega_{0}$:
$$
    \int_{\Omega_{0}} \left(\frac{\partial^2 u}{\partial t^2}v + c^2\nabla u \cdot \nabla v\right) \, dx = \int_{\Omega_{0}} f v \, dx,
$$
for an arbitrary test function $v\in V$, where $V$ is a function space. The implementation of the weak form is given by the following Firedrake code:

In [None]:
import finat
def wave_equation_solver(c, source_function, dt, mesh):
    V = FunctionSpace(mesh, "KMV", 3)
    u = TrialFunction(V)
    v = TestFunction(V)

    u_np1 = Function(V) # timestep n+1
    u_n = Function(V) # timestep n
    u_nm1 = Function(V) # timestep n-1

    # quadrature rule for lumped mass matrix
    quad_rule = finat.quadrature.make_quadrature(V.finat_element.cell, V.ufl_element().degree(), "KMV")
    # time discretisation/mass matrix
    m =  (u - 2.0 * u_n + u_nm1) / Constant(dt * dt) * v * dx(scheme=quad_rule)
    # stiffness matrix
    a = c * c * dot(grad(u_n), grad(v)) * dx
    # wave source
    f = source_function * v * dx(scheme=quad_rule)
    F = m + a - f
    lhs_ = lhs(F)
    rhs_ = rhs(F)

    lin_var = LinearVariationalProblem(lhs_, rhs_, u_np1)
    solver = LinearVariationalSolver(lin_var, solver_parameters={"ksp_type": "preonly", "pc_type": "jacobi"})
    return solver, u_np1, u_n, u_nm1

In [None]:
def source_interpolator(mesh, wavelet, source_location, function_space):
    vom = VertexOnlyMesh(mesh, source_location, redundant=False)
    f_vom = FunctionSpace(vom, "DG", 0)
    f_vom_input_ordering = FunctionSpace(vom.input_ordering, "DG", 0)
    f_point_data_input_ordering = Function(f_vom_input_ordering)
    f_point_data_input_ordering.dat.data_wo[:] = wavelet
    f_vom_wavelet = interpolate(f_point_data_input_ordering, f_vom)
    return Interpolator(TestFunction(function_space, f_vom))

## Executing an wave equation solver
In this example, we consider a two dimensional domain where we want to estimate the parameter $c$ in a physical domain with the lenght of $1$ km ($L_x = 1km$) and the depth of $1$ km ($L_z = 1km$).

Below we create a dictionary containing the parameters necessary to solve the wave equation.

In [None]:
import numpy as np
Lx, Lz = 1.0, 1.0
num_receivers = 10
model = {
    "source_location": [0.5, 0.5],
    "receiver_locations": np.linspace((0.3, 0.1), (0.9, 0.1), num_receivers),
    "mesh": SquareMesh(Lx, Lz, 50, 50),
    "dt": 0.001,
    "final_time": 1.0,
    "frequency_peak": 7.0,
    "element_model": {"method": "KMV", "degree": 3, "quadrature": "KMV"},
}

In [None]:
from firedrake.pyplot import tricontourf
import matplotlib.pyplot as plt
def velocity_model(c_computed=False, plot_c=False):
    """Acoustic velocity model"""
    V = FunctionSpace(model["mesh"], model["element_model"]["method"], model["element_model"]["degree"])
    x, z = SpatialCoordinate(model["mesh"])
    if c_computed:
        c = Function(V).interpolate(1.5 + 0.0 * x)
    else:
        c = Function(V).interpolate(2.5 + 1 * tanh(100 * (0.125 - sqrt((x - 0.5) ** 2 + (z - 0.5) ** 2))))
    if plot_c:
        fig, axes = plt.subplots()
        levels = np.linspace(0, 1, 51)
        contours = tricontourf(c, levels=levels, axes=axes, cmap="inferno")
        axes.set_aspect("equal")
        fig.colorbar(contours)
        fig.show()
    return c

c = velocity_model(plot_c=True)

We then simulate the wave equation using the parameters defined in the dictionary as follows:

In [None]:
solver, u_np1, u_n, u_nm1 = wave_equation_solver(c, model["source_location"], model["dt"], model["mesh"])
t = 0
while t < model["final_time"]:
    f = source(ricker, t)
    # Call the solver object.
    solver.solve()

    # Exchange the solution at the two time-stepping levels.
    u_nm1.assign(u_n)
    u_n.assign(u_np1)

In [None]:
rec_data = []
solver, u_np1, u_n, u_nm1 = wave_eq_solver(c, source)
P, rec_interpolator = p0dg_interpolation(u_np1)
t = 0
while t < T:
    f = source(ricker, t)
    # Call the solver object.
    solver.solve()

    # Exchange the solution at the two time-stepping levels.
    u_nm1.assign(u_n)
    u_n.assign(u_np1)

    rec = fd.Function(P, name="rec")
    rec_interpolator.interpolate(output=rec)
    rec_data.append(rec)
