In [None]:
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from example_models import get_example1

from modelbase2 import Cache, Simulator, npe, scans
from modelbase2.distributions import LogNormal, sample

if TYPE_CHECKING:
    from modelbase2.types import Array, ModelProtocol


def create_steady_state_data(
    model: ModelProtocol,
    parameters: pd.DataFrame,
    cache: Cache | None,
    y0: dict[str, float] | None = None,
) -> pd.DataFrame:
    return pd.concat(
        scans.parameter_scan_ss(
            model=model,
            parameters=parameters,
            y0=y0,
            cache=cache,
        ),
        axis=1,
    ).reset_index(drop=True)


def create_time_series_data(
    model: ModelProtocol,
    parameters: pd.DataFrame,
    time_points: Array,
    cache: Cache | None,
    y0: dict[str, float] | None = None,
) -> pd.DataFrame:
    _, c, v = scans.parameter_scan_time_series(
        model=model,
        parameters=parameters,
        time_points=time_points,
        y0=y0,
        cache=cache,
    )
    df = pd.concat(
        (pd.concat(c), pd.concat(v)),
        axis=1,
    )
    df.index.names = ["n", "time"]
    return df


# Create model

In [None]:
# Example plot of this models behaviour
_ = Simulator(get_example1()).simulate(10).get_fluxes().plot()

# Create data

In [None]:
targets = sample(
    {
        "x1": LogNormal(mean=1.0, sigma=0.3),
        "ATP": LogNormal(mean=0.7, sigma=0.1),
        "NADPH": LogNormal(mean=0.3, sigma=0.2),
    },
    n=10_000,
)

time_points = np.linspace(0, 10, 11)

ss_data = create_steady_state_data(
    get_example1(),
    parameters=targets,
    cache=Cache(Path(".cache") / "npe-ss"),
)

ts_data = create_time_series_data(
    get_example1(),
    parameters=targets,
    time_points=time_points,
    cache=Cache(Path(".cache") / "npe-ts"),
)

# Train NPE on steady-state data


In [None]:
features = ss_data.loc[:, ["x2", "x3"]]

estimator, losses = npe.train_torch_ss_estimator(
    features=features,
    targets=targets,
    epochs=5_000,
)

losses.plot()

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=True,
)

ax = sns.kdeplot(targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(features)

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
plt.show()

# Train NPE on time series data

In [None]:
features = ts_data.loc[:, ["x2", "x3"]]


estimator, losses = npe.train_torch_time_series_estimator(
    features=features,
    targets=targets,
    epochs=5_000,
)

losses.plot()

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=True,
)

ax = sns.kdeplot(targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(features)

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
plt.show()