In [None]:
from __future__ import annotations

from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
from example_models import get_example1

from modelbase2 import Cache, Simulator, npe
from modelbase2.distributions import LogNormal, sample


# Create model

In [None]:
# Example plot of this models behaviour
_ = Simulator(get_example1()).simulate(10).get_fluxes().plot()

# Create data

In [None]:
targets = sample(
    {
        "x1": LogNormal(mean=1.0, sigma=0.3),
        "ATP": LogNormal(mean=0.7, sigma=0.1),
        "NADPH": LogNormal(mean=0.3, sigma=0.2),
    },
    n=10_000,
)

features = npe.create_ss_flux_data(
    get_example1(),
    parameters=targets,
    cache=Cache(Path(".cache") / "dist"),
)

# Train NPE

In [None]:
estimator, losses = npe.train_torch_estimator(
    features=features,
    targets=targets,
    epochs=10_000,
)

losses.plot()

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=True,
)

ax = sns.kdeplot(targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(features.to_numpy())

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
plt.show()