<table align="left">
    <tr>
        <td style="vertical-align: middle; padding-left: 0px; padding-right: 0px;">
            <a href="https://creativecommons.org/licenses/by/4.0/">
                <img src="https://licensebuttons.net/l/by/4.0/80x15.png" />
            </a>
        </td>
        <td style="vertical-align: middle; padding-left: 5px; padding-right: 0px;">
            <a href="https://opensource.org/licenses/MIT">
                <img src="https://img.shields.io/badge/License-MIT-green.svg" />
            </a>
        </td>
        <td style="vertical-align: middle; padding-left: 15px;">
            &copy; Guillaume Rongier
        </td>
    </tr>
</table>

# Basic StratigraPy example

This notebook shows a basic example of simulating some deltaic deposits following the principles defined by [Granjeon (1996)](https://theses.hal.science/tel-00648827v1) using StratigraPy.

### Imports

Let's first import all the required packages and components:

In [None]:
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import cmocean
import pyvista as pv

from landlab.components import FlowDirectorMFD, FlowAccumulator

from stratigrapy import RasterModelGrid
from stratigrapy.components import SeaLevelCalculator, WaterDrivenRouter
from stratigrapy.plot import extract_tie_centered_layers

## 1. Setup and simulation

StratigraPy is built on [Landlab](https://landlab.csdms.io/#) and follows the same principles: we need to setup a [grid](https://landlab.csdms.io/user_guide/grid.html) whose fields are iteratively modified by [components](https://landlab.csdms.io/user_guide/components.html). A key element of a grid added by StratigraPy is a StackedLayers, which record the stratigraphy and its changes. All the components of StratigraPy rely on this StackedLayers, and components from Landlab might not be compatible with StratigraPy because they can't access the StackedLayers.

Let's start by defining our grid. We'll run sediment erosion, transport, and deposition over 500$\,$000 years with a time step of 100 years:

In [None]:
timestep = 100.
runtime = 500000.
n_iterations = int(runtime/timestep)

This will lead to 5$\,$000 iterations. We use this information to pre-allocate the StackedLayers when creating a raster grid using the parameter `initial_allocation`:

In [None]:
grid = RasterModelGrid((25, 30),
                       xy_spacing=(2500., 2500.),
                       number_of_classes=2,
                       initial_allocation=n_iterations//100 + 100,
                       number_of_layers_to_fuse=100,
                       number_of_top_layers=100)

By default, some components from StratigraPy adds a new layer in the stack at each iteration, which quickly becomes computationally expensive. Instead, we'll fuse iterations together to only have 50 layers in our stack using the parameters `number_of_layers_to_fuse` and `number_of_top_layers` (for more details about this process, see [the second notebook](./2_managing-stratigraphy)). We will simulate two classes of sediments, which we need to already set up in the StackedLayers using `number_of_classes`.

Now that we have a grid, we can define the boundary conditions, leaving only the bottom boundary open:

In [None]:
grid.set_closed_boundaries_at_grid_edges(True, True, True, False)

Adding fields to the grid is exactly the same as [in Landlab](https://landlab.csdms.io/user_guide/grid.html#adding-data-to-a-landlab-grid-element-using-fields). We first add the initial topography as a sloped surface:

In [None]:
elevation = grid.add_zeros('topographic__elevation', at='node', clobber=True)
elevation += 0.003*(grid.y_of_node - 50000.)

And visualize it:

In [None]:
grid.imshow('topographic__elevation', var_name='Elevation', var_units='m', grid_units=['m', 'm'])

We can add a source of water at the top of the grid, spanning over two cells:

In [None]:
yx_source = [(23, 23), (14, 15)]
idx = np.ravel_multi_index(yx_source, grid.shape)
water_influx = grid.add_zeros('water__unit_flux_in', at='node', clobber=True)
water_influx[idx] = 5000. # m/yr

This water influx field will be multiplied by the drainage area to get the discharge, so it is in m/yr instead of m$^3$/yr. Let's visualize it:

In [None]:
grid.imshow('water__unit_flux_in', var_name='Water influx', var_units='m/yr', grid_units=['m', 'm'], cmap='Blues')

Our source of water is also a source of sediments, with 70% of the first class and 30% of the second:

In [None]:
sediment_influx = grid.add_field('sediment__unit_flux_in',
                                 np.zeros((grid.number_of_nodes, 2)),
                                 clobber=True)
sediment_influx[idx] = [0.7*50000., 0.3*50000.] # m3/yr

Contrary to the water, we direclty have a flux of sediment in m$^3$/yr.

Now we can move to defining the components that will update those fields and the stratigraphy. We start with the sea level, which is defined as two sine curves with different wavelengths and amplitutes added together:

In [None]:
slc = SeaLevelCalculator(grid, wavelength=[100000., 10000.], amplitude=[25., 2.5])

We also need components to compute the discharge, which are just the regular components from Landlab:

In [None]:
fd = FlowDirectorMFD(grid, partition_method='slope', diagonals=True)

In [None]:
fa = FlowAccumulator(grid, flow_director=fd)

Finally, we need a component for sediment transport, which is based on the stream power law following [Granjeon (1996)](https://theses.hal.science/tel-00648827v1):

In [None]:
wdd = WaterDrivenRouter(grid,
                          transportability_cont=[1e-8, 1e-8],
                          transportability_mar=[4e-10, 2e-10],
                          wave_base=15.,
                          max_erosion_rate_sed=1e-2,
                          max_erosion_rate_br=1e-12,
                          bedrock_composition=[0.7, 0.3],
                          fields_to_track='bathymetric__depth')

This component also tracks the field `bathymetric__depth` from the grid so that it is recorded in the stratigraphy.

Similarly to Landlab, each component from StratigraPy has a function `run_one_step` to call at each iteration to run the simulation, while the function `fuse` of the StackedLayers fuse the iterations togethers so that they fit in a reduced number of stratigraphic layers:

In [None]:
time = np.empty(n_iterations)
sea_level = np.empty(n_iterations)
for i in tqdm(range(n_iterations)):
    slc.run_one_step(timestep)
    time[i] = slc._time
    sea_level[i] = grid.at_grid['sea_level__elevation']
    fa.run_one_step()
    wdd.run_one_step(timestep)
    grid.stacked_layers.fuse(time=np.mean, bathymetric__depth=np.mean)
grid.stacked_layers.fuse(finalize=True, time=np.mean, bathymetric__depth=np.mean)

In that loop, we've also recorded the sea level variations, which can now plot:

In [None]:
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(time, sea_level)
ax.set(xlabel='Time (yr)', ylabel='Sea level (m)');

Now let's have a look at the final topography:

In [None]:
fig, ax = plt.subplots()

raster_x = grid.x_of_node[grid.core_nodes].reshape(grid.cell_grid_shape)
raster_y = grid.y_of_node[grid.core_nodes].reshape(grid.cell_grid_shape)
raster_z = grid.at_node['topographic__elevation'][grid.core_nodes].reshape(grid.cell_grid_shape)

pc = ax.pcolormesh(raster_x, raster_y, raster_z, cmap=cmocean.cm.topo,
                   norm=mcolors.CenteredNorm(grid.at_grid['sea_level__elevation']))
fig.colorbar(pc, ax=ax, label='Elevation (m)')

ax.set(xlabel='y (m)', ylabel='y (m)', aspect='equal');

## 2. 2D visualization

The raster grid in StratigraPy has a function `plot_layers`, which we can use the visualize cross sections through the stratigraphy using [matplotlib](https://matplotlib.org/), for instance for the fraction of the second sediment class in the middle of the grid along the y-axis:

In [None]:
fig, ax = plt.subplots(figsize=(10, 3.5))

raster_y = grid.y_of_node[grid.core_nodes].reshape(grid.cell_grid_shape)[:, 14]
raster_z = grid.at_node['topographic__elevation'][grid.core_nodes].reshape(grid.cell_grid_shape)[:, 14]
# Sea level
fill_sea = ax.fill_between(raster_y, raster_z, grid.at_grid['sea_level__elevation'],
                           color='#c6dbef', zorder=0)
# Bedrock
ax.fill_between(raster_y, raster_z, raster_z.min(), color='#d9d9d9')

# Sediments
pc = grid.plot_layers(ax, 'composition', i_class=1, mask_wedges=True, cmap='pink')
fig.colorbar(pc[0], ax=ax, label='Fraction of the second sediment class')

ax.set(xlabel='y (m)', ylabel='z (m)');

And we can do the same along the x-axis:

In [None]:
fig, ax = plt.subplots(figsize=(10, 3.5))

raster_x = grid.x_of_node[grid.core_nodes].reshape(grid.cell_grid_shape)[11]
raster_z = grid.at_node['topographic__elevation'][grid.core_nodes].reshape(grid.cell_grid_shape)[11]
# Sea level
fill_sea = ax.fill_between(raster_x, raster_z, grid.at_grid['sea_level__elevation'],
                           color='#c6dbef', zorder=0)

# Sediments
pc = grid.plot_layers(ax, 'composition', i_x=None, i_y='middle', i_class=1, mask_wedges=True, cmap='pink')
fig.colorbar(pc[0], ax=ax, label='Fraction of the second sediment class')

ax.set(xlabel='x (m)', ylabel='z (m)');

Or plot the fraction of the first sediment class at the surface:

In [None]:
fig, ax = plt.subplots()

pc = grid.plot_layers(ax, 'composition', i_x=None, i_layer='top', i_class=0,
                      shading='nearest', cmap='pink')
fig.colorbar(pc, ax=ax, label='Fraction of the first sediment class')

ax.set(xlabel='y (m)', ylabel='y (m)', aspect='equal');

We can also plot any field tracked by the components. Time is always tracked by default:

In [None]:
fig, ax = plt.subplots(figsize=(10, 3.5))

raster_y = grid.y_of_node[grid.core_nodes].reshape(grid.cell_grid_shape)[:, 14]
raster_z = grid.at_node['topographic__elevation'][grid.core_nodes].reshape(grid.cell_grid_shape)[:, 14]
# Sea level
fill_sea = ax.fill_between(raster_y, raster_z, grid.at_grid['sea_level__elevation'],
                           color='#c6dbef', zorder=0)
# Bedrock
ax.fill_between(raster_y, raster_z, raster_z.min(), color='#d9d9d9')

# Sediments
pc = grid.plot_layers(ax, 'time')
fig.colorbar(pc[0], ax=ax, label='Deposition time (yr)')

ax.set(xlabel='y (m)', ylabel='z (m)');

And here we also asked to track the bathymetry:

In [None]:
fig, ax = plt.subplots(figsize=(10, 3.5))

raster_y = grid.y_of_node[grid.core_nodes].reshape(grid.cell_grid_shape)[:, 14]
raster_z = grid.at_node['topographic__elevation'][grid.core_nodes].reshape(grid.cell_grid_shape)[:, 14]
# Sea level
fill_sea = ax.fill_between(raster_y, raster_z, grid.at_grid['sea_level__elevation'],
                           color='#c6dbef', zorder=0)
# Bedrock
ax.fill_between(raster_y, raster_z, raster_z.min(), color='#d9d9d9')

# Sediments
pc = grid.plot_layers(ax, 'bathymetric__depth', cmap=cmocean.cm.deep)
fig.colorbar(pc[0], ax=ax, label='Water depth (m)')

ax.set(xlabel='y (m)', ylabel='z (m)');

StratigraPy's visualization abilities are still rough, and don't render erosion quite properly for instance, but can still be used to analyze simulation results.

## 3. 3D visualization

We can also visualize the stratigraphy in 3D using PyVista, which takes a more indirect path. We first need to call the function `extract_tie_centered_layers`, which prepares the right inputs for PyVista's class `StructuredGrid`:

In [None]:
x, y, z, layers = extract_tie_centered_layers(grid, 'bathymetric__depth', axis=2)

Then we can visualize the full stratigraphy:

In [None]:
p = pv.Plotter()

for l in range(len(layers)):
    mesh = pv.StructuredGrid(x[l:l + 2], y[l:l + 2], z[l:l + 2])
    mesh['Water depth (m)'] = np.tile(layers[l], (2, 1, 1)).T.ravel()
    p.add_mesh(mesh, scalars='Water depth (m)', cmap=cmocean.cm.deep, show_edges=True)

p.set_scale(zscale=150)
p.show()

Or sections through it:

In [None]:
p = pv.Plotter()

for l in range(len(layers)):
    mesh = pv.StructuredGrid(x[l:l + 2], y[l:l + 2], z[l:l + 2])
    mesh['Water depth (m)'] = np.tile(layers[l], (2, 1, 1)).T.ravel()
    p.add_mesh(mesh.slice_along_axis(n=10, axis='x'), scalars='Water depth (m)',
               clim=[np.nanmin(layers), np.nanmax(layers)], cmap=cmocean.cm.deep)
    p.add_mesh(mesh.slice_along_axis(n=10, axis='y'), scalars='Water depth (m)',
               clim=[np.nanmin(layers), np.nanmax(layers)], cmap=cmocean.cm.deep)

p.set_scale(zscale=150)
p.show()