# Run batches of FastScape simulations

This is still work in progress, but the idea is to leverage the ``dask`` library (https://docs.dask.org/en/latest/) and its integration with ``xarray`` to easily run, analyse and visualize batches of model runs, e.g., in the cases of sentitivity analyses or inversions.

Let's import some packages first:

In [None]:
import numpy as np
import xarray as xr
import xsimlab as xs
import fastscape

In [None]:
print('xarray-simlab version: ', xs.__version__)
print('fastscape version: ', fastscape.__version__)

You won't need to run the cell below when support for running batches of simulations will be added in ``xarray-simlab``.

In [None]:
import patch_xsimlab

## Settting up a dask cluster

The dask cluster will take care of distributing tasks (like running models, do post-processing or even visualization) to workers that will execute it in parallel.

Here below we set up a cluster of 15 workers and connect to it.

In [None]:
from distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=15)

client = Client(cluster)

client

## Import the model and the model base setup

Let's import the basic (standard) FastScape model

In [None]:
from fastscape.models import basic_model

The setup create in the ``run_basic_model`` is reused here as a base setup

In [None]:
in_ds = xr.load_dataset('basic_input.nc')

## Example 1: run models for different values of $K$ (stream power law)

We just need to set the corresponding variable with different values, and then use ``.xsimlab.run_model_batch``: 

In [None]:
in_vars = {'spl__k_coef': ('batch', np.linspace(1e-5, 1e-4, 20))}

In [None]:
with basic_model:
    out_ds = (
        in_ds.xsimlab.update_vars(input_vars=in_vars)
             .xsimlab.run_model_batch('batch')
             .rename(batch='spl_k')
             .set_index(x='grid__x', y='grid__y', spl_k='spl__k_coef')
    )
    
out_ds

Plotting using ``hvplot`` is just as easy as with single model runs. We can easily explore the parameter space.

In [None]:
import hvplot.xarray
import matplotlib.pyplot as plt

out_ds.topography__elevation.hvplot.image(
    x='x', y='y', cmap=plt.cm.viridis, groupby=['spl_k', 'out'])

## Example 2: run models with different (random) initial conditions

This trick below could be nicer (e.g., explicitly setting random seeds).

In [None]:
in_vars = {'spl__k_coef': ('batch', np.linspace(1e-5, 1e-5, 20))}

In [None]:
with basic_model:
    out_ds2 = (
        in_ds.xsimlab.update_vars(input_vars=in_vars)
             .xsimlab.run_model_batch('batch')
             .set_index(x='grid__x', y='grid__y')
    )
    
out_ds2

Extracting statistics along the batch dimension is very easy (and it's executed in parallel):

In [None]:
avg = out_ds2.topography__elevation.mean(dim='batch')

avg.compute()

Let's compare one cross section of the result above with swath profile extracted from one simulation:

In [None]:
avg.isel(out=-1).sel(x=10000).plot();

In [None]:
out_ds2.topography__elevation.isel(batch=0, out=-1).mean(dim='x').plot();