In [None]:
from functools import partial

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from modelbase2 import LinearLabelMapper, Model, Simulator, npe, plot, scan
from modelbase2.distributions import Uniform, sample
from modelbase2.fns import (
    michaelis_menten_1s,
)
from modelbase2.npe import AbstractEstimator
from modelbase2.parallel import parallelise
from modelbase2.types import unwrap, unwrap2


def _worker(
    x: tuple[tuple[int, pd.Series], tuple[int, pd.Series]],
    mapper: LinearLabelMapper,
    time: float,
    initial_labels: dict[str, int | list[int]],
) -> pd.Series:
    (_, y_ss), (_, v_ss) = x
    return unwrap(
        Simulator(mapper.build_model(y_ss, v_ss, initial_labels=initial_labels))
        .simulate(time)
        .get_concs()
    ).iloc[-1]


def get_label_distribution_at_time(
    model: Model,
    label_variables: dict[str, int],
    label_maps: dict[str, list[int]],
    time: float,
    initial_labels: dict[str, int | list[int]],
    ss_concs: pd.DataFrame,
    ss_fluxes: pd.DataFrame,
) -> pd.DataFrame:
    mapper = LinearLabelMapper(
        model,
        label_variables=label_variables,
        label_maps=label_maps,
    )

    return pd.DataFrame(
        parallelise(
            partial(_worker, mapper=mapper, time=time, initial_labels=initial_labels),
            inputs=list(
                enumerate(zip(ss_concs.iterrows(), ss_fluxes.iterrows(), strict=True))
            ),
            cache=None,
        )
    ).T


def inverse_parameter_elasticity(
    estimator: AbstractEstimator,
    datum: pd.Series,
    *,
    normalized: bool = True,
    displacement: float = 1e-4,
) -> pd.DataFrame:
    ref = estimator.predict(datum).iloc[0, :]

    coefs = {}
    for name, value in datum.items():
        up = coefs[name] = estimator.predict(
            pd.Series(datum.to_dict() | {name: value * 1 + displacement})
        ).iloc[0, :]
        down = coefs[name] = estimator.predict(
            pd.Series(datum.to_dict() | {name: value * 1 - displacement})
        ).iloc[0, :]
        coefs[name] = (up - down) / (2 * displacement * value)

    coefs = pd.DataFrame(coefs)
    if normalized:
        coefs *= datum / ref.to_numpy()

    return coefs

In [None]:
def get_closed_cycle() -> tuple[Model, dict[str, int], dict[str, list[int]]]:
    """

    | Reaction       | Labelmap |
    | -------------- | -------- |
    | x1 ->[v1] x2   | [0, 1]   |
    | x2 ->[v2a] x3  | [0, 1]   |
    | x2 ->[v2b] x3  | [1, 0]   |
    | x3 ->[v3] x1   | [0, 1]   |

    """
    model = (
        Model()
        .add_parameters(
            {
                "vmax_1": 1.0,
                "km_1": 0.5,
                "vmax_2a": 1.0,
                "vmax_2b": 1.0,
                "km_2": 0.5,
                "vmax_3": 1.0,
                "km_3": 0.5,
            }
        )
        .add_variables({"x1": 1.0, "x2": 0.0, "x3": 0.0})
        .add_reaction(
            "v1",
            michaelis_menten_1s,
            stoichiometry={"x1": -1, "x2": 1},
            args=["x1", "vmax_1", "km_1"],
        )
        .add_reaction(
            "v2a",
            michaelis_menten_1s,
            stoichiometry={"x2": -1, "x3": 1},
            args=["x2", "vmax_2a", "km_2"],
        )
        .add_reaction(
            "v2b",
            michaelis_menten_1s,
            stoichiometry={"x2": -1, "x3": 1},
            args=["x2", "vmax_2b", "km_2"],
        )
        .add_reaction(
            "v3",
            michaelis_menten_1s,
            stoichiometry={"x3": -1, "x1": 1},
            args=["x3", "vmax_3", "km_3"],
        )
    )
    label_variables: dict[str, int] = {"x1": 2, "x2": 2, "x3": 2}
    label_maps: dict[str, list[int]] = {
        "v1": [0, 1],
        "v2a": [0, 1],
        "v2b": [1, 0],
        "v3": [0, 1],
    }
    return model, label_variables, label_maps

In [None]:
model, label_variables, label_maps = get_closed_cycle()

ss_concs, ss_fluxes = unwrap2(
    Simulator(model)
    .update_parameters({"vmax_2a": 1.0, "vmax_2b": 0.5})
    .simulate_to_steady_state()
    .get_full_concs_and_fluxes()
)
mapper = LinearLabelMapper(
    model,
    label_variables=label_variables,
    label_maps=label_maps,
)

_, axs = plot.relative_label_distribution(
    mapper,
    unwrap(
        Simulator(
            mapper.build_model(
                ss_concs.iloc[-1], ss_fluxes.iloc[-1], initial_labels={"x1": 0}
            )
        )
        .simulate(5)
        .get_concs()
    ),
    sharey=True,
    n_cols=3,
)

axs[0].set_ylabel("Relative label distribution")
axs[1].set_xlabel("Time / s")
plt.show()

In [None]:
from modelbase2.distributions import Normal


targets = sample(
    {
        "vmax_2b": Normal(0.5, 0.1),
    },
    n=1000,
).clip(lower=0)

ax = sns.kdeplot(targets, fill=True)
ax.set_title("Prior")

In [None]:
ss_concs, ss_fluxes = scan.steady_state(model, parameters=targets)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
_, ax = plot.violins(ss_concs, ax=ax1)
ax.set_ylabel("Concentration / a.u.")
_, ax = plot.violins(ss_fluxes, ax=ax2)
ax.set_ylabel("Flux / a.u.")

In [None]:
features = get_label_distribution_at_time(
    model=model,
    label_variables=label_variables,
    label_maps=label_maps,
    time=5,
    ss_concs=ss_concs,
    ss_fluxes=ss_fluxes,
    initial_labels={"x1": 0},
)
_, ax = plot.violins(features)
ax.set_ylabel("Relative label distribution")

features = features  # .loc[:, ["x1__0", "x1__1", "x3__0", "x3__1"]]

In [None]:
estimator, losses = npe.train_torch_ss_estimator(
    features=features,
    targets=targets,
    epochs=2_000,
)

ax = losses.plot()
ax.set_ylim(0, None)

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=False,
)

ax = sns.kdeplot(targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(features)

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
ax2.set_ylim(*ax1.get_ylim())
plt.show()

# Inverse parameter sensitivity

In [None]:
_ = plot.heatmap(inverse_parameter_elasticity(estimator, features.iloc[0]))

In [None]:
elasticities = pd.DataFrame(
    {
        k: inverse_parameter_elasticity(estimator, i).loc["vmax_2b"]
        for k, i in features.iterrows()
    }
).T

_ = plot.violins(elasticities)