# modelbase 2.0

modelbase is a Python package for metabolic modeling and analysis.

In [None]:
from __future__ import annotations

import itertools as it
from pathlib import Path
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.optimize import minimize

import modelbase2 as mb2
from example_models import (
    get_example1,
    get_lin_chain_two_circles,
    get_linear_chain_2v,
    get_upper_glycolysis,
)
from modelbase2 import (
    Cache,
    Derived,
    LabelMapper,
    LinearLabelMapper,
    Model,
    Simulator,
    fit,
    make_protocol,
    mc,
    mca,
    npe,
    plot,
    sbml,
    scan,
)
from modelbase2.distributions import LogNormal, Uniform, sample
from modelbase2.surrogates import train_torch_surrogate
from modelbase2.types import unwrap, unwrap2

if TYPE_CHECKING:
    from modelbase2.fit import ResidualFn


def randomise_parameters(m: Model, seed: int = 42) -> Model:
    rng = np.random.default_rng(seed=seed)
    m.update_parameters(
        (pd.Series(m.parameters) + rng.normal(0, 0.125, len(m.parameters))).to_dict()
    )
    return m

## Building your first model

Let's say you want to model the following chemical network of a linear chain of reactions

$$ \Large \varnothing \xrightarrow{v_0} S \xrightarrow{v_1} P \xrightarrow{v_2} \varnothing $$

We can translate this into a system of ordinary differential equations (ODEs)

$$\begin{align*}
\frac{dS}{dt} &= v_0 - v_1     \\
\frac{dP}{dt} &= v_1 - v_2 \\
\end{align*}
$$

an then choose rate equations for each rate to get the flux vector $v$

$$\begin{align*}
    v_0 &= k_0 \\
    v_1 &= k_1 * S \\
    v_2 &= k_2 * P \\
\end{align*}$$

Let's begin by defining rate functions $\textbf{v}$.  
Note that these should be **general** and **re-usable** whenever possible, to make your model clear to people reading it.  
Try to give these functions names that are meaningful to your audience, e.g. a rate function `k * s` could be named **proportional** or **mass-action**.

In [None]:
def constant(k: float) -> float:
    return k


def proportional(k: float, s: float) -> float:
    return k * s

Next, we create our model.  
Note, that we use a single function that returns the model instead of defining it globally.  
This allows us to quickly re-create the model whenever we need a fresh version of it.  

Let's step through the code below:

We first add parameters to the model using `.add_parameters({name: value})`.  

Next, we add variables using `.add_variables({name: initial_value})`.  

Finally, we add the three reactions by using 

```python
.add_reaction(
    name,              # the internal name for the reaction
    fn=...,            # a python function to be evaluated
    args=[name, ...]   # the arguments passed to the python function
    stoichiometry={    # a mapping encoding how much the variable `name`
        name: value    # is changed by the reaction
    },
)
```

In [None]:
def linear_chain_2cpds() -> Model:
    return (
        Model()
        .add_parameters({"k_in": 1, "k_1": 1, "k_out": 1})
        .add_variables({"S": 0, "P": 0})
        .add_reaction(
            "v0",
            fn=constant,
            args=["k_in"],
            stoichiometry={"S": 1},
        )
        .add_reaction(
            "v1",
            fn=proportional,
            args=["k_1", "S"],
            stoichiometry={"S": -1, "P": 1},
        )
        .add_reaction(
            "v2",
            fn=proportional,
            args=["k_out", "P"],
            stoichiometry={"P": -1},
        )
    )

We can then simulate the model by passing it to a `Simulator` and simulate a time series using `.simulate(t_end)`.  
Finally, we can obtain the concentrations and fluxes using `get_concs_and_fluxes`.  

While you can directly plot the `pd.DataFrame`s, modelbase supplies a variety of plots in the `plot` namespace that are worth checking out.  

In [None]:
c, v = (
    Simulator(linear_chain_2cpds())  # initialise the simulator
    .simulate(5)  # simulate until t_end = 5
    .get_concs_and_fluxes()  # return pd.DataFrames for concentrations and fluxes
)

if c is not None and v is not None:
    fig, (ax1, ax2) = plot.two_axes(figsize=(6, 3))
    _ = plot.lines(c, ax=ax1)
    _ = plot.lines(v, ax=ax2)

    # Never forget to labelr you axes :)
    ax1.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
    ax2.set(xlabel="time / a.u.", ylabel="flux / a.u.")
    plt.show()

Note, that we checked whether the results were `None` in case the simulation failed.  
Explicitly checking using an `if` clause is the prefered error handling mechanism.  

If you are **sure** the simulation won't fail, and still want your code to be type-safe, you can use `unwrap` and `unwrap2`.  

```python
c = unwrap(Simulator(model).simulate(10).get_concs())
c, v = unwrap2(Simulator(model).simulate(10).get_concs_and_fluxes())
```

Note that these functions will throw an error if the values are `None`, which potentially might crash your programs.

## Derived quantities

Frequently it makes sense to derive one quantity in a model from other quantities.  
This can be done for

- parameters derived from other parameters
- variables derived from parameters or other variables
- stoichiometries derived from parameters or variables (more on this later)

modelbase chooses automatically whether the quantitiy is a derived parameter or variable based on the `args`.  

In [None]:
def moiety_1(x1: float, total: float) -> float:
    return total - x1


def model_derived() -> Model:
    return (
        Model()
        .add_variables({"ATP": 1.0})
        .add_parameters({"ATP_total": 1.0, "k_base": 1.0, "e0_atpase": 1.0})
        # derived parameter because all args are parameters
        .add_derived("k_atp", proportional, ["k_base", "e0_atpase"])
        # derived variable because one arg is a variable
        .add_derived("ADP", moiety_1, ["ATP", "ATP_total"])
        .add_reaction(
            "ATPase", proportional, args=["k_atp", "ATP"], stoichiometry={"ATP": -1}
        )
    )


c, v = Simulator(model_derived()).simulate(10).get_full_concs_and_fluxes()
if c is not None:
    fig, ax = plot.lines(c)
    ax.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
    plt.show()

## Introspection

If the simulation didn't show the expected results, it is usually a good idea to try to pinpoint the error.  
`modelbase` offers a variety of methods to access intermediate results.  

The first is to check whether all **derived parameters** were calculate correctly.  
For this, you can use the `get_args` method, which is named consistently with the `args` argument in all methods like `add_reaction`.

In [None]:
m = linear_chain_2cpds()
m.get_args({"S": 1.0, "P": 0.5})

If the `args` look fine, the next step is usually to check whether the rate equations are looking as expected

In [None]:
m = linear_chain_2cpds()
m.get_fluxes({"S": 1.0, "P": 0.5})

and whether the stoichiometries are assigned correctly

In [None]:
m = linear_chain_2cpds()
m.get_stoichiometries()

Lastly, you can check the generated right hand side

In [None]:
m = linear_chain_2cpds()
m.get_right_hand_side({"S": 1.0, "P": 0.5})

If any of the quantities above were unexpected, you can check the model interactively by accessing the various collections.  

> Note: the returned quantities are **copies** of the internal data, modifying these won't have any effect on the model

In case you model contains derived quantitites you can access the derived quantities using `.derived`.  
Note that this returns a **copy** of the derived quantities, so editing it won't have any effect on the model.  

In [None]:
m = model_derived()

print(m.derived)


If you are interested in which category modelbase has placed the derived quantities, you can access `.derived_parameters` and `.derived_variables` as well. 

In [None]:
m = model_derived()
print(m.derived_parameters)
print(m.derived_variables)

## CRUD

The model has a complete **c**reate, **r**ead, **u**pdate, **d**elete API for all it's elements.  
The methods and attributes are named consistenly, with `add` instead of `create` and `get` instead of `read`.  
Note that the elements itself are accessible as `properties`, e.g. `.parameters` which will return **copies** of the data.  
Only use the supplied methods to change the internal state of the model.

Here are some example methods and attributes for parameters

| Functionality | Parameters                                                                              |
| ------------- | --------------------------------------------------------------------------------------- |
| Create        | `.add_parameter()`, `.add_parameters()`                                                 |
| Read          | `.parameters`, `.get_parameter_names()`                                                 |
| Update        | `.update_parameter()`, `.update_parameters()`, `.scale_parameter()`, `scale.parameters()` |
| Delete        | `.remove_parameter()`, `.remove_parameters()`                                           |

and variables

| Functionality | Variables                                                         |
| ------------- | ----------------------------------------------------------------- |
| Create        | `.add_variable()`, `.add_variables()`                             |
| Read          | `.variables`, `.get_variable_names()`, `get_initial_conditions()` |
| Update        | `.update_variable()`, `.update_variables()`                         |
| Delete        | `.remove_parameter()`, `.remove_parameters()`                     |


In [None]:
# Get model
m = linear_chain_2cpds()

# Calculate fluxes
print("Before update", m.get_fluxes({"S": 1.0, "P": 0.5}), sep="\n", end="\n\n")

# Update parameters
m.update_parameters({"k_in": 2.0})

# Calculate fluxes again
print("After update", m.get_fluxes({"S": 1.0, "P": 0.5}), sep="\n")

## Derived stoichiometries

To define derived stoichiometries you need to use the `Derived` class as a value in the stoichiometries.  

In [None]:
c, v = unwrap2(
    Simulator(
        Model()
        .add_parameters({"stoich": -1.0, "k": 1.0})
        .add_variables({"x": 1.0})
        .add_reaction(
            "name",
            proportional,
            args=["x", "k"],
            # Define derived stoichiometry here
            stoichiometry={"x": Derived(constant, ["stoich"])},
        )
    )
    .simulate(1)
    # Update parameter the derived stoichiometry depends on
    .update_parameter("stoich", -4.0)
    # Continue simulation
    .simulate(5)
    .get_concs_and_fluxes()
)

_, ax = plot.lines(c)
ax.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
plt.show()

## Advanced simulations

By default, the `Simulator` is initialised with the initial concentrations set in the `Model`.  
Optionally, you can overwrite the initial conditions using the `y0` argument.  

```python
Simulator(model, y0={name: value, ...})
```

In [None]:
c, v = unwrap2(
    Simulator(linear_chain_2cpds(), y0={"S": 2.0, "P": 0.0})  # initialise the simulator
    .simulate(10)  # simulate until t_end = 10
    .get_concs_and_fluxes()  # return pd.DataFrames for concentrations and fluxes
)

fig, (ax1, ax2) = plot.two_axes(figsize=(6, 3))
_ = plot.lines(c, ax=ax1)
_ = plot.lines(v, ax=ax2)

ax1.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="flux / a.u.")
plt.show()

### Time course

You can obtain the time course of integration using the `simulate` method.  
There are three ways how you can define the time points this function returns.  

1. supply the end time `t_end`
2. supply both end time and number of steps with `steps`
3. supply the exact time points to be returned using `time_points`

```python
simulate(t_end=10)
simulate(t_end=10, steps=10)
simulate(time_points=np.linspace(0, 10, 11))
```

> Note that these settings don't change the integration itself (e.g. tolerances)!

In [None]:
concs, fluxes = unwrap2(
    Simulator(get_linear_chain_2v())  # optionally supply initial conditions
    .simulate(t_end=10)
    .get_concs_and_fluxes()
)

fig, ax = plot.lines(concs)
ax.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
plt.show()

### Protocol time course

Protocols are used to make parameter changes discrete in time.  
This is useful e.g. for reproducing experimental time courses where a parameter was changed at fixed time points.  

In [None]:
protocol = make_protocol(
    {
        1: {"k1": 1},
        2: {"k1": 2},
        3: {"k1": 1},
    }
)
concs, fluxes = unwrap2(
    Simulator(get_linear_chain_2v())
    .simulate_over_protocol(protocol)
    .get_concs_and_fluxes()
)

fig, ax = plt.subplots()
plot.lines(concs, ax=ax)
plot.shade_protocol(protocol["k1"], ax=ax, alpha=0.1)
ax.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
plt.show()

### Steady-state

You can simulate until the model reaches a steady-state using the `simulate_to_steady_state` method.  


In [None]:
concs, fluxes = unwrap2(
    Simulator(get_linear_chain_2v())  # optionally supply initial conditions
    .simulate_to_steady_state()
    .get_concs_and_fluxes()
)

fig, ax = plot.bars(concs)
ax.set(xlabel="Variable / a.u.", ylabel="Concentration / a.u.")
plt.show()

## Parameter scans

`modelbase` has a variety of parameter scans available in the `scan` module.  
They differ by the kind of data returned by them, e.g. steady-states or time course of concentration and fluxes.

In all cases, the parameters which you want to scan over are passed as a `pandas.DataFrame`.

### Steady-state

The steady-state scan takes a `pandas.DataFrame` of parameters to be scanned as an input and returns the steady-states at the respective parameter values.  

In [None]:
res = scan.steady_state(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(1, 3, 11)}),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 3))
plot.lines(res.concs, ax=ax1)  # access concentrations by name
plot.lines(res.fluxes, ax=ax2)  # access fluxes by name

ax1.set(ylabel="Concentration / a.u.")
ax2.set(ylabel="Flux / a.u.")
plt.show()

All scans return some kind of result object, which allow multiple access patterns for convenience. 

Namely, the concentrations and fluxes can be accessed by name, unpacked or combined into a single dataframe.

In [None]:
# Access by name
_ = res.concs
_ = res.fluxes

# scan can be unpacked
concs, fluxes = res

# combine concs and fluxes as single dataframe
_ = res.results

#### Combinations

Often you want to scan over multiple parameters at the same time.  
The recommended way to do this is to use the `cartesian_product` function, which takes a `parameter_name: values` mapping and creates a `pandas.DataFrame` of their combinations from it (think nested for loop).  

In the case the parameters `DataFrame` contains more than one column, the returned `pandas.DataFrame` will contain a `pandas.MultiIndex`.  

In [None]:
res = scan.steady_state(
    get_linear_chain_2v(),
    parameters=mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 3),
            "k2": np.linspace(1, 2, 4),
        }
    ),
)

res.results.head()

You can plot the results of a **single variable** of this scan using a heatmap

In [None]:
plot.heatmap_from_2d_idx(res.concs, variable="x")
plt.show()

Or create heatmaps of all passed variables at once.  

In [None]:
plot.heatmaps_from_2d_idx(res.concs)
plt.show()

You can also combine more than two parameters, however, visualisation then becomes challenging.  

In [None]:
res = scan.steady_state(
    get_linear_chain_2v(),
    mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 3),
            "k2": np.linspace(1, 2, 4),
            "k3": np.linspace(1, 2, 4),
        }
    ),
)
res.results.head()

### Time courses

You can perform a time course for each of the parameter values.  
The index now also contains the time, so even for one parameter a `pandas.MultiIndex` is used.

In [None]:
tss = scan.time_course(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(1, 2, 11)}),
    time_points=np.linspace(0, 1, 11),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tss.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(tss.fluxes, ax=ax2)

ax1.set(xlabel="time / a.u.", ylabel="Concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="Flux / a.u.")
plt.show()

Again, this works for an arbitray number of parameters.

In [None]:
tss = scan.time_course(
    get_linear_chain_2v(),
    parameters=mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 11),
            "k2": np.linspace(1, 2, 4),
        }
    ),
    time_points=np.linspace(0, 1, 11),
)


fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tss.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(tss.fluxes, ax=ax2)
ax1.set(xlabel="time / a.u.", ylabel="Concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="Flux / a.u.")
plt.show()

The scan object returned has a `pandas.MultiIndex` of `n x time`, where `n` is an index that references parameter combinations.  
You can access the referenced parameters using `.parameters`

In [None]:
tss.parameters.head()

You can also easily access common aggregates using `get_agg_per_time`.  

In [None]:
tss.get_agg_per_time("std").head()

### Protocol

The same can be done for protocols.  

In [None]:
res = scan.time_course_over_protocol(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k2": np.linspace(1, 2, 11)}),
    protocol=make_protocol(
        {
            1: {"k1": 1},
            2: {"k1": 2},
            3: {"k1": 1},
        }
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(res.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(res.fluxes, ax=ax2)
ax1.set(xlabel="time / a.u.", ylabel="Concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="Flux / a.u.")

for ax in (ax1, ax2):
    plot.shade_protocol(res.protocol["k1"], ax=ax, alpha=0.2)
plt.show()

## Metabolic control analysis

`modelbase` supports both elasticities (arbitrary state) and response coefficients (steady-state) measurements from metabolic control analysis.  
They can all be found in the `modelbase2.mca` namespace.  

### Variable elasticities

Variable elasticities are the sensitivity of reactions to a small change in the concentration of a variable.  
They are **not** a steady-state measurement and can be calculated for any arbitrary state.  

Both the `concs` and `variables` arguments are optional.  
If `concs` is not supplied, the routine will use the initial conditions from the model.  
If `variables` is not supplied, the elasticities will be calculated for all variables.  

In [None]:
elas = mca.variable_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    variables=["GLC", "F6P"],
)

_ = plot.heatmap(elas)
plt.show()

### Parameter elasticities

Parameter elasticities are the sensitivity of reactions to a small change in the concentration of a variable.  
They are **not** a steady-state measurement and can be calculated for any arbitrary state.  

Both the `concs` and `parameters` arguments are optional.  
If `concs` is not supplied, the routine will use the initial conditions from the model.  
If `parameters` is not supplied, the elasticities will be calculated for all parameters.  

In [None]:
elas = mca.parameter_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    parameters=["k1", "k2"],
)

_ = plot.heatmap(elas)
plt.show()

### Response coefficients

Response coefficients show the sensitivity of variables and reactions **at steady-state** to a small change in a parameter.  

If the parameter is proportional to the rate, they are also called **control coefficients**.

In [None]:
crcs, frcs = mca.response_coefficients(
    get_upper_glycolysis(),
    parameters=["k1", "k2", "k3", "k4", "k5", "k6", "k7"],
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
_ = plot.heatmap(crcs, ax=ax1)
_ = plot.heatmap(frcs, ax=ax2)
plt.show()

## Fitting

Models can be fitted to both steady-state and time course data.  
You can both use concentrations and fluxes for this.  
Let's create some data first to fit our model to:

In [None]:
model_fn = get_linear_chain_2v
p_true = {"k1": 1.0, "k2": 2.0, "k3": 1.0}
res = unwrap(
    Simulator(model_fn())
    .update_parameters(p_true)
    .simulate(time_points=np.linspace(0, 10, 101))
    .get_results()
)

fig, ax = plot.lines(res)
ax.set(xlabel="time / a.u.", ylabel="Conc. & Flux / a.u.")
plt.show()

### Steady-states

For the steady-state fit we need one `pandas.Series` as an input.  
The fitting routine will compare all data contained in that series to the model output.  

In [None]:
fit.steady_state(
    model_fn(),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=res.iloc[-1],
)

If only some of the data is required, you can use a subset of it.  
The fitting routine will only try to fit concentrations and fluxes contained in that series.

In [None]:
fit.steady_state(
    model_fn(),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=res.iloc[-1].loc[["x", "y"]],
)

### Time course

For the steady-state fit we need a `pandas.DataFrame` as an input.  
Other than that, the same rules of the steady-state fitting apply.  

In [None]:
fit.time_course(
    model_fn(),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=res,
)

### Customisation

You can use depdency injection to overwrite the minimisation function as well as the residual function and the integrator.  
Here we create a custom minimization function.  

In [None]:
def nealder_mead(
    residual_fn: ResidualFn,
    p0: dict[str, float],
) -> dict[str, float]:
    res = minimize(
        residual_fn,
        x0=list(p0.values()),
        method="Nelder-Mead",
    )
    if res.success:
        return dict(
            zip(
                p0,
                res.x,
                strict=True,
            )
        )
    return dict(zip(p0, np.full(len(p0), np.nan, dtype=float), strict=True))


fit.time_course(
    model_fn(),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=res,
    minimize_fn=nealder_mead,
)

## Monte-carlo scans

Most model parameters in the natural sciences are typically best described using **distributions**.  
Thus, to get the distribution of realistic behaviour, `modelbase` supplies monte-carlo methods of all other analyses.

For that, you supply a `pandas.DataFrame` of parameter values randomly drawn from different distributions.  
You can use the `sample` function and distributions supplied by modelbase.  

In [None]:
sample(
    {
        "k2": Uniform(1.0, 2.0),
        "k3": LogNormal(mean=1.0, sigma=1.0),
    },
    n=5,
)

If you want to create custom distributions, all you need to do is to create a class that follows the `Distribution` protocol, e.g. implements a sample function.  

```python
class MyOwnDistribution:
    def sample(self, num: int) -> Array:
        # implement here
```

and it can be used in the `sample` function as well. 

### monte-carlo steady-states

Using `mc.steady_state` you can calculate the steady-state distribution given the monte-carlo parameters.  

In [None]:
ss = mc.steady_state(
    get_linear_chain_2v(),
    mc_parameters=sample(
        {
            "k1": Uniform(0.9, 1.1),
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4), sharex=False)
plot.violins(ss.concs, ax=ax1)
plot.violins(ss.fluxes, ax=ax2)
ax1.set(xlabel="Variables", ylabel="Concentration / a.u.")
ax2.set(xlabel="Reactions", ylabel="Flux / a.u.")
plt.show()

### monte-carlo time series

Using `mc.time_course` you can calculate time courses for sampled parameters.  

The `pandas.DataFrame`s for concentrations and fluxes have a `n x time` `pandas.MultiIndex`.  
The corresponding parameters can be found in `scan.parameters`

In [None]:
tc = mc.time_course(
    get_linear_chain_2v(),
    time_points=np.linspace(0, 1, 11),
    mc_parameters=sample(
        {
            "k1": Uniform(0.9, 1.1),
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tc.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(tc.fluxes, ax=ax2)
ax1.set(xlabel="Time / a.u", ylabel="Concentration / a.u.")
ax2.set(xlabel="Time / a.u", ylabel="Flux / a.u.")
plt.show()

### mc protocol


Using `mc.time_course_over_protocol` you can calculate time courses for sampled parameters given a discrete protocol.  

The `pandas.DataFrame`s for concentrations and fluxes have a `n x time` `pandas.MultiIndex`.  
The corresponding parameters can be found in `scan.parameters`

In [None]:
tc = mc.time_course_over_protocol(
    get_linear_chain_2v(),
    time_points_per_step=10,
    protocol=make_protocol(
        {
            1: {"k1": 1},
            2: {"k1": 2},
            3: {"k1": 1},
        }
    ),
    mc_parameters=sample(
        {
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tc.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(tc.fluxes, ax=ax2)
for ax in (ax1, ax2):
    plot.shade_protocol(tc.protocol["k1"], ax=ax, alpha=0.1)

ax1.set(xlabel="Time / a.u", ylabel="Concentration / a.u.")
ax2.set(xlabel="Time / a.u", ylabel="Flux / a.u.")
plt.show()

### mc metabolic control analysis

#### Variable elasticities

The returned `pandas.DataFrame` has a `pd.MultiIndex` of shape `n x reaction`.  

In [None]:
mc_elas = mc.variable_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    variables=["GLC", "F6P"],
    mc_parameters=sample(
        {
            # "k1": LogNormal(mean=np.log(0.25), sigma=1.0),
            # "k2": LogNormal(mean=np.log(1.0), sigma=1.0),
            "k3": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k4": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k5": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k6": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k7": LogNormal(mean=np.log(2.5), sigma=1.0),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(mc_elas)
plt.show()

#### Parameter elasticities

In [None]:
elas = mc.parameter_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    parameters=["k1", "k2", "k3"],
    mc_parameters=sample(
        {
            "k3": LogNormal(mean=np.log(0.25), sigma=1.0),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(elas)
plt.show()

#### Response coefficients

In [None]:
# Compare with "normal" control coefficients
rc = mca.response_coefficients(
    get_lin_chain_two_circles(),
    parameters=["vmax_1", "vmax_2", "vmax_3", "vmax_5", "vmax_6"],
)
_ = plot.heatmap(rc.concs)

mrc = mc.response_coefficients(
    get_lin_chain_two_circles(),
    parameters=["vmax_1", "vmax_2", "vmax_3", "vmax_5", "vmax_6"],
    mc_parameters=sample(
        {
            "k0": LogNormal(np.log(1.0), 1.0),
            "k4": LogNormal(np.log(0.5), 1.0),
        },
        n=10,
    ),
)

_ = plot.violins_from_2d_idx(mrc.concs, n_cols=len(mrc.concs.columns))

### mc steady-state parameter scan

Vary **both** monte carlo parameters as well as systematically scan for other parameters

In [None]:
mcss = mc.scan_steady_state(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(0, 1, 3)}),
    mc_parameters=sample(
        {
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

plot.violins_from_2d_idx(mcss.concs)
plt.show()

In [None]:
# FIXME: no idea how to plot this yet. Ridge plots?
# Maybe it's just a bit much :D

mcss = mc.scan_steady_state(
    get_linear_chain_2v(),
    parameters=mb2.cartesian_product(
        {
            "k1": np.linspace(0, 1, 3),
            "k2": np.linspace(0, 1, 3),
        }
    ),
    mc_parameters=sample(
        {
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

mcss.concs.head()

## Steady-state surrogates

**Surrogates** allow replacing part of a system of ODEs with a surrogate.
**Steady-state** surrogates use the fluxes at 

In [None]:
# Example plot of this models behaviour
_ = unwrap(Simulator(get_example1()).simulate(10).get_fluxes()).plot()

### Create data

In [None]:
features = pd.DataFrame(
    it.product(
        np.linspace(0, 2.0, 21),
        np.linspace(0, 2.0, 21),
        np.linspace(0, 2.0, 21),
    ),
    columns=["x1", "ATP", "NADPH"],
)

targets = (
    scan.steady_state(
        get_example1(),
        parameters=features,
        cache=Cache(Path(".cache") / "linear"),
    )
    .fluxes.loc[:, ["x2_out", "x3_out"]]
    .fillna(0)
)


fig, ax = plot.violins(targets)
ax.set(title="Targets", ylabel="Flux / a.u.")

### Train Surrogate

In [None]:
surrogate, loss = train_torch_surrogate(
    features=features,
    targets=targets,
    epochs=2000,
    surrogate_inputs=["x1", "ATP", "NADPH"],
    surrogate_stoichiometries={
        "v2": {"x1": -1, "x2": 1, "ATP": -1},
        "v3": {"x1": -1, "x3": 1, "NADPH": -1},
    },
)

loss.plot()

### Get predictions of rates

In [None]:
# FIXME: why is the prediction for 0, 0, 0 not 0?
print(surrogate.predict(np.array([0.0, 0.0, 0.0])))
print(surrogate.predict(np.array([1.0, 0.0, 0.0])))
print(surrogate.predict(np.array([0.0, 1.0, 0.0])))
print(surrogate.predict(np.array([0.0, 0.0, 1.0])))
print(surrogate.predict(np.array([1.0, 0.0, 1.0])))
print(surrogate.predict(np.array([1.0, 1.0, 1.0])))

### Insert surrogate into model

In [None]:
def get_tpi_model() -> Model:
    model = Model()
    model.add_variables(
        {
            "x1": 1.0,
            "x2": 0.0,
            "x3": 0.0,
            "ATP": 2.0,
            "NADPH": 0.1,
        }
    )

    # Adding the surrogate
    model.add_surrogate("surrogate", surrogate)

    # Note that besides the surrogate we haven't defined any other reaction!
    # We could have though
    return model


# FIXME: note that NADPH get's negative
# At least the rates seem to get 0 around the time when x1 is 0
c, v = unwrap2(Simulator(get_tpi_model()).simulate(1.0).get_full_concs_and_fluxes())

fig, (ax1, ax2) = plot.two_axes(figsize=(10, 4))
plot.lines(c, ax=ax1)
plot.lines(v, ax=ax2)
ax1.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="flux / a.u.")
plt.show()

## Neural posterior estimation

### Create data

In [None]:
# Example plot of this models behaviour
if (fluxes := Simulator(get_example1()).simulate(10).get_fluxes()) is not None:
    fluxes.plot()
    plt.show()

In [None]:
targets = sample(
    {
        "x1": LogNormal(mean=1.0, sigma=0.3),
        "ATP": LogNormal(mean=0.7, sigma=0.1),
        "NADPH": LogNormal(mean=0.3, sigma=0.2),
    },
    n=10_000,
)

time_points = np.linspace(0, 10, 11)

ss_data = scan.steady_state(
    get_example1(),
    parameters=targets,
    cache=Cache(Path(".cache") / "npe-ss"),
).results


ts_data = scan.time_course(
    get_example1(),
    parameters=targets,
    time_points=time_points,
    cache=Cache(Path(".cache") / "npe-ts"),
).results

### Train NPE on steady-state data


In [None]:
features = ss_data.loc[:, ["x2", "x3"]]

estimator, losses = npe.train_torch_ss_estimator(
    features=features,
    targets=targets,
    epochs=2_500,
)

ax = losses.plot()
ax.set_ylim(0, None)

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=True,
)

ax = sns.kdeplot(targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(features)

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
plt.show()

### Train NPE on time series data

In [None]:
features = ts_data.loc[:, ["x2", "x3"]]


estimator, losses = npe.train_torch_time_course_estimator(
    features=features,
    targets=targets,
    epochs=2_500,
)

losses.plot()

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=True,
)

ax = sns.kdeplot(targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(features)

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
plt.show()

## Label models

In [None]:
def mass_action_1(kf: float, s: float) -> float:
    return kf * s


def mass_action_2(kf: float, s1: float, s2: float) -> float:
    return kf * s1 * s2


def get_tpi_model() -> Model:
    p = {
        "kf_TPI": 1.0,
        "Keq_TPI": 21.0,
        "kf_Ald": 2000.0,
        "Keq_Ald": 7000.0,
    }
    p["kr_TPI"] = p["kf_TPI"] / p["Keq_TPI"]
    p["kr_Ald"] = p["kf_Ald"] / p["Keq_Ald"]

    GAP0 = 2.5e-5
    DHAP0 = GAP0 * p["Keq_TPI"]
    FBP0 = GAP0 * DHAP0 * p["Keq_Ald"]

    y0 = {"GAP": GAP0, "DHAP": DHAP0, "FBP": FBP0}

    return (
        Model()
        .add_variables(y0)
        .add_parameters(p)
        .add_reaction(
            "TPIf",
            mass_action_1,
            args=["kf_TPI", "GAP"],
            stoichiometry={"GAP": -1, "DHAP": 1},
        )
        .add_reaction(
            "TPIr",
            mass_action_1,
            args=["kr_TPI", "DHAP"],
            stoichiometry={"DHAP": -1, "GAP": 1},
        )
        .add_reaction(
            "ALDf",
            mass_action_2,
            args=["kf_Ald", "DHAP", "GAP"],
            stoichiometry={"DHAP": -1, "GAP": -1, "FBP": 1},
        )
        .add_reaction(
            "ALDr",
            mass_action_1,
            args=["kr_Ald", "FBP"],
            stoichiometry={
                "FBP": -1,
                "DHAP": 1,
                "GAP": 1,
            },
        )
    )

### Label mapper

Labelled models allow explicitly mapping the transitions between isotopomers variables.  
*modelbase* includes a `LabelMapper` that takes

- a model
- a dictionary mapping the variables to the amount of label positions they have
- a transition map 

to auto-generate all possible `2^n` variants of the variables and reaction transitions between them.  

As an example let's take triose phosphate isomerase, which catalyzes the interconversion of glyceraldehyde 3-phosphate (GAP) and dihydroxyacetone phosphate (DHAP).  
As illustrated below, in the case of the forward reaction the first and last carbon atoms are swapped

<img src="assets/carbon-maps.png" style="max-width: 250px">

So DHAP(1) is build from GAP(3), DHAP(2) from GAP(2) and DHAP(3) from GAP(1).  
We notate this using normal **0-based indexing** as follows

```python
label_maps = {"TPIf": [2, 1, 0]}
```

In [None]:
mapper = LabelMapper(
    get_tpi_model(),
    label_variables={"GAP": 3, "DHAP": 3, "FBP": 6},
    label_maps={
        "TPIf": [2, 1, 0],
        "TPIr": [2, 1, 0],
        "ALDf": [0, 1, 2, 3, 4, 5],
        "ALDr": [0, 1, 2, 3, 4, 5],
    },
)

if (
    concs := Simulator(mapper.build_model(initial_labels={"GAP": 0}))
    .simulate(20)
    .get_full_concs()
) is not None:
    plot.relative_label_distribution(mapper, concs, n_cols=3)

### Linear label mapper


The `LabelMapper` makes no assumptions about the state of the model, which causes a lot of complexity.  
In steady-state however, the space of possible solutions is reduced and the labelling dynamics can be represented by a set of linear differential equations.  
See [Sokol and Portais 2015](https://doi.org/10.1371/journal.pone.0144652) for the theory of dynamic label propagation under the stationary assumption.


In [None]:
m = get_tpi_model()

concs, fluxes = unwrap2(Simulator(m).simulate(20).get_concs_and_fluxes())


mapper = LinearLabelMapper(
    m,
    label_variables={"GAP": 3, "DHAP": 3, "FBP": 6},
    label_maps={
        "TPIf": [2, 1, 0],
        "TPIr": [2, 1, 0],
        "ALDf": [0, 1, 2, 3, 4, 5],
        "ALDr": [0, 1, 2, 3, 4, 5],
    },
)

if (
    concs := (
        Simulator(
            mapper.build_model(
                concs=concs.iloc[-1],
                fluxes=fluxes.iloc[-1],
                initial_labels={"GAP": 0},
            )
        )
        .simulate(20)
        .get_full_concs()
    )
) is not None:
    plot.relative_label_distribution(mapper, concs, n_cols=3)

## SBML

`modelbase` supports reading and writing **sbml** models using the `sbml.read` and `sbml.write` functions.

In [None]:
model = sbml.read(Path("assets") / "00001-sbml-l3v2.xml")
c, v = unwrap2(Simulator(model).simulate(10).get_concs_and_fluxes())
_ = plot.lines(c)

When exporting a model, you can supply additional meta-information like units and compartmentalisation.  
See the [official sbml documentation](https://sbml.org/documents/) for more information of legal values.

In [None]:
sbml.write(
    linear_chain_2cpds(),
    file=Path(".cache") / "model.xml",
    extent_units="mole",
    substance_units="mole",
    time_units="second",
)