In [1]:
import numpy as np
import xsimlab as xs

# A simple phytoplankton chemostat model in 2 Dimensions

Based on the first model prototype, I worked on running it in a higher dimensional setting in the xsimlab.

Check the bottom of this jupyter notebook for 2D model output!

# Physical Environment
A process defining the model dimensions and passing it to other processes

In [2]:
@xs.process
class PhysicalEnvironment:
    """
    This physical environment provides a base dimension (0D), that is inherited by other components,
    so that all components can be group at grid points of a larger grid
    
    can be extended to higher dimensions via another grid process, defining 'grid_dims'
    """
    dim_label = xs.variable(default='Env')
    
    Env = xs.index(dims='Env')
    
    grid_dims = xs.variable(intent='inout')
    
    # Input
    dims = xs.variable(intent='out')
    Env_dim = xs.variable(default=1)
    
    def initialize(self):
        self.Env = np.array([1])

        self.dims = self.grid_dims + (self.Env_dim, 1)

## Grid

In [3]:
@xs.process
class GridXY:
    """
    This process supplies the Grid dimensions to the Physical Environment
    """
    # Dimension labels and indices
    x_label = xs.variable(default='x')
    y_label = xs.variable(default='y')
    
    x = xs.index(dims='x')
    y = xs.index(dims='y')
    
    # Input
    x_dim = xs.variable(intent='in', description='length of dimension, x direction')
    y_dim = xs.variable(intent='in', description='length of dimension, y direction')
    
    dx = xs.variable(intent='out', description='grid distance in regular grid, x direction')
    dy = xs.variable(intent='out', description='grid distance in regular grid, y direction')
    
    grid_dims = xs.foreign(PhysicalEnvironment, 'grid_dims', intent='out')
    
    def initialize(self):
        self.dx, self.dy = 10/np.array([self.x_dim, self.x_dim], dtype='float64')
        
        self.x = np.arange(self.x_dim)
        self.y = np.arange(self.y_dim)
        
        self.grid_dims = (self.x_dim, self.y_dim)

# Components
Processes defining and tracking the state variables of our model

In [4]:
@xs.process
class Component:
    """
    Basis for all components, defines the calculation of fluxes and state.
    specific fluxes, variables, and parameters need to be defined in subclass.
    """
    @xs.runtime(args="step_delta")
    def run_step(self, dt):
        self.delta = sum((v for v in self.fluxes)) * dt  # multiply by time step

    def finalize_step(self):
        self.state += self.delta

    
@xs.process
class SingularComp(Component):
    dim = xs.variable(intent='out')
    #e.g. N
    
    def initialize(self):
        self.dim = 1

        
@xs.process
class Nutrient(SingularComp):
    # create the own N dimension
    dim_label = xs.variable(default='N')
    N = xs.index(dims='N')
    
    state = xs.variable(intent='inout', dims=[('N'),('Env','N'),('x','y','Env','N')])
    fluxes = xs.group('N_flux')
    
    def initialize(self):
        super(Nutrient, self).initialize()
        
        self.N = np.arange(self.dim)
    
    
@xs.process
class MultiComp(Component):
    #e.g. P, Z 
    dim = xs.variable(intent='inout')
    
    
@xs.process 
class Phytoplankton(MultiComp):
    dim_label = xs.variable(default='P')
    P = xs.index(dims='P')
    
    state = xs.variable(intent='inout', dims=[('P'),('Env','P'),('x','y','Env','P')])
    fluxes = xs.group('P_flux')
    
    halfsat = xs.variable(intent='inout', dims=[('P'),('Env','P'),('x','y','Env','P')])
    mortality_rate = xs.variable(intent='inout', dims=[('P'),('Env','P'),('x','y','Env','P')])
    
    def initialize(self):
        self.P = np.arange(self.dim)

# Fluxes
Processes affecting the state variables. Each flux process is a term in the system of differential equations that underly this model.

In [5]:
@xs.process
class NutrientUptake:
    """This is an example for a MultiComp interacting with a SingularComp"""
    Model_dims = xs.foreign(PhysicalEnvironment, 'dims')
    
    N = xs.foreign(Nutrient, 'state')
    P = xs.foreign(Phytoplankton, 'state')
    
    N_uptake = xs.variable(dims=('x','y','Env','N'), intent='out', groups='N_flux')
    P_growth = xs.variable(dims=('x','y','Env','P'), intent='out', groups='P_flux')
    
    P_halfsat = xs.foreign(Phytoplankton, 'halfsat')
    
    NutLim = xs.variable(intent='out')
    
    @property
    def NutrientLimitation(self):
        lim = self.N / (self.P_halfsat + self.N)
        #print(lim.shape, np.zeros_like(self.N).shape)
        return lim
    
    def initialize(self):
        self.N_uptake = np.zeros_like(self.N)
        self.P_growth = np.zeros_like(self.P)
    
    def run_step(self):
        # calculate Nutrient limitation:
        self.NutLim = np.array(self.NutrientLimitation, dtype='float64')
        
        self.P_growth = self.NutLim * self.P
        
        # since there is only a single N, that dimension is summed up via "axis = -1"
        self.N_uptake = - np.sum(self.P_growth, axis = -1, keepdims = True)  # negative flux


@xs.process
class PhytoplanktonMortality:
    """Quadratic mortality """
    Model_dims = xs.foreign(PhysicalEnvironment, 'dims')
    
    P = xs.foreign(Phytoplankton, 'state')
    
    P_mortality = xs.variable(dims=('x','y','Env','P'), intent='out', groups='P_flux')
    
    P_mortality_rate = xs.foreign(Phytoplankton, 'mortality_rate')
    
    def initialize(self):
        self.P_mortality = np.zeros_like(self.P)
    
    def run_step(self):
        self.P_mortality = - np.array(self.P_mortality_rate * self.P ** 2, dtype='float64')

# Forcing

Processes supplying the model forcing to Forcing Fluxes

In [6]:
@xs.process
class ChemostatForcing:
    """Here we initialise the Nutrient Input Forcing (also spatially defined)"""
    Model_dims = xs.foreign(PhysicalEnvironment, 'dims')
    
    N_0 = xs.variable(dims=('x','y','Env'), intent='out', static=True)
    
    def initialize(self):
        # initialize empty array
        self._N_0 =  np.tile(np.array(0., dtype='float64'), self.Model_dims) 
        
        # calculate the center area of grid
        halfway_x = int(self._N_0.shape[0]/2)
        halfway_y = int(self._N_0.shape[1]/2)
        dy_dx = int(sum([self._N_0.shape[0],self._N_0.shape[1]])/20)
        
        #add nutrient input at some cells (concentration 5)
        self._N_0[halfway_x-dy_dx:halfway_x+dy_dx,halfway_y-dy_dx:halfway_y+dy_dx,:] = np.array(5, dtype='float64')
        
        self.N_0 = self._N_0

## Forcing flux

In [7]:
@xs.process
class Mixing:
    """ This is a forcing flux """
    Model_dims = xs.foreign(PhysicalEnvironment, 'dims')
    
    N_0 = xs.foreign(ChemostatForcing, 'N_0')
    
    N = xs.foreign(Nutrient, 'state')
    N_input = xs.variable(dims=('x','y','Env','N'), intent='out', groups='N_flux')
    
    flowrate = xs.variable(intent='in')
    
    def initialize(self):
        self.N_input = np.zeros_like(self.N)
    
    def run_step(self):
        self.N_input = self.flowrate * self.N_0

## Grid Fluxes

diffusion between gridpoints

In [8]:
@xs.process
class GridExchange:
    """
    This process collects pairwise interaction between all adjacent gridpoints
    advection equation adapted from https://scipython.com/book/chapter-7-matplotlib/examples/the-two-dimensional-diffusion-equation/
    """
    Model_dims = xs.foreign(PhysicalEnvironment, 'dims')
    
    dx = xs.foreign(GridXY, 'dx')
    dy = xs.foreign(GridXY, 'dy')
    
    N = xs.foreign(Nutrient, 'state')
    P = xs.foreign(Phytoplankton, 'state')
    
    N_advected = xs.variable(dims=('x','y','Env','N'), intent='out', groups='N_flux')
    P_advected = xs.variable(dims=('x','y','Env','P'), intent='out', groups='P_flux')
    
    exchange_rate = xs.variable(intent='in')
    
    
    def advection(self, state, dt):
        # Propagate with forward-difference in time, central-difference in space
        advect = self.exchange_rate * dt * (
          (state[2:, 1:-1] - 2*state[1:-1, 1:-1] + state[:-2, 1:-1])/self.dx2
          + (state[1:-1, 2:] - 2*state[1:-1, 1:-1] + state[1:-1, :-2])/self.dy2 )
        return advect
        
    def initialize(self):
        self.N_advected = np.zeros_like(self.N)
        self.P_advected = np.zeros_like(self.P)
        
        self.dx2 = self.dx**2 
        self.dy2 = self.dy**2 
    
    @xs.runtime(args="step_delta")
    def run_step(self, dt):
        # indexing below defines that the boundaries are note affected by advection (i.e. highly simplified boundary condition -> nutrient source placed in the center)
        self.N_advected[1:-1, 1:-1] = self.advection(self.N,dt)
        self.P_advected[1:-1, 1:-1] = self.advection(self.P,dt)

# Model initialisation
processes that initialize model parameters and state variables from user input, to simplify the user interface

In [9]:
@xs.process
class ChemostatGridXYSetup:
    """ 
    This crucial process supplies the initial values to the components,
    more complicated parameter setup of MultiComps can be done here.
    """
    Model_dims = xs.foreign(PhysicalEnvironment, 'dims')
    
    # Input
    N_initval = xs.variable(intent='in', dims=[(),('N')])
    P_initval = xs.variable(intent='in', dims=[(),('P')])
    P_num = xs.variable(intent='in', dims=())
    
    # Initializes:
    P_halfsat = xs.foreign(Phytoplankton, 'halfsat', intent='out')
    P_mortality_rate = xs.foreign(Phytoplankton, 'mortality_rate', intent='out')
    
    N_state = xs.foreign(Nutrient, 'state', intent='out')
    P_state = xs.foreign(Phytoplankton, 'state', intent='out')
    P_dims = xs.foreign(Phytoplankton, 'dim', intent='out')
    
    
    def initialize(self):
        self.P_dims = self.P_num
        
        # initialize the state variables in the correct dimensions
        self.N_state = np.tile(np.array([self.N_initval], dtype='float64'),self.Model_dims)
        self.P_state = np.tile(np.array([self.P_initval/self.P_num for i in range(self.P_num)], dtype='float64'), self.Model_dims)
        
        # initialize the model parameters
        self.P_halfsat = np.tile(np.array([1.5], dtype='float64'), self.Model_dims)
        self.P_mortality_rate = np.tile(np.array([0.1 for i in range(self.P_num)], dtype='float64'), self.Model_dims)
        

# xsimlab model setup

In [10]:
DimModel = xs.Model({
    'Grid':GridXY,
    
    'Env':PhysicalEnvironment, 
    
    'N':Nutrient,'P':Phytoplankton, 
    
    'NP_uptake':NutrientUptake, 'P_Mortality':PhytoplanktonMortality, 
    
    'FX':ChemostatForcing, 'Mix':Mixing, 
    
    'GX':GridExchange,
    
    'MS':ChemostatGridXYSetup
})

In [11]:
DimModel

<xsimlab.Model (10 processes, 13 inputs)>
Grid
    y_label           [in]
    x_dim             [in] length of dimension, x direction
    x_label           [in]
    y_dim             [in] length of dimension, y direction
Env
    dim_label         [in]
    Env_dim           [in]
MS
    P_num             [in]
    P_initval         [in] () or ('P',) 
    N_initval         [in] () or ('N',) 
GX
    exchange_rate     [in]
FX
Mix
    flowrate          [in]
NP_uptake
N
    dim_label         [in]
P_Mortality
P
    dim_label         [in]

# supply time, parameters and output

this model allows any number of phytoplankton types (or size classes) to be initialized within each cell. For simplicity "P_num" is 1 here:

In [12]:
DimModel_in = xs.create_setup(
    model=DimModel,
    clocks={   
        'time': np.linspace(0,30,1000),
    },
    master_clock='time',
    input_vars={
        'Grid__x_dim':100,
        'Grid__y_dim':100,
        
        'MS':{
            'N_initval':0.01,
            'P_num':1,
            'P_initval':.1
        },
        
        'Mix__flowrate':1.,
        
        'GX__exchange_rate':1.5
        
    },
    output_vars={
        # state of components as output
        'N__state':'time',
        'P__state':'time',
        
        # fluxes stored for diagnostic purposes
        'Mix__N_input':'time'
    }
)

# Model Run

In [13]:
from xsimlab.monitoring import ProgressBar

with ProgressBar(frontend='console'):
    with DimModel:
        DimModel_out = DimModel_in.xsimlab.run()

██████████ 100% | Simulation finished in 00:37 


In [14]:
DimModel_out

# Model output

below the 2d grid of our model is plotted, with the 'time' that can be controlled via the slider next to it

In [19]:
import hvplot.xarray
import holoviews as hv

import matplotlib.pyplot as plt

## Nutrient

The nutrient has a small square influx in the middle of the grid, over time the first influx of nutrients is taken up by phytoplankton that consumes it, reaching a steady state quickly.

In [20]:
N_out = DimModel_out.N__state.hvplot.image(
    x='x', y='y', clim=(0, 5),
    width=550, height=550,
    cmap=plt.cm.viridis, groupby='time'
)

N_out

## Phytoplankton

There is a small concentration of nutrient and phytoplankton initializedin each cell, that fuels a short growth across the grid. Additional nutrients flowing in at the middle are advected towards edges creating a gradient in phytoplankton biomass.

In [21]:
P_out = DimModel_out.P__state.hvplot.image(
    x='x', y='y', clim=(0, 5),
    width=550, height=550,
    cmap=plt.cm.viridis, groupby='time'
)

P_out

## Nutrient input (as a model diagnostic)

In [22]:
N_in = DimModel_out.Mix__N_input.hvplot.image(
    x='x', y='y', clim=(0, 5),
    width=550, height=550,
    cmap=plt.cm.viridis, groupby='time'
)

N_in