In [None]:
import numpy as np
import pandas as pd
from example_models import get_linear_chain_2v, get_poolman2000, get_upper_glycolysis
from matplotlib import pyplot as plt

import modelbase2 as mb2
from modelbase2 import Simulator, mc, mca, plot
from modelbase2.distributions import LogNormal, Uniform, sample


def make_protocol(steps: dict[float, dict[str, float]]) -> pd.DataFrame:
    data = {}

    t0 = pd.Timedelta(0)
    for step, pars in steps.items():
        t0 += pd.Timedelta(seconds=step)
        data[t0] = pars
    return pd.DataFrame(data).T

## Simulation

### Steady-state

You can simulate until the model reaches a steady-state using the `simulate_to_steady_state` method.  


In [None]:
concs, fluxes = (
    Simulator(get_linear_chain_2v())  # optionally supply initial conditions
    .simulate_to_steady_state()
    .get_concs_and_fluxes()
)
concs

### Time series

You can obtain the time series of integration using the `simulate` method.  
There are three ways how you can define the time points this function returns.  

1. supply the end time `t_end`
2. supply both end time and number of steps with `steps`
3. supply the exact time points to be returned using `time_points`

```python
simulate(t_end=10)
simulate(t_end=10, steps=10)
simulate(time_points=np.linspace(0, 10, 11))
```

> Note that these settings don't change the integration itself (e.g. tolerances)!

In [None]:
concs, fluxes = (
    Simulator(get_linear_chain_2v())  # optionally supply initial conditions
    .simulate(t_end=10)
    .get_concs_and_fluxes()
)

if concs is not None:
    _ = plot.lines(concs)
    plt.show()

### Protocol time series

Protocols are used to make parameter changes discrete in time.  
This is useful e.g. for reproducing experimental time courses where a parameter was changed at fixed time points.  

In [None]:
protocol = make_protocol(
    {
        1: {"k1": 1},
        2: {"k1": 2},
        3: {"k1": 1},
    }
)
concs, fluxes = (
    Simulator(get_linear_chain_2v())
    .simulate_over_protocol(protocol)
    .get_concs_and_fluxes()
)

if concs is not None:
    fig, ax = plt.subplots()
    plot.lines(concs, ax=ax)
    plot.shade_protocol(protocol["k1"], ax=ax, alpha=0.1)


## Parameter scans

`modelbase` has a variety of parameter scans available.  
They differ by the kind of data returned by them, e.g. steady-states or time series of concentration and fluxes.

In all cases, the parameters which you want to scan over are passed as a `pandas.DataFrame`.

### Steady-state

In [None]:
scan = mb2.parameter_scan_ss(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(1, 2, 11)}),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 3))
plot.lines(scan.concs, ax=ax1)
plot.lines(scan.fluxes, ax=ax2)
plt.show()

All scans return some kind of result object, which allow multiple access patterns for convenience. 

Namely, the concentrations and fluxes can be accessed by name, unpacked or combined into a single dataframe.

In [None]:
# Access by name
_ = scan.concs
_ = scan.fluxes

# scan can be unpacked
concs, fluxes = scan

# combine concs and fluxes as single dataframe
_ = scan.results


#### Combinations

Often you want to scan over multiple parameters at the same time.  
The recommended way to do this is to use the `cartesian_product` function, which takes a `parameter_name: values` mapping and creates a `pandas.DataFrame` of their combinations from it (think nested for loop).  

In the case of more than one parameter, the returned `pandas.DataFrame` contains a `pandas.MultiIndex`.  

In [None]:
scan = mb2.parameter_scan_ss(
    get_linear_chain_2v(),
    mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 3),
            "k2": np.linspace(1, 2, 4),
        }
    ),
)

scan.results.head()


You can plot the results of this scan using a heatmap

In [None]:
plot.heatmap_from_2d_idx(scan.concs, variable="x")
plt.show()

Or create heatmaps of all passed variables at once.  

In [None]:
plot.heatmaps_from_2d_idx(scan.concs)
plt.show()

You can also combine more than two parameters, however, visualisation becomes challenging.  

In [None]:
scan = mb2.parameter_scan_ss(
    get_linear_chain_2v(),
    mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 3),
            "k2": np.linspace(1, 2, 4),
            "k3": np.linspace(1, 2, 4),
        }
    ),
)
scan.results.head()

### Time-series

You can perform a time series for each of the parameter values.  
The index now also contains the time, so even for one parameter a `pandas.MultiIndex` is used.

In [None]:
tss = mb2.parameter_scan_time_series(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(1, 2, 11)}),
    time_points=np.linspace(0, 1, 11),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tss.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(tss.fluxes, ax=ax2)
plt.show()

Again, this works for an arbitray number of parameters.

In [None]:
tss = mb2.parameter_scan_time_series(
    get_linear_chain_2v(),
    parameters=mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 11),
            "k2": np.linspace(1, 2, 4),
        }
    ),
    time_points=np.linspace(0, 1, 11),
)


fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tss.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(tss.fluxes, ax=ax2)
plt.show()

The scan object returned has a `pandas.MultiIndex` of `n x time`, where `n` is an index that references parameter combinations.  
You can access the referenced parameters using `.parameters`

In [None]:
tss.parameters.head()

You can also easily access common aggregates using `get_agg_per_time`.  

In [None]:
tss.get_agg_per_time("std").head()

### Protocol

The same can be done for protocols.  

In [None]:
scan = mb2.parameter_scan_protocol(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k2": np.linspace(1, 2, 11)}),
    protocol=make_protocol(
        {
            1: {"k1": 1},
            2: {"k1": 2},
            3: {"k1": 1},
        }
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(scan.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(scan.fluxes, ax=ax2)

for ax in (ax1, ax2):
    plot.shade_protocol(scan.protocol["k1"], ax=ax, alpha=0.2)
plt.show()

## Metabolic control analysis

`modelbase` supports both elasticities (arbitrary state) and response coefficients (steady-state) measurements from metabolic control analysis.  
They can all be found in the `modelbase2.mca` namespace.  

### Variable elasticities

Variable elasticities are the sensitivity of reactions to a small change in the concentration of a variable.  
They are **not** a steady-state measurement and can be calculated for any arbitrary state.  

Both the `concs` and `variables` arguments are optional.  
If `concs` is not supplied, the routine will use the initial conditions from the model.  
If `variables` is not supplied, the elasticities will be calculated for all variables.  

In [None]:
elas = mca.variable_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    variables=["GLC", "F6P"],
)

_ = plot.heatmap(elas)
plt.show()

### Parameter elasticities

Parameter elasticities are the sensitivity of reactions to a small change in the concentration of a variable.  
They are **not** a steady-state measurement and can be calculated for any arbitrary state.  

Both the `concs` and `parameters` arguments are optional.  
If `concs` is not supplied, the routine will use the initial conditions from the model.  
If `parameters` is not supplied, the elasticities will be calculated for all parameters.  

In [None]:
elas = mca.parameter_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    parameters=["k1", "k2"],
)

_ = plot.heatmap(elas)
plt.show()

### Response coefficients

Response coefficients show the sensitivity of variables and reactions **at steady-state** to a small change in a parameter.  

If the parameter is proportional to the rate, they are also called **control coefficients**.

In [None]:
crcs, frcs = mca.response_coefficients(
    get_upper_glycolysis(),
    parameters=["k1", "k2", "k3", "k4", "k5", "k6", "k7"],
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
_ = plot.heatmap(crcs, ax=ax1)
_ = plot.heatmap(frcs, ax=ax2)
plt.show()

## Monte-carlo scans

Most model parameters in the natural sciences are typically best described using **distributions**.  
Thus, to get the distribution of realistic behaviour, `modelbase` supplies monte-carlo methods of all other analyses.

For that, you supply a `pandas.DataFrame` of parameter values randomly drawn from different distributions.  
You can use the `sample` function and distributions supplied by modelbase.  

In [None]:
sample(
    {
        "k2": Uniform(1.0, 2.0),
        "k3": LogNormal(mean=1.0, sigma=1.0),
    },
    n=5,
)

If you want to create custom distributions, all you need to do is to create a class that follows the `Distribution` protocol, e.g. implements a sample function.  

```python
class MyOwnDistribution:
    def sample(self, num: int) -> Array:
        # implement here
```

and it can be used in the `sample` function as well. 

### monte-carlo steady-states

Using `mc.steady_state` you can calculate the steady-state distribution given the monte-carlo parameters.  

In [None]:
scan = mc.steady_state(
    get_linear_chain_2v(),
    mc_parameters=sample(
        {
            "k1": Uniform(0.9, 1.1),
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4), sharex=False)
plot.violins(scan.concs, ax=ax1)
plot.violins(scan.fluxes, ax=ax2)
plt.show()

### monte-carlo time series

Using `mc.time_course` you can calculate time courses for sampled parameters.  

The `pandas.DataFrame`s for concentrations and fluxes have a `n x time` `pandas.MultiIndex`.  
The corresponding parameters can be found in `scan.parameters`

In [None]:
scan = mc.time_course(
    get_linear_chain_2v(),
    time_points=np.linspace(0, 1, 11),
    mc_parameters=sample(
        {
            "k1": Uniform(0.9, 1.1),
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(scan.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(scan.fluxes, ax=ax2)
plt.show()

### mc protocol


Using `mc.time_course_over_protocol` you can calculate time courses for sampled parameters given a discrete protocol.  

The `pandas.DataFrame`s for concentrations and fluxes have a `n x time` `pandas.MultiIndex`.  
The corresponding parameters can be found in `scan.parameters`

In [None]:
scan = mc.time_course_over_protocol(
    get_linear_chain_2v(),
    time_points_per_step=10,
    protocol=make_protocol(
        {
            1: {"k1": 1},
            2: {"k1": 2},
            3: {"k1": 1},
        }
    ),
    mc_parameters=sample(
        {
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(scan.concs, ax=ax1)
plot.lines_mean_std_from_2d_idx(scan.fluxes, ax=ax2)
for ax in (ax1, ax2):
    plot.shade_protocol(scan.protocol["k1"], ax=ax, alpha=0.1)

plt.show()

### mc metabolic control analysis

#### Compound elasticities

The returned `pandas.DataFrame` has a `pd.MultiIndex` of shape `n x reaction`.  

In [None]:
elas = mc.compound_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    variables=["GLC", "F6P"],
    mc_parameters=sample(
        {
            # "k1": LogNormal(mean=np.log(0.25), sigma=1.0),
            # "k2": LogNormal(mean=np.log(1.0), sigma=1.0),
            "k3": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k4": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k5": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k6": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k7": LogNormal(mean=np.log(2.5), sigma=1.0),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(elas)
plt.show()

#### Parameter elasticities

In [None]:
elas = mc.parameter_elasticities(
    get_upper_glycolysis(),
    concs={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    parameters=["k1", "k2", "k3"],
    mc_parameters=sample(
        {
            "k3": LogNormal(mean=np.log(0.25), sigma=1.0),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(elas)
plt.show()

#### Response coefficients

In [None]:
resp = mc.response_coefficients(
    get_poolman2000(),
    parameters=["Vmax_1", "Vmax_6", "Vmax_9", "Vmax_13", "Vmax_16"],
    mc_parameters=sample(
        {
            "Vmax_1": LogNormal(np.log(2.72), sigma=0.1),
            "Vmax_6": LogNormal(np.log(1.6), sigma=0.1),
            "Vmax_9": LogNormal(np.log(0.32), sigma=0.1),
            "Vmax_13": LogNormal(np.log(8.0), sigma=0.1),
            "Vmax_16": LogNormal(np.log(2.8), sigma=0.1),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(resp.concs.loc[:, ["PGA", "GAP", "SBP"]])
plt.show()

### mc steady-state parameter scan

Vary **both** monte carlo parameters as well as systematically scan for other parameters

In [None]:
mcss = mc.parameter_scan_ss(
    get_linear_chain_2v(),
    parameters=pd.DataFrame({"k1": np.linspace(0, 1, 3)}),
    mc_parameters=sample(
        {
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

plot.violins_from_2d_idx(mcss.concs)
plt.show()

In [None]:
# FIXME: no idea how to plot this yet. Ridge plots?
# Maybe it's just a bit much :D

mcss = mc.parameter_scan_ss(
    get_linear_chain_2v(),
    parameters=mb2.cartesian_product(
        {
            "k1": np.linspace(0, 1, 3),
            "k2": np.linspace(0, 1, 3),
        }
    ),
    mc_parameters=sample(
        {
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

mcss.concs.head()