# Parallel computation with Dask for InSAR applications

A practical introduction to parallel computation for DePSI developers.

<img src="figs/netherlands-escience-center-logo-RGB.svg" alt="nlesc-logo" width=50%/>

---

## Why Dask? 

Disclaimer:

- It may not be faster than numpy;
- It may not be more efficient than numpy;

Dask is a solution for **Scaling Up** your computation with **Less pain**.

- Operates with larger-than-memory data;
- Efficiently use HPC resources;
- Minimal changes to your numpy code;

---

## Dask in a nutshell

<img src="figs/dask-overview.svg" alt="dask-overview" width=80%/>

---

## Configure scheduler

This is to tell Dask how to utilize the available resources.

Check [Dask documentation for scheduling](https://docs.dask.org/en/stable/scheduling.html) for more details.

### Single-machine scheduler

In [None]:
import dask

# threaded scheduler, usually the default option
dask.config.set(scheduler='threads')  

In [None]:
# process scheduler, not recommended by dask. Use local cluster instead.
dask.config.set(scheduler='processes')

In [None]:
# force single-threaded execution, very useful for debugging
dask.config.set(scheduler='synchronous')

### Distributed

In [None]:
# Initiate a local cluster

from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=4, threads_per_worker=2) # limit the concurrency resources to avoid memory issues
client = Client(cluster)

In [None]:
# Initiate a SLURM dask cluster
# We use this in HPC with SLURM system

from dask.distributed import SLURMCluster, Client

cluster = SLURMCluster(
    name="dask-worker",  # Name of the Slurm job
    queue="normal", # Name of the node partition on your SLURM system
    cores=4, # Number of cores per worker
    memory="32 GB",  # Total amount of memory per worker
    processes=1,  # Number of Python processes per worker
    walltime="3:00:00",  # Reserve each worker for X hour
)

---

## Example 1: apply_gufunc

[DePSI/slc.py](https://github.com/TUDelftGeodesy/DePSI/blob/main/depsi/slc.py) reconstruct SLC complex from ifg and mother complex.

In [2]:
def ifg_to_slc(mother_slc, ifgs):
    slc_out = ifgs.copy()
    meta_arr = np.array((), dtype=np.complex64)
    slc_complex = da.apply_gufunc(
        _slc_complex_recontruct,
        "(),()->()",
        mother_slc["complex"],
        slc_out["complex"],
        meta=meta_arr,
    )
    slc_out = slc_out.assign({"complex": (("azimuth", "range", "time"), slc_complex)})
    return slc_out


def _slc_complex_recontruct(mother_slc_complex, ifg_complex):
    return ifg_complex / mother_slc_complex

## Example 2: map_blocks

In [1]:
def enrich_from_polygon(self, polygon, fields, xlabel="lon", ylabel="lat"):
    """Enrich the SpaceTimeMatrix from one or more attribute fields of a (multi-)polygon.

    Each attribute in fields will be assigned as a data variable to the STM.

    If a point of the STM falls into the given polygon, the value of the specified field will
    be added.

    For space entries outside the (multi-)polygon, the value will be None.

    Parameters
    ----------
    polygon : geopandas.GeoDataFrame, str, or pathlib.Path
        Polygon or multi-polygon with contextual information for enrichment
    fields : str or list of str
        Field name(s) in the (multi-)polygon for enrichment
    xlabel : str, optional
        Name of the x-coordinates of the STM, by default "lon"
    ylabel : str, optional
        Name of the y-coordinates of the STM, by default "lat"

    Returns
    -------
    xarray.Dataset
        Enriched STM.

    """
    _ = _validate_coords(self._obj, xlabel, ylabel)

    # Check if fields is a Iterable or a str
    if isinstance(fields, str):
        fields = [fields]
    elif not isinstance(fields, Iterable):
        raise ValueError("fields need to be a Iterable or a string")

    # Get polygon type and the first row
    if isinstance(polygon, gpd.GeoDataFrame):
        type_polygon = "GeoDataFrame"
        polygon_one_row = polygon.iloc[0:1]
    elif isinstance(polygon, Path | str):
        type_polygon = "File"
        polygon_one_row = gpd.read_file(polygon, rows=1)
    else:
        raise NotImplementedError("Cannot recognize the input polygon.")

    # Check if fields exists in polygon
    for field in fields:
        if field not in polygon_one_row.columns:
            raise ValueError(f'Field "{field}" not found in the the input polygon')

    # Enrich all fields
    ds = self._obj
    chunks = (ds.chunksizes["space"][0],)  # Assign an empty fields to ds
    for field in fields:
        ds = ds.assign(
            {
                field: (
                    ["space"],
                    da.from_array(np.full(ds.space.shape, None), chunks=chunks),
                )
            }
        )
    ds = xr.map_blocks(
        _enrich_from_polygon_block,
        ds,
        args=(polygon, fields, xlabel, ylabel, type_polygon),
        template=ds,
    )

    return ds

In [None]:
def _enrich_from_polygon_block(ds, polygon, fields, xlabel, ylabel, type_polygon):
    """Block-wise function for "enrich_from_polygon"."""
    # Get the match list
    match_list, polygon = _ml_str_query(ds[xlabel], ds[ylabel], polygon, type_polygon)

    _ds = ds.copy(deep=True)

    if match_list.ndim == 2:
        intuids = np.unique(match_list[:, 0])
        for intuid in intuids:
            intm = np.where(match_list[:, 0] == intuid)[0]
            intmid = match_list[intm, 1]
            for field in fields:
                _ds[field].data[intmid] = polygon.iloc[intuid][field]

    return _ds

## Example 3: groupby + map + map_blocks

In [7]:
def full_batch_one_group(stm_pnt, slc_quality_ref, h2ph_ref):
    
    # stm_pnt = stm_pnt.compute()
    slc_quality_pnts = stm_pnt['slc_quality']
    h2ph_pnts = stm_pnt['h2ph_values']
    dd_arc = stm_pnt['dd_complex']
    # # Compute the diagonal of the VCM of the dd phases
    Qyy_diagonal = np.sqrt((slc_quality_ref)**2 + (slc_quality_pnts)**2)


    # Compute the variance covariance matrix of the DD based on the NMAD for the arc
    Qyy = np.identity(len(stm_pnt.time))*Qyy_diagonal.to_numpy()**2

    # Compute 'mean' h2ph value for the arc (which we currently model as the average of the two time series)
    
    h2ph_arc = (h2ph_ref + h2ph_pnts)/2
    # h2ph_arc = h2ph_arc.to_numpy()
    h2ph_arc = h2ph_arc.squeeze().values

    # Get the wrapped phase
    phs_wrapped = np.angle(dd_arc)
    phs_wrapped = phs_wrapped.squeeze()

    # Define y and the corresponding VQM Qyy
    # y = np.append(phs_wrapped, [0, 0, 0, 0])
    Q_phs = Qyy
    Q_b = np.diag([sigma_offset**2, sigma_vel**2, sigma_h**2, sigma_ther**2])
    # Qyy = np.block([[Q_phs, np.zeros((Q_phs.shape[0], Q_b.shape[1]))],
    #                 [np.zeros((Q_b.shape[0], Q_phs.shape[1])), Q_b]])

    # Define the design matrices
    A1 = np.diag([-2*np.pi] * len(phs_wrapped))
    # B2 = np.diag([1] * 4)
    # C = np.block([[A1, np.zeros((A1.shape[0], B2.shape[1]))],
    #                 [np.zeros((B2.shape[0], A1.shape[1])), B2]])

    B1 = np.ones((phs_wrapped.shape[0],4))
    B1[:,1] = years*(-4*np.pi/wavelength/1000)
    B1[:,2] = h2ph_arc*(-4*np.pi/wavelength)
    B1[:,3] = temp*(-4*np.pi/wavelength/1000)
    # C[:len(phs_wrapped), -4:] = B1

    # Float solution with least-squares
    ahat = phs_wrapped/(-2*np.pi)
    Qahat = 1/(4*((np.pi)**2)) * (Q_phs + B1@Q_b@B1.T)


    # Lambda method - Integer bootstrapping
    afixed,sqnorm,Ps,Qzhat,Z,nfixed,mu = LAMBDA.main(ahat,Qahat,3)

    # Calculate the unwrapped phase [rad]
    phs_unw  = phs_wrapped - A1@afixed
    
    # Get the estimated parameters
    b_hat = np.linalg.inv((B1.T@np.linalg.inv(Q_phs)@B1))@B1.T@np.linalg.inv(Q_phs)@phs_unw
    
    # Get the phase for estimated DD observation, non-thermal displacement, height difference and thermal expension [rad]
    phs_est = B1@b_hat
    phs_dis =  B1[:,:2]@b_hat[:2]
    phs_height = B1[:,2]*b_hat[2]
    phs_ther = B1[:,3]*b_hat[3]

    # Get the phase for estimated DD observation, non-thermal displacement, height difference and thermal expension [mm]
    dis_est = phs_est*(-wavelength/(4*np.pi)*1000)
    dis_dis = phs_dis*(-wavelength/(4*np.pi)*1000)
    dis_height = phs_height*(-wavelength/(4*np.pi)*1000)
    dis_ther = phs_ther*(-wavelength/(4*np.pi)*1000)

    # Get the wrapped and unwrapped phase [mm]
    dis_wrapped = phs_wrapped*(-wavelength/(4*np.pi)*1000)
    dis_unw = phs_unw*(-wavelength/(4*np.pi)*1000)

    ds_out = stm_pnt.copy()
    ds_out = ds_out.assign(dis_wrapped=((ds_out.dims), np.expand_dims(dis_wrapped, axis=0)))
    ds_out = ds_out.assign(dis_unw=((ds_out.dims), np.expand_dims(dis_unw, axis=0)))
    ds_out = ds_out.assign(dis_est=((ds_out.dims), np.expand_dims(dis_est, axis=0)))
    ds_out = ds_out.assign(dis_dis=((ds_out.dims), np.expand_dims(dis_dis, axis=0)))
    ds_out = ds_out.assign(dis_height=((ds_out.dims), np.expand_dims(dis_height, axis=0)))
    ds_out = ds_out.assign(dis_ther=((ds_out.dims), np.expand_dims(dis_ther, axis=0)))

    ds_out['vel'] = xr.DataArray(np.array([b_hat[1]]), dims=('space'), coords={'space': ds_out.space.values})
    ds_out['height'] = xr.DataArray(np.array([b_hat[2]]), dims=('space'), coords={'space': ds_out.space.values})
    ds_out['ther'] = xr.DataArray(np.array([b_hat[3]]), dims=('space'), coords={'space': ds_out.space.values})

    return ds_out

In [8]:
def full_batch_chunk(ds, sd_complex_ref, slc_quality_ref, h2ph_ref):
    """
    Given temporal differnces of point i and point j we compute the double difference phase 
    We also compute the quality of the arc time series 

    point i is the reference point and is subtracted from point j:
    sd_complex_conj_i = sd_complex_i.conj()
    dd_arc = sd_complex_j*sd_complex_conj_i

    Remark that the output of the quality is the sigma of the DD phase. 

    Example:
        dd_arc, Qyy_diagonal_sigma = compute_dd(sd_complex_i, sd_complex_j,sd_quality_i, sd_quality_j, plot, dates, mother_idx)


    Code is as follows:
        sd_complex_conj_i = sd_complex_i.conj()
        dd_arc = sd_complex_j*sd_complex_conj_i
       
        Qyy_diagonal_sigma = np.sqrt((sd_quality_i)**2 + (sd_quality_j)**2)
    """
    ds_out = ds.copy()
    slc_complex_pnts =  ds['sd_complex']
    # slc_quality_pnts = ds['slc_quality']
    sd_complex_conj_ref = sd_complex_ref.conj()
    dd_arc = slc_complex_pnts*sd_complex_conj_ref

    # # # Compute the diagonal of the VCM of the dd phases
    # Qyy_diagonal = np.sqrt((slc_quality_ref)**2 + (slc_quality_pnts)**2)

    ds_out['dd_complex'] = dd_arc
    # ds_out['Qyy_diagonal'] = Qyy_diagonal
    # ds_out['h2ph_arc'] = h2ph_arc

    # # Load chunk in memory
    # ds_out = ds_out.compute()

    groups = ds_out.groupby("space")
    ds_analysis = groups.map(
        full_batch_one_group,
        slc_quality_ref = slc_quality_ref, 
        h2ph_ref = h2ph_ref
    )

    return ds_analysis

In [9]:
stm_new2 = xr.map_blocks(full_batch_chunk, stm_new_in, kwargs = {"sd_complex_ref": sd_complex_ref_com, "slc_quality_ref": slc_quality_ref_com, "h2ph_ref": h2ph_ref_com}, template = stm_new_in)
stm_new2