# FWI with Devito, Dask, buckets, DFT

In [None]:
#NBVAL_IGNORE_OUTPUT

# Set up inversion parameters.
param = {'t0': 0.,
         'tn': 1000.,              # Simulation last 1 second (1000 ms)
         'f0': 0.010,              # Source peak frequency is 10Hz (0.010 kHz)
         'nshots': 5,              # Number of shots to create gradient from
         'm_bounds': (0.08, 0.25), # Set the min and max slowness
         'shape': (101, 101),      # Number of grid points (nx, nz).
         'spacing': (10., 10.),    # Grid spacing in m. The domain size is now 1km by 1km.
         'origin': (0, 0),         # Need origin to define relative source and receiver locations.
         'nbpml': 40}              # nbpml thickness.

import numpy as np

import scipy
from scipy import signal, optimize

from devito import Grid

from distributed import Client, LocalCluster, wait

import cloudpickle as pickle

# Import acoustic solver, source and receiver modules.
from examples.seismic import Model, demo_model
from examples.seismic.acoustic import AcousticWaveSolver
from examples.seismic import TimeAxis, PointSource, RickerSource, Receiver

# Import convenience function for plotting results
from examples.seismic import plot_image

In [None]:
def download_file_from_bucket(filename):
    client = storage.Client(project='seg-demo-project-2')
    bucket = client.get_bucket('datasets-proxy')
    blob = bucket.get_blob(filename)
    with open(filename, 'wb') as f:
        blob.download_to_file(f)


def upload_file_to_bucket(filename):
    client = storage.Client(project='seg-demo-project-2')
    bucket = client.get_bucket('datasets-proxy')
    blob = storage.Blob(filename, bucket)
    with open(filename, 'rb') as f:
        blob.upload_from_file(f)


def from_hdf5(filename, **kwargs):
    import h5py

    f = h5py.File(filename, 'r')

    origin = kwargs.pop('origin', None)
    if origin is None:
        origin_key = kwargs.pop('origin_key', 'o')
        origin = tuple(f[origin_key][()])

    spacing = kwargs.pop('spacing', None)
    if spacing is None:
        spacing_key = kwargs.pop('spacing_key', 'd')
        spacing = tuple(f[spacing_key][()])

    nbpml = kwargs.pop('nbpml', 20)
    datakey = kwargs.pop('datakey', None)
    if datakey is None:
        raise ValueError("Must specify datakey")

    space_order = kwargs.pop('space_order', None)
    dtype = kwargs.pop('dtype', None)
    data_m = f[datakey][()]
    data_vp = np.sqrt(1/data_m).astype(dtype)
    data_vp = np.transpose(data_vp, (1, 2, 0))
    shape = data_vp.shape

    f.close()

    return Model(space_order=space_order, vp=data_vp, origin=origin,
                 shape=shape, dtype=dtype, spacing=spacing, nbpml=nbpml)

def get_initial_model():
    filename = 'overthrust_3D_initial_model.h5'

    model_file = Path(filename)
    if not model_file.is_file():
        download_file_from_bucket(filename)

    return from_hdf5(filename, nbpml=param['nbpml'], space_order=8,
                     datakey='m', dtype=np.float32)

In [None]:
import os

filename = "/tmp/seg-demo-project-2-2fa1703fb1c7.json"
jsonfile = open(filename, 'w')
jsonfile.write(r"""{
  "type": "service_account",
  "project_id": "seg-demo-project-2",
  "private_key_id": "2fa1703fb1c7327ca598b6d6ac7a8d7b3501d524",
  "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCUcKay5c3FS+Dc\nDvMoy3hZtHvuBtu3bGkTClsjHZtnfm65CqFuWyxSpYkDtmi7KR1SQY4+WeP/uX5z\nl1CQhowr1D3R6IdkeJxuu2+xEKscvAW6UukxUiCo8i7Ov3Ve5zusaplnizT9rxUO\nCmYNvQYyV/4jViy4LqKphIC9UZtOo+5R7osP8LJmBvCvrVm+hSE7Ob/rHKMcvoPq\n9odEMdRT8n/4tLJ2auF6XnM2uzKgyt/iAEdlzM6VaSDKlSuNJEXTpqEW/vsQUWnL\n8G0zOxapsrEar63fMz4qyyP5qxiP8GiGYDi9onhJijDKTR6NlFWMqDewXBJCp+sh\n+fb+IRhXAgMBAAECggEAOsKpSpIrtZlG5vXIDfMVrjUDBgOTAHYa1h24XBtBRGPJ\nQtjRdJUE46aBqYWQyd4JrGElBMuitL1iMDHLA5sva741xp1M01nnWvI50ZbulY5c\nmhhuFUcUhKxyGQezV6EjfyonldSGYpHnPMqjAXc9N7qbcLOROkvDumUobUkbuLIT\neuVMm3Jfp+fHjBO7pUB5FPd6c5vXtTnp+orf5BpSzQSecHCdf6A4xvPdPScnE15j\nFYSwc41/dG/hkVtn9DwySgOPsjS6ljUUE57Zitmqy0LeLrPx4QqWhbl/TSeYCOK1\naM3KrIH79XEslTZzWTwMPi924xKWY06HWI3FBN4siQKBgQDNvQPBx9ajKgYqUM0D\nWj13x3zlNBACIT+uoSBToLT5YiS0vwlwprw7UvC3fAFCOgvena3UtdcJKdiiWU65\nAFM8WPfR5ekF1Jg1yGQCe1raYFmV/UCiBzDf6PH8po3OyoTgfWvdye3jClfMtD2L\nHW9BCLy9oU86J6f5sFMp5OF2wwKBgQC4tCur/x6tNwP3s3rep/jGibo8Okczbfcr\nNm2g5R3CpY0lPRex3hJ8wcY0vOc63k886KIAnQOvRwpZ5lO/BOLkfiDQk1XLcZ/m\nPSrQH2MOZupQ3nndkGJdx82jPOcuZKi397aoO8h4yzgsWV5iaHvInq8k+TK95c2p\nCw4XbScG3QKBgF6KA4bxMGULs4eQV5S5y9MVnQOpt81yv9OcAMHM8DxEZ/+SZWEZ\njRdplmmKFv36tGeYZz9+S5DPZNe9WBpU1uq9KyuNjVV7inH0YlhtnKMKcUAl/qQ1\nz7SkU0Y4tqMWlpadq4pf9utEXnIXRMx/OxdUT36H+GMNw/dNmfl/TkeHAoGAF/j6\nehZgquarEykuV1vBxDL4Av0lZJ1vKSKlU+6o0CyghybIvoMuLxcPXKTv9goIisU+\n0YmPgt5bj5N/ZxmBQVrFc4zL493ZfQ6PUffg6WueGeTmOEWXHsjh/b/X2YOjCk2S\nXX9044isv8TRpUAeYMmHverCFTeQW9Jdf9jg6dkCgYA5pHBk/XrfZDori354jYIw\nuJgcrPmvLf72wSNAMdKBLXLoYEcZ1hCp6iLaM6hQwXa8vzLLPzprDPe5ZCqVp3sE\nQLS0Rpd0UAs/y4XKGKTVq4xgXUTRS0x91GB1Qs3HiI2DN2duJfanUs2eJd4A2FJJ\nXb9XeDzqftxsEgsK+LpV6A==\n-----END PRIVATE KEY-----\n",
  "client_email": "gerardmac@seg-demo-project-2.iam.gserviceaccount.com",
  "client_id": "107431839642113475492",
  "auth_uri": "https://accounts.google.com/o/oauth2/auth",
  "token_uri": "https://oauth2.googleapis.com/token",
  "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
  "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gerardmac%40seg-demo-project-2.iam.gserviceaccount.com"
}
""")
jsonfile.close()
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = filename

In [None]:
# Get a list of all the shots in a bucket.

from google.cloud import storage

def list_blobs_with_prefix(bucket_name, prefix):
    """Lists all the blobs in the bucket that begin with the prefix.

    This can be used to list all blobs in a "folder", e.g. "public/".

    """
    storage_client = storage.Client(project='seg-demo-project-2')
    bucket = storage_client.get_bucket(bucket_name)

    blobs = bucket.list_blobs(prefix=prefix)

    return [blob.name for blob in blobs]

list_of_shots = list_blobs_with_prefix('datasets-proxy', "shot")

In [None]:
def load_model(filename):
    """ Returns the current model. This is used by the
    worker to get the current model.
    """

    model_file = Path(filename)
    if not model_file.is_file():
        download_file_from_bucket(filename)

    pkl = pickle.load(open(filename, "rb"))
    
    return pkl['model']

def dump_model(filename, model):
    ''' Dump model to disk.
    '''
    pickle.dump({'model':model}, open(filename, "wb"))
    
    upload_file_to_bucket(filename)

In [None]:
def load_shot_data(filename, dt):
    ''' Load shot data from bucket, resampling to the model time step.
    '''
    model_file = Path(filename)
    if not model_file.is_file():
        download_file_from_bucket(filename)
    
    pkl = pickle.load(open(filename, "rb"))
    
    return pkl['src'].resample(dt), pkl['rec'].resample(dt)

In [None]:
# Define a type to store the functional and gradient.
class fg_pair:
    def __init__(self, f, g):
        self.f = f
        self.g = g
    
    def __add__(self, other):
        f = self.f + other.f
        g = self.g + other.g
        
        return fg_pair(f, g)
    
    def __radd__(self, other):
        if other == 0:
            return self
        else:
            return self.__add__(other)

In [None]:
from devito import Function

# Create FWI gradient kernel for a single shot
def fwi_gradient_i(param):
    from devito import clear_cache

    # Need to clear the workers cache.
    clear_cache()

    # Load the current model and the shot data for this worker.
    # Note, unlike the serial example the model is not passed in
    # as an argument. Broadcasting large datasets is considered
    # a programming anti-pattern and at the time of writing it
    # it only worked relaiably with Dask master. Therefore, the
    # the model is communicated via a file.
    model0 = load_model(param['model'])
    
    dt = model0.critical_dt

    src, rec = load_shot_data(param['shot_filename'], dt)
    
    # Set up solver.
    solver = AcousticWaveSolver(model0, src, rec, space_order=8)

    # Compute simulated data and full forward wavefield u0
    d, ufr, ufi = solver.forward_freq_modeling(freq=(0.003, 0.004, 0.005), factor=10, src=src)

    # Compute the data misfit (residual) and objective function
    residual = Receiver(name='rec', grid=model0.grid,
                        time_range=rec.time_range,
                        coordinates=rec.coordinates.data)

    residual.data[:] = d.data[:] - rec.data[:]
    f = .5*np.linalg.norm(residual.data.flatten())**2
    
    # Compute gradient using the adjoint-state method. Note, this
    # backpropagates the data misfit through the model.
    grad = Function(name="grad", grid=model0.grid)

    solver.adjoint_freq_born(recin=residual, freq=(0.003, 0.004, 0.005), factor=10, ufr=ufr, ufi=ufi, grad=grad)
    
    # Copying here to avoid a (probably overzealous) destructor deleting
    # the gradient before Dask has had a chance to communicate it.
    g = np.array(grad.data[:])
    
    # return the objective functional and gradient.
    return fg_pair(f, g)

In [None]:
import random

def fwi_gradient(model, param, iid):
    # Dump a copy of the current model for the workers
    # to pick up when they are ready.
    param['model'] = "model_%d.p"%iid
    dump_model(param['model'], model)

    # Select a random sample of shots - choosing 28 because that's
    # the size of my test cluster.
    minibatch = random.sample(list_of_shots, 28)
    
    # Create worklist.
    worklist = []
    for shot in minibatch:
        worklist.append(dict(param))
        worklist[-1]['shot_filename'] = shot        
        
    # Distribute worklist.
    fgi = client.map(fwi_gradient_i, work, retries=1)
    
    # Perform reduction.
    fg = client.submit(sum, fgi).result()
    
    return fg.f, fg.g

# Define bounding box constraints on the solution.
def apply_box_constraint(m):
    # Maximum possible 'realistic' velocity is 3.5 km/sec
    # Minimum possible 'realistic' velocity is 2 km/sec
    return np.clip(m, 1/3.5**2, 1/2**2)

In [None]:
# Start Dask cluster
cluster = LocalCluster(n_workers=5, death_timeout=600)
client = Client(cluster)

In [None]:
fwi_iterations = 5
model_i = get_initial_model()

# Run FWI with gradient descent
history = np.zeros((fwi_iterations, 1))
for i in range(fwi_iterations):
    # Compute the functional value and gradient for the current
    # model estimate
    phi, direction = fwi_gradient(model_i, param, i)
    
    # Store the history of the functional values
    history[i] = phi
    
    # Artificial Step length for gradient descent
    # In practice this would be replaced by a Linesearch (Wolfe, ...)
    # that would guarantee functional decrease Phi(m-alpha g) <= epsilon Phi(m)
    # where epsilon is a minimum decrease constant
    alpha = .005 / np.max(direction)
    
    # Update the model estimate and inforce minimum/maximum values
    model_i.m.data[:] = apply_box_constraint(model_i.m.data - alpha * direction)
    
    # Log the progress made
    print('Objective value is %f at iteration %d' % (phi, i+1))

We now apply our FWI function and have a look at the result.

In [None]:
#NBVAL_SKIP

# Show what the update does to the model
from examples.seismic import plot_image, plot_velocity

model0.vp = np.sqrt(1. / model0.m.data[40:-40, 40:-40])
plot_velocity(model0)

In [None]:
#NBVAL_SKIP

# Plot percentage error
plot_image(100*np.abs(model0.vp-get_true_model().vp.data)/get_true_model().vp.data, vmax=15, cmap="hot")

In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt

# Plot objective function decrease
plt.figure()
plt.loglog(relative_error)
plt.xlabel('Iteration number')
plt.ylabel('True relative error')
plt.title('Convergence')
plt.show()