In [None]:
from __future__ import annotations

import itertools as it
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import tqdm
from torch import nn, optim

from modelbase2.mc import parallelise
from modelbase2.mc._scan import _empty_flux_series
from modelbase2 import Model, Simulator, TorchSurrogate
from modelbase2.mc import Cache


def filter_stoichiometry(
    model: Model,
    stoichiometry: dict[str, float],
) -> dict[str, float]:
    """Only use components that are actually compounds in the model"""
    new: dict[str, float] = {}
    for k, v in stoichiometry.items():
        if k in model._variables:
            new[k] = v
        elif k not in model._ids:  # noqa: SLF001
            msg = f"Missing component {k}"
            raise KeyError(msg)
    return new


def constant(x: float) -> float:
    return x


def michaelis_menten_2s(
    s1: float,
    s2: float,
    vmax: float,
    km1: float,
    km2: float,
    ki1: float,
) -> float:
    return vmax * s1 * s2 / (ki1 * km2 + km2 * s1 + km1 * s2 + s1 * s2)

# Create model

In [None]:
def model_to_be_replaced() -> Model:
    model = Model()
    model.add_variables({"x2": 0.0, "x3": 0.0})
    model.add_parameters(
        {
            # These need to be static in order to train the model later
            "x1": 1.0,
            "ATP": 1.0,
            "NADPH": 1.0,
            # v2
            "vmax_v2": 2.0,
            "km_v2_1": 0.1,
            "km_v2_2": 0.1,
            "ki_v2": 0.1,
            # v3
            "vmax_v3": 2.0,
            "km_v3_1": 0.2,
            "km_v3_2": 0.2,
            "ki_v3": 0.2,
        }
    )

    model.add_reaction(
        "v2",
        michaelis_menten_2s,
        filter_stoichiometry(model, {"x1": -1, "ATP": -1, "x2": 1}),
        ["x1", "ATP", "vmax_v2", "km_v2_1", "km_v2_2", "ki_v2"],
    )
    model.add_reaction(
        "v3",
        michaelis_menten_2s,
        filter_stoichiometry(model, {"x1": -1, "NADPH": -1, "x3": 1}),
        ["x1", "ATP", "vmax_v3", "km_v3_1", "km_v3_2", "ki_v3"],
    )
    model.add_reaction("x2_out", constant, {"x2": -1}, ["x2"])
    model.add_reaction("x3_out", constant, {"x3": -1}, ["x3"])

    return model


# Example plot of this models behaviour
_ = Simulator(model_to_be_replaced()).simulate_and(10).get_fluxes().plot()

# Create data

In [None]:
def ss_flux(
    params: pd.Series,
    model: Model,
) -> pd.Series:
    flux = (
        Simulator(model.update_parameters(params.to_dict()))
        .simulate_to_steady_state_and()
        .get_fluxes()
    )
    if flux is None:
        return _empty_flux_series(model)
    return flux.iloc[-1]


inp = pd.DataFrame(
    it.product(
        np.linspace(0, 2.0, 21),
        np.linspace(0, 2.0, 21),
        np.linspace(0, 2.0, 21),
    ),
    columns=["x1", "ATP", "NADPH"],
)


res = (
    pd.concat(
        parallelise(
            partial(
                ss_flux,
                model=model_to_be_replaced(),
            ),
            inputs=list(inp.iterrows()),
            cache=Cache(),
        )
    )
    .unstack()
    .fillna(0)
)

features = inp
target = res.loc[:, ["x2_out", "x3_out"]]

# Train Surrogate

In [None]:
class Approximator(nn.Module):
    def __init__(self, n_inputs: int, n_outputs: int) -> None:
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(n_inputs, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, n_outputs),
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


device = torch.device("cpu")
aprox = Approximator(
    n_inputs=len(features.columns),
    n_outputs=len(target.columns),
).to(device)
optimizer = optim.RMSprop(aprox.parameters(), lr=1e-3)

X = torch.Tensor(features.to_numpy(), device=device)
Y = torch.Tensor(target.to_numpy(), device=device)

# TODO: batch the training
losses = {}
for i in tqdm.trange(2000):
    optimizer.zero_grad()
    loss = torch.mean(torch.abs(aprox(X) - Y))
    loss.backward()
    optimizer.step()
    losses[i] = loss.detach().numpy()

_ = pd.Series(losses, dtype=float).plot(xlabel="Epoch", ylabel="Loss")


# Insert surrogate into model

In [None]:
def get_model() -> Model:
    model = Model()
    model.add_variables({"x1": 1.0, "x2": 0.0, "x3": 0.0, "ATP": 2.0, "NADPH": 0.1})

    # Adding the surrogate
    model.add_surrogate(
        "surrogate",
        TorchSurrogate(
            model=aprox,
            inputs=["x1", "ATP", "NADPH"],
            stoichiometries={
                "v2": {"x1": -1, "x2": 1, "ATP": -1},
                "v3": {"x1": -1, "x3": 1, "NADPH": -1},
            },
        ),
    )

    # Note that besides the surrogate we haven't defined any other reaction!
    # We could obviously though
    return model


c, v = Simulator(get_model()).simulate_and(2).get_full_concs_and_fluxes()

# FIXME: note that NADPH get's negative
# At least the rates seem to get 0 around the tie when x1 is 0
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
c.plot(ax=ax1, xlabel="time / s", ylabel="concentration / mM")
v.plot(ax=ax2, xlabel="time / s", ylabel="flux / (mM / s)")
plt.show()

# Do some manual tests

In [None]:
surrogate = TorchSurrogate(
    model=aprox,
    inputs=["x1", "ATP", "NADPH"],
    stoichiometries={
        "v2": {"x1": -1, "x2": 1, "ATP": -1},
        "v3": {"x1": -1, "x3": 1, "NADPH": -1},
    },
)

print(surrogate.predict(np.array([0.0, 0.0, 0.0])))
print(surrogate.predict(np.array([1.0, 0.0, 0.0])))
print(surrogate.predict(np.array([0.0, 1.0, 0.0])))
print(surrogate.predict(np.array([0.0, 0.0, 1.0])))
print(surrogate.predict(np.array([1.0, 0.0, 1.0])))
print(surrogate.predict(np.array([1.0, 1.0, 1.0])))