In [None]:
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import Markdown
from matplotlib.figure import Figure
from mxlpy import (
    Derived,
    Model,
    Simulator,
    compare,
    fns,
    plot,
    report,
    surrogates,
    unwrap,
)

from models import get_sir, get_sird

(TMP_DIR := Path("temp")).mkdir(exist_ok=True, parents=True)


def plot_difference(r_old: pd.DataFrame, r_new: pd.DataFrame) -> Figure:
    rel_diff = (r_new - r_old) / r_old
    largest_diff = rel_diff.abs().mean().fillna(0).sort_values().tail(n=3)

    fig, ax = plot.one_axes()
    plot.lines(r_new, ax=ax)
    lines = dict(zip(r_new.columns, ax.lines, strict=True))
    for f, i in enumerate(reversed(largest_diff.index), start=2):
        line = lines[i]  # type: ignore
        line.set_linewidth(line.get_linewidth() * f)

    plot.reset_prop_cycle(ax)
    plot.lines(r_old, ax=ax, alpha=0.25, legend=False)
    ax.set(xlabel="Time / a.u.", ylabel="Relative Population")
    return fig


def remove_labels(axs: plot.Axs) -> None:
    for ax in axs:
        ax.set(xlabel="", ylabel="")


def grid_labels(
    axs: plot.Axs,
    xlabel: str | None = None,
    ylabel: str | None = None,
) -> None:
    """Apply labels to left and bottom axes."""
    remove_labels(axs)

    col = 0
    invisible = False

    for ax in axs[:, 0]:
        ax.set_ylabel(ylabel)
    for i, ax in enumerate(axs[-1, :]):
        if not ax.get_visible():
            col = i
            invisible = True
            break
        ax.set_xlabel(xlabel)
    if invisible:
        for ax in axs[-2, col:]:
            ax.set_xlabel(xlabel)


# MxlPy & MxlBricks workshop

<p align="center">
    <img 
        src="https://raw.githubusercontent.com/Computational-Biology-Aachen/MxlPy/refs/heads/main/docs/assets/logo-diagram.png"
        style="width: 350px; max-width: 45%"
        alt='mxlpy-logo'
    >
    <img 
        src="https://raw.githubusercontent.com/Computational-Biology-Aachen/mxl-bricks/refs/heads/main/docs/assets/logo.png"
        style="width: 350px; max-width: 45%"
        alt='mxlbricks-logo'
    >
</p>

Today we are going to talk about some of the new `MxlPy` features.  

This includes:

- All the ways how model components can be derived from each other
- Mechanistic Learning techniques such as surrogates and reaction carousels
- The purpose and design of the `MxlBricks` library

Optionally, if time allows

- some code anti-patterns and why they are considered that way


## All things derived

Loads of values can be derived from each other.  

Since all of these values depend on something, you can obtain a `pandas.Series` with every possible argument (excluding data) and the calculated values themselves with `get_dependent`.  


This includes
- parameters
- derived parameters
- variables
- derived variables
- rates
- surrogate outputs
- (readouts)

This is different from `get_args`, which only contains

- variables
- derived variables
- (readouts)


We can discuss later, if it makes sense to combine these functions into one 

### Parameters

In [None]:
(
    Model()
    .add_parameter("p1", 1.0)
    .add_derived("d1", fns.twice, args=["p1"])  # derive from parameter p1
    .get_dependent()
)

### Variables

In [None]:
(
    Model()
    .add_variable("v1", 1.0)
    .add_derived("d1", fns.twice, args=["v1"])  # derive from variable v1
    .get_dependent()
)

### Derived variables

In [None]:
(
    Model()
    .add_parameter("p1", 1.0)
    .add_derived("d1", fns.twice, args=["p1"])
    .add_derived("d2", fns.twice, args=["d1"])  # derive from derived d1
    .get_dependent()
)

### Rates

> Note: does **not** include the stoichiometry, just the rate

In [None]:
(
    Model()
    .add_variable("v1", 1.0)
    .add_reaction("r1", fns.twice, args=["v1"], stoichiometry={"v1": -1})
    .add_derived("d1", fns.twice, args=["r1"])  # derived from rate of r1
    .add_reaction("r2", fns.twice, args=["d1"], stoichiometry={"v1": -1})  # use d1!
    .get_dependent()
)

### Stoichiometries

Derive stoichiometry from other model components

> Hint: if you need the raw (as in not-calculated) stoichiometry of a variable, you can use `Model.get_raw_stoichiometries_of_variable(variable)`

In [None]:
(
    Model()
    .add_parameter("p1", 1.0)
    .add_variable("v1", 1.0)
    .add_reaction(
        "r1",
        fns.twice,
        args=["v1"],
        stoichiometry={"v1": Derived(fn=fns.twice, args=["p1"])},
    )
    .get_stoichiometries()
)

### Initial conditions

> Note: this just derives the value **once**.  
> This is **not** the same as a derived variable

In [None]:
(
    Model()
    .add_variables(
        {
            "v1": 1.0,
            "v2": Derived(fn=fns.twice, args=["v1"]),  # derive initial condition
        }
    )
    .get_initial_conditions()
)

In [None]:
unwrap(
    Simulator(
        Model().add_variables(
            {
                "v1": 1.0,
                "v2": Derived(fn=fns.twice, args=["v1"]),  # derive initial condition
            }
        )
    )
    .simulate(1)
    .get_result()
).get_new_y0()

## Data references

In [None]:
def average(light: pd.Series) -> float:
    return light.mean()


lights = pd.Series(
    data={"400nm": 200, "500nm": 300, "600nm": 400},
    dtype=float,
)


(
    Model()
    .add_data("light", lights)
    .add_derived("average_light", average, args=["light"])
    .get_dependent()
)

## Comparisons

In [None]:
ssc = compare.steady_states(
    get_sir(),
    get_sird(),
)

In [None]:
_ = ssc.plot_variables()
_ = ssc.plot_fluxes()

In [None]:
pc = compare.time_courses(
    get_sir(),
    get_sird(),
    time_points=np.linspace(0, 100, 101),
)

_ = pc.plot_variables_relative_difference()
_ = pc.plot_fluxes_relative_difference()

## Reports

In [None]:
md = report.markdown(
    get_sir(),
    get_sird(),
)

# IPython Display
Markdown(md)

In [None]:
def analyse_concentrations(m1: Model, m2: Model, img_dir: Path) -> tuple[str, Path]:
    r_old = unwrap(Simulator(m1).simulate(100).get_result())
    r_new = unwrap(Simulator(m2).simulate(100).get_result())
    fig = plot_difference(r_old.variables, r_new.variables)
    fig.savefig((path := img_dir / "concentration.png"), dpi=300)
    plt.close(fig)
    return "## Comparison of largest changing", path


md = report.markdown(
    get_sir(),
    get_sird(),
    analyses=[analyse_concentrations],
    img_path=TMP_DIR,
)

# IPython Display
Markdown(md)

## Metaprogramming

In [None]:
from mxlpy.meta import (
    generate_latex_code,
    generate_model_code_py,
)


In [None]:
print(generate_model_code_py(get_sir()))

In [None]:
print(generate_latex_code(get_sir()))

## Discussion: units

See https://github.com/Computational-Biology-Aachen/MxlPy/issues/26

# Mechanistic learning

## Surrogates

What **is** a surrogate?  

I will define it here as a replacement / approximation for another system / model.  
These *might* be learned from data, but don't necessarily need to.  


Examples of surrogates

- quasi-steady-states
- polynomials
- machine-learning models (e.g. torch)

You need to check the **validity** of doing these replacements yourself.  
One common criterium would be that the surrogated system is working at a much faster time scale.  
Then one can assume it to be in steady-state instantaneously relative to the model time.  

Surrogates in `MxlPy` can have

- one or multiple inputs
- one or multiple outputs
- one or multiple stoichiometries (factor x output)

### Quasi-steady-state

In [None]:
def distribute(s: float) -> tuple[float, float]:
    return s / 3, s * 2 / 3


# This creates two derived variables, but has no stoichiometries
(
    Model()
    .add_variables({"a": 1.0})
    .add_surrogate(
        "distribute",
        surrogates.qss.Surrogate(
            model=distribute,
            args=["a"],
            outputs=["a1", "a2"],
        ),
    )
    .get_dependent()
)

### Polynomial

In [None]:
from numpy.polynomial.polynomial import Polynomial

(
    Model()
    .add_variable("x", 1.0)
    .add_surrogate(
        "surrogate",
        surrogates.poly.Surrogate(
            model=Polynomial(coef=[2]),
            args=["x"],
            outputs=["y"],
        ),
    )
    .add_derived("z", fns.add, args=["x", "y"])
    .get_dependent()
)

### Build your own!

In [None]:
from mxlpy.types import AbstractSurrogate


@dataclass(kw_only=True)
class DoubleSurrogate(AbstractSurrogate):
    def predict(
        self, args: dict[str, float | pd.Series | pd.DataFrame]
    ) -> dict[str, float]:
        return dict(
            zip(
                self.outputs,
                (args[arg] * 2 for arg in self.args),
                strict=True,
            )
        )  # type: ignore


(
    Model()
    .add_variable("v1", 1.0)
    .add_surrogate(
        "surrogate",
        DoubleSurrogate(
            args=["v1"],
            outputs=["s1"],
        ),
    )
    .get_dependent()
)

## Carousels

In [None]:
from mxlpy import fit
from mxlpy.carousel import Carousel, ReactionTemplate

carousel = Carousel(
    get_sir(),
    {
        "infection": [
            ReactionTemplate(fn=fns.mass_action_2s, args=["s", "i", "beta"]),
            ReactionTemplate(
                fn=fns.michaelis_menten_2s,
                args=["s", "i", "beta", "km_bs", "km_bi"],
                additional_parameters={"km_bs": 0.1, "km_bi": 1.0},
            ),
        ],
        "recovery": [
            ReactionTemplate(fn=fns.mass_action_1s, args=["i", "gamma"]),
            ReactionTemplate(
                fn=fns.michaelis_menten_1s,
                args=["i", "gamma", "km_gi"],
                additional_parameters={"km_gi": 0.1},
            ),
        ],
    },
)

In [None]:
carousel_time_course = carousel.time_course(np.linspace(0, 100, 101))
variables_by_model = carousel_time_course.get_variables_by_model()

fig, ax = plot.one_axes()
plot.line_mean_std(variables_by_model["s"].unstack().T, label="s", ax=ax)
plot.line_mean_std(variables_by_model["i"].unstack().T, label="i", ax=ax)
plot.line_mean_std(variables_by_model["r"].unstack().T, label="r", ax=ax)
ax.legend()
plot.show()


In [None]:
data = unwrap(
    Simulator(get_sir().update_parameters({"beta": 0.3, "gamma": 0.15}))
    .simulate(100, steps=11)
    .get_result()
).variables

data.head()


In [None]:
res = fit.carousel_time_course(
    carousel,
    p0={
        "beta": 0.1,
        "gamma": 0.1,
        # specific to reaction templates
        # "km_bi": 1.0,
    },
    data=data,
)

best = res.get_best_fit().model

fig, ax = plot.one_axes()
plot.lines(
    unwrap(Simulator(best).simulate(100).get_result()).variables,
    ax=ax,
)
plot.reset_prop_cycle(ax=ax)
plot.lines(data, linestyle="dashed", ax=ax, legend=False)
plot.show()

In [None]:
best_fit = res.get_best_fit()

print(best_fit.best_pars)
print([rxn.fn.__name__ for rxn in best_fit.model.reactions.values()])


In [None]:
{i: v.loss for i, v in enumerate(res.fits)}

**Discussion**: 

- fit did not return the intended reactions
- but reactions do fit the data well


What mechanisms should we use to fit in the future?

# MxlBricks

- [Repo](https://github.com/Computational-Biology-Aachen/mxl-bricks)
- [Documentation](https://computational-biology-aachen.github.io/mxl-bricks/0.2.0/)

In [None]:
from mxlbricks import names as n
from mxlbricks.enzymes import (
    add_catalase,
    add_glycine_decarboxylase_yokota,
    add_glycine_transaminase_yokota,
    add_glycolate_oxidase_yokota,
    add_hpa_outflux,
    add_phosphoglycolate_influx,
    add_serine_glyoxylate_transaminase_irreversible,
)
from mxlbricks.utils import static


def get_yokota1985() -> Model:
    model = Model()
    model.add_variables(
        {
            n.glycolate(): 0.09,
            n.glyoxylate(): 0.7964601770483386,
            n.glycine(): 8.999999999424611,
            n.serine(): 2.5385608670239126,
            n.hydroxypyruvate(): 0.009782608695111009,
            n.h2o2(): 0.010880542843616855,
        }
    )

    add_phosphoglycolate_influx(model)
    add_glycolate_oxidase_yokota(model)
    add_glycine_transaminase_yokota(model)
    add_glycine_decarboxylase_yokota(
        model,
        e0=static(model, n.e0(n.glycine_decarboxylase()), 0.5),
    )
    add_serine_glyoxylate_transaminase_irreversible(model)
    add_hpa_outflux(model)
    add_catalase(model)
    return model

## Create new model that actually has oxygen concentration

In [None]:
from mxlbricks.enzymes import add_glycolate_oxidase


def get_photorespiration() -> Model:
    model = Model()
    model.add_variables(
        {
            n.glycolate(): 0.09,
            n.glyoxylate(): 0.7964601770483386,
            n.glycine(): 8.999999999424611,
            n.serine(): 2.5385608670239126,
            n.hydroxypyruvate(): 0.009782608695111009,
            n.h2o2(): 0.010880542843616855,
        }
    )
    model.add_parameter(n.o2(), 0.2)  # changed here

    add_phosphoglycolate_influx(model)
    add_glycolate_oxidase(model)  # changed here
    add_glycine_transaminase_yokota(model)
    add_glycine_decarboxylase_yokota(
        model,
        e0=static(model, n.e0(n.glycine_decarboxylase()), 0.5),
    )
    add_serine_glyoxylate_transaminase_irreversible(model)
    add_hpa_outflux(model)
    add_catalase(model)
    return model

In [None]:
d = get_photorespiration().get_dependent()

In [None]:
Markdown(
    report.markdown(
        get_yokota1985(),
        get_photorespiration(),
    )
)

In [None]:
tcc = compare.time_courses(
    get_yokota1985(),
    get_photorespiration(),
    time_points=np.linspace(0, 10, 101, dtype=float),
)

fig, axs = tcc.plot_variables_relative_difference()
remove_labels(axs)
grid_labels(axs, xlabel="Time / h", ylabel="Relative difference")
plt.show()

## Discussion: those are just mechanistic bricks, where is the `L`?  

- Should we provide learned surrogates of these models in a package like this?  
- They can get very large, where do we store them?
  - Downloader like `torch` does it? 

# Anti-patterns

## Don't use global variables

```python
kf: float = 1.0

def mass_action(x: float) -> float:
    return x * kf
```

**Why is this a bad idea?**

This makes it really hard from the outside to read what the dependencies of the model **actually** are

```python
from mxlpy import Model

def get_model() -> Model:
    return (
        Model()
        .add_variable("x", 1.0)
        .add_reaction(
            "v1",
            fn=mass_action,  # no notion of kf
            args=["x"],  # no notion of kf
            stoichiometry={"x": -1},
        )
    )
```

**What to do instead**

Make **all** your function inputs actual inputs passed by `args`.

```python
from mxlpy import Model

def mass_action(x: float, kf: float) -> float:
    return x * kf

def get_model() -> Model:
    return (
        Model()
        .add_variable("x", 1.0)
        .add_parameter("kf", 1.0)
        .add_reaction(
            "v1",
            fn=mass_action,
            args=["x", "kf"],
            stoichiometry={"x": -1},
        )
    )
```

If for some reason you **cannot** pass an argument via `args`, use a partially applied function and pass the value to the `get_model` function


```python
from mxlpy import Model
from functools import partial


def mass_action(x: float, kf: float) -> float:
    return x * kf


def get_model(kf: float = 1) -> Model:
    fn = partial(mass_action, kf=kf)

    return (
        Model()
        .add_variable("x", 1.0)
        .add_parameter("kf", 1.0)
        .add_reaction(
            "v1",
            fn=fn,
            args=["x"],
            stoichiometry={"x": -1},
        )
    )
```


## Don't use import side effects

```python
import pandas as pd

DATA: pd.Series = pd.read_csv(Path("data.csv"))

kf: float = 1.0

def mass_action(x: float) -> float:
    return x * DATA['kf']
```

**Why is this a bad idea?**

This makes it really hard from the outside to read what the dependencies of the model **actually** are.

**What to do instead**

Read required data files in your **main** file, so it is easy to see which data actually needs to be loaded.
Pass a reference to that data to your `get_model` function.

```python
import pandas as pd

from mxlpy import Model

def mass_action(x: float, data: pd.Series) -> float:
    return x * data['kf']

def get_model(data: pd.Series) -> Model:
    return (
        Model()
        .add_variable("x", 1.0)
        .add_data("name", data)
        .add_reaction(
            "v1",
            fn=mass_action,
            args=["x", "data"],
            stoichiometry={"x": -1},
        )
    )
```

In case `.add_data` does not work for you, you can always create a partially applied function.
In that case, it is still

```python
import pandas as pd

from mxlpy import Model

def mass_action(x: float, data: pd.Series) -> float:
    return x * data['kf']

# Clear that model depends on data
def get_model(data: pd.Series) -> Model:
    return (
        Model()
        .add_variable("x", 1.0)
        .add_reaction(
            "v1",
            fn=partial(mass_action, data=data),  # clear where data is used
            args=["x", "data"],
            stoichiometry={"x": -1},
        )
    )
```


## Don't create intermediate parameters or variables

```python
from mxlpy import Model


def mass_action(x: float, kf: float) -> float:
    return x * kf


def wrapped() -> float:
    x = 1.0  # BAD: Don't create variables here
    kf = 1.0  # BAD: Don't create parameters here
    return mass_action(x, kf)


def get_model() -> Model:
    return (
        Model()
        .add_reaction(
            "v1",
            fn=wrapped,
            args=[],  # BAD: what about x and kf?
            stoichiometry={"x": -1},
        )
    )
```

**Why is this bad?**

This makes it really hard from the outside to read what the dependencies of the model **actually** are.
Also, since none of the intermediate parameters or variables can be 'seen' by `mxlpy`, you have no way of reading out their values.
Thus, if there is an error in one of them, it is really hard to actually find that error.

**What to do instead**

Make **all** your function inputs actual inputs passed by `args`.
If you have parameters or variables that depend on other parameters or variables, use `add_derived`.


```python
from mxlpy import Model


def mass_action(x: float, kf: float) -> float:
    return x * kf


def get_model() -> Model:
    return (
        Model()
        .add_variable("x", 1.0)
        .add_parameter("kf", 1.0)
        .add_reaction(
            "v1",
            fn=mass_action,
            args=["x", "kf"],
            stoichiometry={"x": -1},
        )
    )
```


## Do: wrap model construction in function

```python
from mxlpy import Model

def get_model() -> Model:
    return (
        Model
        .add_variables(...)
        .add_parameters(...)
        .add_reaction(...)
    )
```

**Why is this a good idea?**

Quite often you have analyses that will change some component of the model, e.g. the value of a parameter:

```python
def analysis1(model: Model) -> None:
    model.update_parameter(...)
    ...

def analysis2(model: Model) -> None:
    model.update_variable(...)
    ...
```

If you don't keep track of **reverting** all of these changes, you will introduce subtle bugs in your analyses, where the results depend on previous results

```python
model = ...

analysis1(model)
analysis2(model)  # BAD: changes of analysis1 are still in effect
```

By re-creating the model every time, you make all analyses independent and avoid that situation

```python
analysis1(get_model())
analysis2(get_model())  # GOOD: analysis2 is independent of analysis1
```
