In [None]:
from __future__ import annotations

import itertools as it
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from example_models import get_example1

from modelbase2 import Cache, Model, Simulator
from modelbase2.surrogates import create_ss_flux_data, train_torch_surrogate


# Create model

In [None]:
# Example plot of this models behaviour
_ = Simulator(get_example1()).simulate(10).get_fluxes().plot()


# Create data

In [None]:
features = pd.DataFrame(
    it.product(
        np.linspace(0, 2.0, 21),
        np.linspace(0, 2.0, 21),
        np.linspace(0, 2.0, 21),
    ),
    columns=["x1", "ATP", "NADPH"],
)

targets = create_ss_flux_data(
    get_example1(),
    features,
    cache=Cache(Path(".cache") / "linear"),
).loc[:, ["x2_out", "x3_out"]]

# Train Surrogate

In [None]:
surrogate, loss = train_torch_surrogate(
    features=features,
    targets=targets,
    epochs=2000,
    surrogate_inputs=["x1", "ATP", "NADPH"],
    surrogate_stoichiometries={
        "v2": {"x1": -1, "x2": 1, "ATP": -1},
        "v3": {"x1": -1, "x3": 1, "NADPH": -1},
    },
)

loss.plot()

## Get predictions of rates

In [None]:
print(surrogate.predict(np.array([0.0, 0.0, 0.0])))
print(surrogate.predict(np.array([1.0, 0.0, 0.0])))
print(surrogate.predict(np.array([0.0, 1.0, 0.0])))
print(surrogate.predict(np.array([0.0, 0.0, 1.0])))
print(surrogate.predict(np.array([1.0, 0.0, 1.0])))
print(surrogate.predict(np.array([1.0, 1.0, 1.0])))

# Insert surrogate into model

In [None]:
def get_model() -> Model:
    model = Model()
    model.add_variables(
        {
            "x1": 1.0,
            "x2": 0.0,
            "x3": 0.0,
            "ATP": 2.0,
            "NADPH": 0.1,
        }
    )

    # Adding the surrogate
    model.add_surrogate("surrogate", surrogate)

    # Note that besides the surrogate we haven't defined any other reaction!
    # We could have though
    return model


c, v = Simulator(get_model()).simulate(0.8).get_full_concs_and_fluxes()

# FIXME: note that NADPH get's negative
# At least the rates seem to get 0 around the tie when x1 is 0
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
c.plot(ax=ax1, xlabel="time / s", ylabel="concentration / mM")
v.plot(ax=ax2, xlabel="time / s", ylabel="flux / (mM / s)")
plt.show()