# modelbase 2.0

In [None]:
from __future__ import annotations

import itertools as it
from pathlib import Path
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from example_models import (
    get_example1,
    get_linear_chain_2v,
    get_poolman2000,
    get_upper_glycolysis,
)

import modelbase2 as mb2
from modelbase2 import (
    Cache,
    Derived,
    LabelMapper,
    LinearLabelMapper,
    Model,
    Simulator,
    fit,
    mc,
    mca,
    npe,
    plot,
    scans,
)
from modelbase2.distributions import LogNormal, Uniform, sample
from modelbase2.surrogates import create_ss_flux_data, train_torch_surrogate
from modelbase2.types import unwrap, unwrap2

if TYPE_CHECKING:
    from modelbase2.types import Array


def create_steady_state_data(
    model: Model,
    parameters: pd.DataFrame,
    cache: Cache | None,
    y0: dict[str, float] | None = None,
) -> pd.DataFrame:
    return pd.concat(
        scans.parameter_scan_ss(
            model=model,
            parameters=parameters,
            y0=y0,
            cache=cache,
        ),
        axis=1,
    ).reset_index(drop=True)


def create_time_series_data(
    model: Model,
    parameters: pd.DataFrame,
    time_points: Array,
    cache: Cache | None,
    y0: dict[str, float] | None = None,
) -> pd.DataFrame:
    return scans.parameter_scan_time_series(
        model=model,
        parameters=parameters,
        time_points=time_points,
        y0=y0,
        cache=cache,
    ).results


def randomise_parameters(m: Model, seed: int = 42) -> Model:
    rng = np.random.default_rng(seed=seed)
    m.update_parameters(
        (pd.Series(m.parameters) + rng.normal(0, 0.125, len(m.parameters))).to_dict()
    )
    return m


def make_protocol(steps: dict[float, dict[str, float]]) -> pd.DataFrame:
    data = {}

    t0 = pd.Timedelta(0)
    for step, pars in steps.items():
        t0 += pd.Timedelta(seconds=step)
        data[t0] = pars
    return pd.DataFrame(data).T


## Building your first model

Let's say you want to model the following chemical network

$$ \Large \varnothing \xrightarrow{v_0} S \xrightarrow{v_1} P \xrightarrow{v_2} \varnothing $$

which translates into

$$\begin{align*}
\frac{dS}{dt} &= v_0 - v_1     \\
\frac{dP}{dt} &= v_1 - v_2 \\
\end{align*}
$$

We then choose rate equations for each rate to get the flux vector $v$

$$\begin{align*}
    v_0 &= k_0 \\
    v_1 &= k_1 * S \\
    v_2 &= k_2 * P \\
\end{align*}$$

<!-- $$v = \left\{ 
    \begin{align*}
    & k_0 \\
    & k_1 * S \\
    & k_2 * P \\
    \end{align*} 
\right.$$ -->

Then the system of ODEs is given by 


$$\begin{align*}
\frac{dS}{dt} &= k_0 - k_1 * S     \\
\frac{dP}{dt} &= k_1 * S - k_2 * P \\
\end{align*}$$

Let's begin by defining rate functions.  
Note that these should be **general** and **re-usable** whenever possible, to make your model clear to people reading it.  
Try to give these functions names that are meaningful to your audience, e.g. a rate function `k * s` could be named **proportional** or **mass-action**.

In [None]:
def constant(k: float) -> float:
    return k


def proportional(k: float, s: float) -> float:
    return k * s

Next, we create our model.  
Note, that we use a single function that returns the model instead of defining it globally.  
This allows us to quickly re-create the model whenever we need a fresh version of it.  

Let's step through the code below:

We first add parameters to the model using `.add_parameters({name: value})`.  

Next, we add variables using `.add_variables({name: initial_value})`.  

Finally, we add the three reactions by using 

```python
.add_reaction(
    name,              # the internal name for the reaction
    fn=...,            # a python function to be evaluated
    args=[name, ...]   # the arguments passed to the python function
    stoichiometry={    
        name: value    
    },                 # a mapping encoding how much the variable `name`
                       # is changed by the reaction
)
```

In [None]:
def linear_chain_2cpds() -> Model:
    return (
        Model()
        .add_parameters({"k_in": 1, "k_1": 1, "k_out": 1})
        .add_variables({"S": 0, "P": 0})
        .add_reaction(
            "v0",
            fn=constant,
            stoichiometry={"S": 1},
            args=["k_in"],
        )
        .add_reaction(
            "v1",
            fn=proportional,
            stoichiometry={"S": -1, "P": 1},
            args=["k_1", "S"],
        )
        .add_reaction(
            "v2",
            fn=proportional,
            stoichiometry={"P": -1},
            args=["k_out", "P"],
        )
    )

We can then simulate the model by passing it to a `Simulator` and simulate a time series using `.simulate(t_end)`.  
Finally, we can obtain the concentrations and fluxes using `get_concs_and_fluxes`.  

> Note, that  `get_concs_and_fluxes` returns `tuple[pd.DataFrame | None, pd.DataFrame | None]`, indicating that it will return `None` in case the simulation fails.  
> Thus, it is good practice to always check for failure.  

While you can directly plot the `pd.DataFrame`s, modelbase supplies a variety of plots in the `plot` namespace that are worth checking out.  

In [None]:
c, v = unwrap2(
    Simulator(linear_chain_2cpds())  # initialise the simulator
    .simulate(10)  # simulate until t_end = 10
    .get_concs_and_fluxes()  # return pd.DataFrames for concentrations and fluxes
)

fig, (ax1, ax2) = plot.two_axes(figsize=(6, 3))
_ = plot.lines(c, ax=ax1)
_ = plot.lines(v, ax=ax2)

By default, the `Simulator` is initialised with the initial concentrations set in the `Model`.  
Optionally, you can overwrite the initial conditions using the `y0` argument.  

```python
Simulator(model, y0={name: value, ...})
```

## Derived quantities

Frequently it makes sense to derive one quantity in a model from other quantities.  
In `modelbase2` this is done by using the `Derived` class.
This can be done for

- parameters derived from other parameters
- variables derived from parameters or other variables
- stoichiometries derived from parameters or variables 

In [None]:
def moiety_1(x1: float, total: float) -> float:
    return total - x1


m = (
    Model()
    .add_variables({"ATP": 1.0})
    .add_parameters({"ATP_total": 1.0})
    .add_derived("ADP", moiety_1, ["ATP", "ATP_total"])
    .add_reaction("ATPase", constant, {"ATP": -1}, ["ATP"])
)

c, v = Simulator(m).simulate(10).get_full_concs_and_fluxes()
if c is not None:
    plot.lines(c)

## Simulation

### Steady-state

You can simulate until the model reaches a steady-state using the `simulate_to_steady_state` method.  


In [None]:
concs, fluxes = (
    Simulator(get_linear_chain_2v())  # optionally supply initial conditions
    .simulate_to_steady_state()
    .get_concs_and_fluxes()
)
concs

### Time series

You can obtain the time series of integration using the `simulate` method.  
There are three ways how you can define the time points this function returns.  

1. supply the end time `t_end`
2. supply both end time and number of steps with `steps`
3. supply the exact time points to be returned using `time_points`

```python
simulate(t_end=10)
simulate(t_end=10, steps=10)
simulate(time_points=np.linspace(0, 10, 11))
```

> Note that these settings don't change the integration itself (e.g. tolerances)!

In [None]:
concs, fluxes = (
    Simulator(get_linear_chain_2v())  # optionally supply initial conditions
    .simulate(t_end=10)
    .get_concs_and_fluxes()
)

if concs is not None:
    _ = plot.lines(concs)
    plt.show()

### Protocol time series

Protocols are used to make parameter changes discrete in time.  
This is useful e.g. for reproducing experimental time courses where a parameter was changed at fixed time points.  

In [None]:
protocol = make_protocol(
    {
        1: {"k1": 1},
        2: {"k1": 2},
        3: {"k1": 1},
    }
)
concs, fluxes = (
    Simulator(get_linear_chain_2v())
    .simulate_over_protocol(protocol)
    .get_concs_and_fluxes()
)

if concs is not None:
    fig, ax = plt.subplots()
    plot.lines(concs, ax=ax)
    plot.shade_protocol(protocol["k1"], ax=ax, alpha=0.1)


## Parameter scans

`modelbase` has a variety of parameter scans available.  
They differ by the kind of data returned by them, e.g. steady-states or time series of concentration and fluxes.

In all cases, the parameters which you want to scan over are passed as a `pandas.DataFrame`.

### Steady-state

In [None]:
scan = mb2.parameter_scan_ss(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(1, 2, 11)}),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 3))
plot.lines(scan.concs, ax=ax1)
plot.lines(scan.fluxes, ax=ax2)
plt.show()

All scans return some kind of result object, which allow multiple access patterns for convenience. 

Namely, the concentrations and fluxes can be accessed by name, unpacked or combined into a single dataframe.

In [None]:
# Access by name
_ = scan.concs
_ = scan.fluxes

# scan can be unpacked
concs, fluxes = scan

# combine concs and fluxes as single dataframe
_ = scan.results


#### Combinations

Often you want to scan over multiple parameters at the same time.  
The recommended way to do this is to use the `cartesian_product` function, which takes a `parameter_name: values` mapping and creates a `pandas.DataFrame` of their combinations from it (think nested for loop).  

In the case of more than one parameter, the returned `pandas.DataFrame` contains a `pandas.MultiIndex`.  

In [None]:
scan = mb2.parameter_scan_ss(
    get_linear_chain_2v(),
    mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 3),
            "k2": np.linspace(1, 2, 4),
        }
    ),
)

scan.results.head()


You can plot the results of this scan using a heatmap

In [None]:
plot.heatmap_from_2d_idx(scan.concs, variable="x")
plt.show()

Or create heatmaps of all passed variables at once.  

In [None]:
plot.heatmaps_from_2d_idx(scan.concs)
plt.show()

You can also combine more than two parameters, however, visualisation becomes challenging.  

In [None]:
scan = mb2.parameter_scan_ss(
    get_linear_chain_2v(),
    mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 3),
            "k2": np.linspace(1, 2, 4),
            "k3": np.linspace(1, 2, 4),
        }
    ),
)
scan.results.head()

### Time-series

You can perform a time series for each of the parameter values.  
The index now also contains the time, so even for one parameter a `pandas.MultiIndex` is used.

In [None]:
tss = mb2.parameter_scan_time_series(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(1, 2, 11)}),
    time_points=np.linspace(0, 1, 11),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tss.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(tss.fluxes, ax=ax2)
plt.show()

Again, this works for an arbitray number of parameters.

In [None]:
tss = mb2.parameter_scan_time_series(
    get_linear_chain_2v(),
    parameters=mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 11),
            "k2": np.linspace(1, 2, 4),
        }
    ),
    time_points=np.linspace(0, 1, 11),
)


fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tss.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(tss.fluxes, ax=ax2)
plt.show()

The scan object returned has a `pandas.MultiIndex` of `n x time`, where `n` is an index that references parameter combinations.  
You can access the referenced parameters using `.parameters`

In [None]:
tss.parameters.head()

You can also easily access common aggregates using `get_agg_per_time`.  

In [None]:
tss.get_agg_per_time("std").head()

### Protocol

The same can be done for protocols.  

In [None]:
scan = mb2.parameter_scan_protocol(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k2": np.linspace(1, 2, 11)}),
    protocol=make_protocol(
        {
            1: {"k1": 1},
            2: {"k1": 2},
            3: {"k1": 1},
        }
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(scan.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(scan.fluxes, ax=ax2)

for ax in (ax1, ax2):
    plot.shade_protocol(scan.protocol["k1"], ax=ax, alpha=0.2)
plt.show()

## Metabolic control analysis

`modelbase` supports both elasticities (arbitrary state) and response coefficients (steady-state) measurements from metabolic control analysis.  
They can all be found in the `modelbase2.mca` namespace.  

### Variable elasticities

Variable elasticities are the sensitivity of reactions to a small change in the concentration of a variable.  
They are **not** a steady-state measurement and can be calculated for any arbitrary state.  

Both the `concs` and `variables` arguments are optional.  
If `concs` is not supplied, the routine will use the initial conditions from the model.  
If `variables` is not supplied, the elasticities will be calculated for all variables.  

In [None]:
elas = mca.variable_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    variables=["GLC", "F6P"],
)

_ = plot.heatmap(elas)
plt.show()

### Parameter elasticities

Parameter elasticities are the sensitivity of reactions to a small change in the concentration of a variable.  
They are **not** a steady-state measurement and can be calculated for any arbitrary state.  

Both the `concs` and `parameters` arguments are optional.  
If `concs` is not supplied, the routine will use the initial conditions from the model.  
If `parameters` is not supplied, the elasticities will be calculated for all parameters.  

In [None]:
elas = mca.parameter_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    parameters=["k1", "k2"],
)

_ = plot.heatmap(elas)
plt.show()

### Response coefficients

Response coefficients show the sensitivity of variables and reactions **at steady-state** to a small change in a parameter.  

If the parameter is proportional to the rate, they are also called **control coefficients**.

In [None]:
crcs, frcs = mca.response_coefficients(
    get_upper_glycolysis(),
    parameters=["k1", "k2", "k3", "k4", "k5", "k6", "k7"],
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
_ = plot.heatmap(crcs, ax=ax1)
_ = plot.heatmap(frcs, ax=ax2)
plt.show()

## Monte-carlo scans

Most model parameters in the natural sciences are typically best described using **distributions**.  
Thus, to get the distribution of realistic behaviour, `modelbase` supplies monte-carlo methods of all other analyses.

For that, you supply a `pandas.DataFrame` of parameter values randomly drawn from different distributions.  
You can use the `sample` function and distributions supplied by modelbase.  

In [None]:
sample(
    {
        "k2": Uniform(1.0, 2.0),
        "k3": LogNormal(mean=1.0, sigma=1.0),
    },
    n=5,
)

If you want to create custom distributions, all you need to do is to create a class that follows the `Distribution` protocol, e.g. implements a sample function.  

```python
class MyOwnDistribution:
    def sample(self, num: int) -> Array:
        # implement here
```

and it can be used in the `sample` function as well. 

### monte-carlo steady-states

Using `mc.steady_state` you can calculate the steady-state distribution given the monte-carlo parameters.  

In [None]:
scan = mc.steady_state(
    get_linear_chain_2v(),
    mc_parameters=sample(
        {
            "k1": Uniform(0.9, 1.1),
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4), sharex=False)
plot.violins(scan.concs, ax=ax1)
plot.violins(scan.fluxes, ax=ax2)
plt.show()

### monte-carlo time series

Using `mc.time_course` you can calculate time courses for sampled parameters.  

The `pandas.DataFrame`s for concentrations and fluxes have a `n x time` `pandas.MultiIndex`.  
The corresponding parameters can be found in `scan.parameters`

In [None]:
scan = mc.time_course(
    get_linear_chain_2v(),
    time_points=np.linspace(0, 1, 11),
    mc_parameters=sample(
        {
            "k1": Uniform(0.9, 1.1),
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(scan.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(scan.fluxes, ax=ax2)
plt.show()

### mc protocol


Using `mc.time_course_over_protocol` you can calculate time courses for sampled parameters given a discrete protocol.  

The `pandas.DataFrame`s for concentrations and fluxes have a `n x time` `pandas.MultiIndex`.  
The corresponding parameters can be found in `scan.parameters`

In [None]:
scan = mc.time_course_over_protocol(
    get_linear_chain_2v(),
    time_points_per_step=10,
    protocol=make_protocol(
        {
            1: {"k1": 1},
            2: {"k1": 2},
            3: {"k1": 1},
        }
    ),
    mc_parameters=sample(
        {
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(scan.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(scan.fluxes, ax=ax2)
for ax in (ax1, ax2):
    plot.shade_protocol(scan.protocol["k1"], ax=ax, alpha=0.1)

plt.show()

### mc metabolic control analysis

#### Compound elasticities

The returned `pandas.DataFrame` has a `pd.MultiIndex` of shape `n x reaction`.  

In [None]:
elas = mc.compound_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    variables=["GLC", "F6P"],
    mc_parameters=sample(
        {
            # "k1": LogNormal(mean=np.log(0.25), sigma=1.0),
            # "k2": LogNormal(mean=np.log(1.0), sigma=1.0),
            "k3": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k4": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k5": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k6": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k7": LogNormal(mean=np.log(2.5), sigma=1.0),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(elas)
plt.show()

#### Parameter elasticities

In [None]:
elas = mc.parameter_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    parameters=["k1", "k2", "k3"],
    mc_parameters=sample(
        {
            "k3": LogNormal(mean=np.log(0.25), sigma=1.0),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(elas)
plt.show()

#### Response coefficients

In [None]:
resp = mc.response_coefficients(
    get_poolman2000(),
    parameters=["Vmax_1", "Vmax_6", "Vmax_9", "Vmax_13", "Vmax_16"],
    mc_parameters=sample(
        {
            "Vmax_1": LogNormal(np.log(2.72), sigma=0.1),
            "Vmax_6": LogNormal(np.log(1.6), sigma=0.1),
            "Vmax_9": LogNormal(np.log(0.32), sigma=0.1),
            "Vmax_13": LogNormal(np.log(8.0), sigma=0.1),
            "Vmax_16": LogNormal(np.log(2.8), sigma=0.1),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(resp.concs.loc[:, ["PGA", "GAP", "SBP"]])
plt.show()

### mc steady-state parameter scan

Vary **both** monte carlo parameters as well as systematically scan for other parameters

In [None]:
mcss = mc.parameter_scan_ss(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(0, 1, 3)}),
    mc_parameters=sample(
        {
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

plot.violins_from_2d_idx(mcss.concs)
plt.show()

In [None]:
# FIXME: no idea how to plot this yet. Ridge plots?
# Maybe it's just a bit much :D

mcss = mc.parameter_scan_ss(
    get_linear_chain_2v(),
    parameters=mb2.cartesian_product(
        {
            "k1": np.linspace(0, 1, 3),
            "k2": np.linspace(0, 1, 3),
        }
    ),
    mc_parameters=sample(
        {
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

mcss.concs.head()

## Fitting

In [None]:
m = get_linear_chain_2v()
c, v = (
    Simulator(m).simulate(time_points=np.linspace(0, 1, 11)).get_full_concs_and_fluxes()
)

res = pd.concat((c, v), axis=1)
res.head()

fit.steady_state(
    randomise_parameters(get_linear_chain_2v()),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=res.iloc[-1].loc[["x", "v1"]],
)
fit.time_series(
    randomise_parameters(get_linear_chain_2v()),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=res.loc[:, ["x", "v1"]],
)

## Steady-state surrogates

In [None]:
# Example plot of this models behaviour
_ = unwrap(Simulator(get_example1()).simulate(10).get_fluxes()).plot()


### Create data

In [None]:
features = pd.DataFrame(
    it.product(
        np.linspace(0, 2.0, 21),
        np.linspace(0, 2.0, 21),
        np.linspace(0, 2.0, 21),
    ),
    columns=["x1", "ATP", "NADPH"],
)

targets = create_ss_flux_data(
    get_example1(),
    features,
    cache=Cache(Path(".cache") / "linear"),
).loc[:, ["x2_out", "x3_out"]]

### Train Surrogate

In [None]:
surrogate, loss = train_torch_surrogate(
    features=features,
    targets=targets,
    epochs=2000,
    surrogate_inputs=["x1", "ATP", "NADPH"],
    surrogate_stoichiometries={
        "v2": {"x1": -1, "x2": 1, "ATP": -1},
        "v3": {"x1": -1, "x3": 1, "NADPH": -1},
    },
)

loss.plot()

### Get predictions of rates

In [None]:
print(surrogate.predict(np.array([0.0, 0.0, 0.0])))
print(surrogate.predict(np.array([1.0, 0.0, 0.0])))
print(surrogate.predict(np.array([0.0, 1.0, 0.0])))
print(surrogate.predict(np.array([0.0, 0.0, 1.0])))
print(surrogate.predict(np.array([1.0, 0.0, 1.0])))
print(surrogate.predict(np.array([1.0, 1.0, 1.0])))

### Insert surrogate into model

In [None]:
def get_model() -> Model:
    model = Model()
    model.add_variables(
        {
            "x1": 1.0,
            "x2": 0.0,
            "x3": 0.0,
            "ATP": 2.0,
            "NADPH": 0.1,
        }
    )

    # Adding the surrogate
    model.add_surrogate("surrogate", surrogate)

    # Note that besides the surrogate we haven't defined any other reaction!
    # We could have though
    return model


c, v = Simulator(get_model()).simulate(0.8).get_full_concs_and_fluxes()

# FIXME: note that NADPH get's negative
# At least the rates seem to get 0 around the tie when x1 is 0
if c is None or v is None:
    msg = "Simulation failed"
    raise ValueError(msg)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
c.plot(ax=ax1, xlabel="time / s", ylabel="concentration / mM")
v.plot(ax=ax2, xlabel="time / s", ylabel="flux / (mM / s)")
plt.show()

## Neural posterior estimation

In [None]:
# Example plot of this models behaviour
if (fluxes := Simulator(get_example1()).simulate(10).get_fluxes()) is not None:
    fluxes.plot()
    plt.show()

### Create data

In [None]:
targets = sample(
    {
        "x1": LogNormal(mean=1.0, sigma=0.3),
        "ATP": LogNormal(mean=0.7, sigma=0.1),
        "NADPH": LogNormal(mean=0.3, sigma=0.2),
    },
    n=10_000,
)

time_points = np.linspace(0, 10, 11)

ss_data = create_steady_state_data(
    get_example1(),
    parameters=targets,
    cache=Cache(Path(".cache") / "npe-ss"),
)

ts_data = create_time_series_data(
    get_example1(),
    parameters=targets,
    time_points=time_points,
    cache=Cache(Path(".cache") / "npe-ts"),
)

### Train NPE on steady-state data


In [None]:
features = ss_data.loc[:, ["x2", "x3"]]

estimator, losses = npe.train_torch_ss_estimator(
    features=features,
    targets=targets,
    epochs=5_000,
)

losses.plot()

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=True,
)

ax = sns.kdeplot(targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(features)

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
plt.show()

### Train NPE on time series data

In [None]:
features = ts_data.loc[:, ["x2", "x3"]]


estimator, losses = npe.train_torch_time_series_estimator(
    features=features,
    targets=targets,
    epochs=5_000,
)

losses.plot()

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=True,
)

ax = sns.kdeplot(targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(features)

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
plt.show()

## Label models

In [None]:
def mass_action_1(kf: float, s: float) -> float:
    return kf * s


def mass_action_2(kf: float, s1: float, s2: float) -> float:
    return kf * s1 * s2


def get_model() -> Model:
    p = {
        "kf_TPI": 1.0,
        "Keq_TPI": 21.0,
        "kf_Ald": 2000.0,
        "Keq_Ald": 7000.0,
    }
    p["kr_TPI"] = p["kf_TPI"] / p["Keq_TPI"]
    p["kr_Ald"] = p["kf_Ald"] / p["Keq_Ald"]

    GAP0 = 2.5e-5
    DHAP0 = GAP0 * p["Keq_TPI"]
    FBP0 = GAP0 * DHAP0 * p["Keq_Ald"]

    y0 = {"GAP": GAP0, "DHAP": DHAP0, "FBP": FBP0}

    return (
        Model()
        .add_variables(y0)
        .add_parameters(p)
        .add_reaction(
            "TPIf",
            mass_action_1,
            {"GAP": -1, "DHAP": 1},
            ["kf_TPI", "GAP"],
        )
        .add_reaction(
            "TPIr",
            mass_action_1,
            {"DHAP": -1, "GAP": 1},
            ["kr_TPI", "DHAP"],
        )
        .add_reaction(
            "ALDf",
            mass_action_2,
            {"DHAP": -1, "GAP": -1, "FBP": 1},
            ["kf_Ald", "DHAP", "GAP"],
        )
        .add_reaction(
            "ALDr",
            mass_action_1,
            {
                "FBP": -1,
                "DHAP": 1,
                "GAP": 1,
            },
            ["kr_Ald", "FBP"],
        )
    )

### Label mapper

In [None]:
mapper = LabelMapper(
    get_model(),
    label_variables={"GAP": 3, "DHAP": 3, "FBP": 6},
    label_maps={
        "TPIf": [2, 1, 0],
        "TPIr": [2, 1, 0],
        "ALDf": [0, 1, 2, 3, 4, 5],
        "ALDr": [0, 1, 2, 3, 4, 5],
    },
)

if (
    concs := Simulator(mapper.build_model(initial_labels={"GAP": 0}))
    .simulate(20)
    .get_full_concs()
) is not None:
    plot.relative_label_distribution(mapper, concs, n_cols=3)


### Linear label mapper

In [None]:
m = get_model()

concs, fluxes = Simulator(m).simulate(20).get_concs_and_fluxes()
if concs is None or fluxes is None:
    raise ValueError

mapper = LinearLabelMapper(
    m,
    label_variables={"GAP": 3, "DHAP": 3, "FBP": 6},
    label_maps={
        "TPIf": [2, 1, 0],
        "TPIr": [2, 1, 0],
        "ALDf": [0, 1, 2, 3, 4, 5],
        "ALDr": [0, 1, 2, 3, 4, 5],
    },
)

if (
    concs := (
        Simulator(
            mapper.build_model(
                concs=concs.iloc[-1],
                fluxes=fluxes.iloc[-1],
                initial_labels={"GAP": 0},
            )
        )
        .simulate(20)
        .get_full_concs()
    )
) is not None:
    plot.relative_label_distribution(mapper, concs, n_cols=3)
