# Xarray-simlab showcase

In [None]:
import numpy as np
import xsimlab as xs

Let's import a xarray-simlab model from the fastscape library (https://github.com/fastscape-lem/fastscape)

In [None]:
from fastscape.models import sediment_model

## Model inspection

In [None]:
sediment_model

In [None]:
sediment_model.visualize()

## Model customization 

In [None]:
from fastscape.processes import Escarpment

model = (
    sediment_model
    .update_processes({'init_topography': Escarpment})
    .drop_processes(['diffusion', 'uplift'])
)

In [None]:
model

## Simulation setup

In [None]:
%load_ext xsimlab.ipython

In [None]:
# %create_setup model -v -d
import xsimlab as xs

ds_in = xs.create_setup(
    model=model,
    clocks={},
    input_vars={
        # nb. of grid nodes in (y, x)
        'grid__shape': ,
        # total grid length in (y, x)
        'grid__length': ,
        # node status at borders
        'boundary__status': 'fixed_value',
        # location of the scarp's left limit on the x-axis
        'init_topography__x_left': ,
        # location of the scarp's right limit on the x-axis
        'init_topography__x_right': ,
        # elevation on the left side of the scarp
        'init_topography__elevation_left': ,
        # elevation on the right side of the scarp
        'init_topography__elevation_right': ,
        # MFD partioner slope exponent
        'flow__slope_exp': 0.0,
        # drainage area exponent
        'spl__area_exp': 0.4,
        # slope exponent
        'spl__slope_exp': 1,
        # bedrock channel incision coefficient
        'spl__k_coef_bedrock': ,
        # soil (sediment) channel incision coefficient
        'spl__k_coef_soil': ,
        # detached bedrock transport/deposition coefficient
        'spl__g_coef_bedrock': ,
        # soil (sediment) transport/deposition coefficient
        'spl__g_coef_soil': ,
    },
    output_vars={}
)


In [None]:
# %create_setup model -v -d
import xsimlab as xs

ds_in = xs.create_setup(
    model=model,
    clocks={
        'time': np.linspace(0, 4e5, 201),
        'out': np.linspace(0, 4e5, 101),
    },
    master_clock='time',
    input_vars={
        # nb. of grid nodes in (y, x)
        'grid__shape': [101, 201],
        # total grid length in (y, x)
        'grid__length': [1e4, 2e4],
        # node status at borders
        'boundary__status': ['fixed_value', 'core', 'looped', 'looped'],
        # location of the scarp's left limit on the x-axis
        'init_topography__x_left': 1e4,
        # location of the scarp's right limit on the x-axis
        'init_topography__x_right': 1e4,
        # elevation on the left side of the scarp
        'init_topography__elevation_left': 0.,
        # elevation on the right side of the scarp
        'init_topography__elevation_right': 1e3,
        # MFD partioner slope exponent
        'flow__slope_exp': 1.0,
        # drainage area exponent
        'spl__area_exp': 0.4,
        # slope exponent
        'spl__slope_exp': 1,
        # bedrock channel incision coefficient
        'spl__k_coef_bedrock': 1e-4,
        # soil (sediment) channel incision coefficient
        'spl__k_coef_soil': 1e-4,
        # detached bedrock transport/deposition coefficient
        'spl__g_coef_bedrock': 0.5,
        # soil (sediment) transport/deposition coefficient
        'spl__g_coef_soil': 0.5,
    },
    output_vars={
        'topography__elevation': 'out',
        'erosion__rate': 'out',
        'drainage__area': 'out',
    }
)


In [None]:
ds_in

## Model run

In [None]:
with xs.monitoring.ProgressBar():
    ds_out = ds_in.xsimlab.run(model=model)

In [None]:
ds_out

## Model output visualization

In [None]:
(ds_out
 .topography__elevation
 .isel(out=range(0, 45, 5))
 .plot.pcolormesh('x', 'y', col='out', col_wrap=3, aspect=2, cmap='cividis')
);

In [None]:
import hvplot.xarray

In [None]:
ds_out.topography__elevation.hvplot.image(x='x', y='y', groupby='out', cmap='cividis', data_aspect=1)

In [None]:
from ipyfastscape import TopoViz3d

In [None]:
app = TopoViz3d(ds_out, time_dim='out')
app.components['coloring'].set_color_var('drainage__area')
app.components['coloring'].set_colormap('Blues')
app.components['coloring'].set_color_scale(log=True)
app.components['vertical_exaggeration'].set_factor(3)
app.show()

In [None]:
app.widget.close()

## Sensitivity analysis

In [None]:
from dask.diagnostics import ProgressBar

In [None]:
with model, ProgressBar():
    ds_sensitivity = (
        ds_in
        .xsimlab.update_vars(
            input_vars={'flow__slope_exp': ('batch', [0, 0.5, 1, 3, 5])}
        )
        .xsimlab.run(batch_dim='batch', parallel=True, store='sensitivity.zarr', scheduler='processes')
        .swap_dims({'batch': 'flow__slope_exp'})
    )

In [None]:
ds_sensitivity

In [None]:
log_area = np.log(ds_sensitivity.drainage__area)

log_area.hvplot.image(x='x', y='y', groupby=['flow__slope_exp', 'out'], cmap='Blues', data_aspect=1)

## Model customization (adding a new process)

In [None]:
from fastscape.processes import DifferentialStreamPowerChannelTD


@xs.process
class ClimateChange:
    """Sudden change in erosion efficiency."""
    
    time = xs.variable(description='time when the change occurs')
    
    k_soil = xs.foreign(DifferentialStreamPowerChannelTD, 'k_coef_soil', intent='out')
    k_bedrock = xs.foreign(DifferentialStreamPowerChannelTD, 'k_coef_bedrock', intent='out')
    
    def initialize(self):
        self.changed = False
        self.k_soil = self.k_bedrock = 3e-5
        
    @xs.runtime(args='step_start')
    def run_step(self, t):
        if not self.changed and t >= self.time:
            self.k_soil = self.k_bedrock = 1e-4
            self.changed = True



In [None]:
new_model = model.update_processes({'climate_change': ClimateChange})

In [None]:
with new_model, xs.monitoring.ProgressBar():
    ds_clim_change = (
        ds_in
        .xsimlab.update_vars(
            input_vars={'climate_change__time': 2e5}
        )
        .xsimlab.run()
    )

In [None]:
ds_clim_change.erosion__rate.hvplot.image(
    x='x', y='y', groupby='out', cmap='RdYlGn', data_aspect=1,
    widget_type='scrubber', widget_location='bottom', clim=(-5e-3, 5e-3)
)