# Full-waveform inversion using automatic differentiation

## Cost function
In essence, full waveform inversion (FWI) consists of a local optimisation, where the objective is to minimise the difference between observed and predicted seismogram data. Following Tarantola (1984), the cost function can be measured by the $L^2$ norm, which may be written as follows, in a continuous space. 
$$
    I (m)\equiv \frac{1}{2} \int_{\tau} \int_{\Omega} \left(u(\mathbf{m},\mathbf{x},t)- u^{obs}(\mathbf{m},\mathbf{x},t)\right)^2 \delta(\mathbf{x}- \check{\mathbf{x}})\, \text{d} \mathcal{V} \text{d} t . 
$$

The data functions, $u = u(\check{\mathbf{x}},t)$ and $u_{obs} = u_{obs}(m,\check{\mathbf{x}},t)$, are respectively the predicted and observed data, both recorded at a finite set of receivers, located at the point positions $\check{\mathbf{x}} \in \Omega_{0}$, in a time interval $\tau\equiv[t_0, t_f]\subset \mathbb{R}$, where $t_0$ is the initial time and $t_f$ is the final time. The term $\delta(\mathbf{x}- \check{\mathbf{x}})$ is the delta Dirac function to model the receiver point positions. The spatial domain of interest is set as $\Omega_{0}$.

In when computing the cost function in the discrete space and using finite elemte method, we have the following exepression for the cost function:


Which is computed by the `cost_function` using the `firedrake` package. The function is defined as follows:

In [None]:
import firedrake as fd
def cost_function(u, u_obs):
    J = fd.assemble(0.5*fd.inner(u - u_obs, u - u_obs) * fd.dx)

## Wave equation
The predicted data, $u = u (\mathbf{m},\mathbf{x},t)$, is modeled here by an acoustic wave equation,
$$
    m(\mathbf{x})\frac{\partial^2 u}{\partial t^2}(\mathbf{x},t)-\frac{\partial^2 u}{\partial \mathbf{x}^2} = f(\mathbf{x},t),
$$

The variable coefficient $m(\mathbf{x}):\Omega_{0}\rightarrow \mathbb{R}$ is such that $m(\mathbf{x})= \displaystyle\frac{1}{c^2(\mathbf{x})}$, where $c(\mathbf{x}):\Omega_{0}\rightarrow \mathbb{R}$ is the pressure wave ($P$-wave) velocity, which is assumed to be piecewise-constant and positive. The external force term $f(\mathbf{x},t):\Omega_{0}\rightarrow \mathbb{R}$, models the source of waves and is usually described by a Ricker Wavelet \citep{ricker1940form}.

The Ricker waveletet is given by the following code:


In [None]:
import numpy as np
def ricker_wavelet(t, freq, amp=1.0):
    # Shift in time so the entire wavelet is injected
    t = t - (np.sqrt(6.0) / (np.pi * freq))
    return amp * (
        1.0 - (1.0 / 2.0) * (2.0 * np.pi * freq) * (2.0 * np.pi * freq) * t * t
    )

Before to write the acoustic wave equation, let us define the time-step, the final time, the mesh, the receivers and source positions. 

In this example, we consider a two dimensional domain

In [None]:

mesh = fd.UnitSquareMesh(50, 50)
T = 1.0
dt = 0.001
t = 0
step = 0
freq = 7
c = fd.Constant(1.5)
receivers = np.linspace((-0.1, 0.2), (-0.1, 0.8), 1)

In [None]:
import finat

V = fd.FunctionSpace(mesh, "KMV", 2)
u = fd.TrialFunction(V)
v = fd.TestFunction(V)

u_np1 = fd.Function(V)  # timestep n+1
u_n = fd.Function(V)    # timestep n
u_nm1 = fd.Function(V)  # timestep n-1
quad_rule = finat.quadrature.make_quadrature(V.finat_element.cell, V.ufl_element().degree(), "KMV")
dxlump=fd.dx(scheme=quad_rule)
m = (u - 2.0 * u_n + u_nm1) / fd.Constant(dt * dt) * v * dxlump
a = c * c * fd.dot(fd.grad(u_n), fd.grad(v)) * dxlump

x, y = fd.SpatialCoordinate(mesh)
source_pos = fd.Constant([0.5, 0.5]) # source position
ricker = fd.Constant(0.0)
R = fd.Function(V)
delta = fd.exp(-2000 * ((x - source_pos[0]) ** 2 + (y - source_pos[1]) ** 2))
f = delta * ricker * v * dxlump
F = m + a - f

lhs_ = fd.lhs(F)
rhs_ = fd.rhs(F)

lin_var = fd.LinearVariationalProblem(lhs_, rhs_, u_np1)
solver = fd.LinearVariationalSolver(lin_var, 
                                    solver_parameters=
                                    {"ksp_type": "preonly", "pc_type": "jacobi"})

In [None]:
P = fd.VectorFunctionSpace(receivers, "DG", 0)
rec_interpolator = fd.Interpolator(u_np1, P)
rec_data = []
while t < T:
    ricker.assign(RickerWavelet(t, freq))
    # Call the solver object.
    solver.solve()

    # Exchange the solution at the two time-stepping levels.
    u_nm1.assign(u_n)
    u_n.assign(u_np1)

    rec = fd.Function(P, name="rec")
    rec_interpolator.interpolate(output=rec)
    rec_data.append(rec)

In [None]:
def forward_wave_equation(c):
    V = fd.FunctionSpace(mesh, "KMV", 2)

    u = fd.TrialFunction(V)
    v = fd.TestFunction(V)

    u_np1 = fd.Function(V)  # timestep n+1
    u_n = fd.Function(V)    # timestep n
    u_nm1 = fd.Function(V)  # timestep n-1
    quad_rule = finat.quadrature.make_quadrature(V.finat_element.cell, V.ufl_element().degree(), "KMV")
    dxlump=fd.dx(scheme=quad_rule)
    m = (u - 2.0 * u_n + u_nm1) / fd.Constant(dt * dt) * v * dxlump
    a = c * c * fd.dot(fd.grad(u_n), fd.grad(v)) * dxlump

    x, y = fd.SpatialCoordinate(mesh)
    source_pos = fd.Constant([0.5, 0.5]) # source position
    ricker = fd.Constant(0.0)
    delta = fd.exp(-2000 * ((x - source_pos[0]) ** 2 + (y - source_pos[1]) ** 2))
    f = delta * ricker * v * dxlump
    F = m + a - f

    lhs_ = fd.lhs(F)
    rhs_ = fd.rhs(F)

    lin_var = fd.LinearVariationalProblem(lhs_, rhs_, u_np1)
    solver = fd.LinearVariationalSolver(lin_var, 
                                        solver_parameters=
                                        {"ksp_type": "preonly", "pc_type": "jacobi"})

    return solver, u_n, u_nm1, u_np1, ricker

In [None]:
def true_data(c):
    solver, u_n, u_nm1, u_np1, ricker = forward_wave_equation(c)
    step = 0
    rec_data = []
    P = fd.VectorFunctionSpace(receivers, "DG", 0)
    rec_interpolator = fd.Interpolator(u_np1, P)
    while t < T:
        ricker.assign(RickerWavelet(t, freq))
        # Call the solver object.
        solver.solve()

        # Exchange the solution at the two time-stepping levels.
        u_nm1.assign(u_n)
        u_n.assign(u_np1)

        rec = fd.Function(P, name="rec")
        rec_interpolator.interpolate(output=rec)
        rec_data.append(rec)
        
        # Increment the time and print the elapsed time every 10 steps.
        t += dt
        step += 1
        if step % 10 == 0:
            print("Elapsed time is: "+str(t))
        
    return rec_data

true_rec_data = true_data(c)

## Constraint optimisation

As mentioned in the first part of this section, in FWI the goal is to minimize the misfit function, which can be measured by eq.~\eqref{objFunc1}. Typically, this minimization is carried out by employing a local optimisation method. Thus, it is necessary to obtain the gradient, $\nabla_{m} I(m)$, which may be computed efficiently by the adjoint method \citep{Plessix:2006}. 

In [None]:
import firedrake.adjoint as adj


adj.continue_annotation()
solver, u_n, u_nm1, u_np1, ricker = forward_wave_equation(c)
P = fd.VectorFunctionSpace(receivers, "DG", 0)
rec_interpolator = fd.Interpolator(u_np1, P)

step = 0
J = 0.0
rec_diff = []
while t < T:
    ricker.assign(RickerWavelet(t, freq))
    # Call the solver object.
    solver.solve()

    # Exchange the solution at the two time-stepping levels.
    u_nm1.assign(u_n)
    u_n.assign(u_np1)
    
    guess_rec = fd.Function(P, name="rec")
    rec_interpolator.interpolate(output=guess_rec)
    rec_dif_t = guess_rec - true_rec_data[t]
    rec_diff.append(guess_rec - true_rec_data[t])
    J += cost_function(guess_rec, guess_rec[t])

    # Increment the time and print the elapsed time every 10 steps.
    t += dt
    if step % 10 == 0:
        print("Elapsed time is: "+str(t))
    step += 1

adj.Control(c)
Jhat = adj.ReducedFunctional(J, c)