# MPI for Sonar simulation

## Imports

In [None]:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np
from devito import configuration
from devito import Eq, Operator, TimeFunction, solve
from examples.seismic import Model, TimeAxis, WaveletSource, Receiver
configuration['mpi'] = True
configuration['language'] = 'C'

In [None]:
class GaborSource(WaveletSource):
    def __init_finalize__(self, *args, **kwargs):
        super(GaborSource, self).__init_finalize__(*args, **kwargs)

    @property
    def wavelet(self):
        assert self.f0 is not None
        agauss = 0.5 * self.f0
        tcut = self.t0 or 5 / agauss
        s = (self.time_values - tcut) * agauss
        a = a or 1
        return a * np.exp(-0.5 * s**2) * np.cos(2 * np.pi * s)

## Initialization

In [None]:
domain_size = (60, 30)
v_env = 1.5
ns = 128
source_distance = 0.002
f0 = 50
space_order = 8
spatial_dist = round(v_env / f0 / 3, 6)
dt = spatial_dist / 20

domain_dims = (domain_size[0] / spatial_dist, domain_size[1] / spatial_dist)
vp = np.full(domain_dims, v_env, dtype=np.float32)
y_wall = max(int(domain_dims[1] * 0.8), round(domain_dims[1] - 5 / spatial_dist))
vp[:, y_wall:] = 1.5

In [None]:
model = Model(
    vp=vp,
    origin=(0.0, 0.0),
    shape=domain_dims,
    spacing=(spatial_dist, spatial_dist),
    space_order=space_order,
    nbl=(ns - 1) / 2 * source_distance / dt,
    bcs="damp",
    dt=dt,
    dtype=np.float64,
)

In [None]:
tn = math.sqrt((domain_size[0] / 2) ** 2 + domain_size[1] ** 2) * 2 / v_env + 5
time_range = TimeAxis(start=0, stop=tn, step=dt)

In [None]:
cy = (ns - 1) / 2 * source_distance
coordinates = np.array([(domain_size[0] - source_distance * ns) / 2, cy]) + np.linspace(
    [0, source_distance], [ns * source_distance, source_distance], num=ns
)

src = GaborSource(
    name="src",
    grid=model.grid,
    npoint=ns,
    f0=f0,
    time_range=time_range,
    coordinates_data=coordinates,
)
rec = Receiver(
    name="rec",
    grid=model.grid,
    time_range=time_range,
    npoint=ns,
    coordinates=coordinates,
)   

In [None]:
u = TimeFunction(name="u", grid=model.grid, time_order=2, space_order=space_order)
pde = model.m * u.dt2 - u.laplace + model.damp * u.dt
stencil = Eq(u.forward, solve(pde, u.forward))
src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m)
rec_term = rec.interpolate(expr=u)
op = Operator([stencil] + src_term + rec_term, subs=model.spacing_map)

## Run simulation

In [None]:
def run_beam(src, rec, op, u, time_range, dt, alpha, v_env):
    ns = src.coordinates.data.shape[0]
    if alpha <= 90:
        max_latency = (
            np.cos(np.deg2rad(alpha)) * ((ns - 1) * source_distance / v_env) / dt
        )
    elif alpha > 90:
        max_latency = np.cos(np.deg2rad(alpha)) * (source_distance / v_env) / dt
    for i in range(ns):
        latency = -np.cos(np.deg2rad(alpha)) * (i * source_distance / v_env)
        src.data[:, i] = np.roll(src.data[:, i], int(latency / dt + max_latency))
    u.data.fill(0)
    op(time=time_range.num - 2, dt=dt)
    
    return rec.data

In [None]:
def run_beam_mpi(src, op, u, time_range, dt, alpha, v_env):
    ns = src.coordinates.data.shape[0]
    if alpha <= 90:
        max_latency = (
            np.cos(np.deg2rad(alpha)) * ((ns - 1) * source_distance / v_env) / dt
        )
    elif alpha > 90:
        max_latency = np.cos(np.deg2rad(alpha)) * (source_distance / v_env) / dt
        
    all_data = comm.gather(src.data, root=0)
    
    if rank == 0:
        full_data = np.concatenate(all_data, axis=1)
        for i in range(ns):
            latency = -np.cos(np.deg2rad(alpha)) * (i * source_distance / v_env)
            src.data[:, i] = np.roll(src.data[:, i], int(latency / dt + max_latency))
        divided_data = np.array_split(full_data, size, axis=1)
    else:
        divided_data = None
        
    new_data = comm.scatter(divided_data, root=0)
    np.copyto(src.data, new_data)
    u.data.fill(0)
    op(time=time_range.num - 2, dt=dt)
    
    return rec.data