In [None]:
import libsbml

In [None]:
sbml_document: libsbml.SBMLDocument = libsbml.readSBMLFromFile(
    "/home/lena/master-thesis/repos/optimal-control/data/Giordano2020.xml"
)

model: libsbml.Model = sbml_document.getModel()
model.getName()

In [None]:
print(list(model.getListOfCompartments()))
print(list(model.getListOfCompartmentTypes()))
print(list(model.getListOfEvents()))
print(list(model.getListOfConstraints()))
print(list(model.getListOfFunctionDefinitions()))
print(list(model.getListOfInitialAssignments()))
print(list(model.getListOfParameters()))
print(list(model.getListOfRules()))
print(list(model.getListOfSpecies()))
print(list(model.getListOfReactions()))
print(list(model.getListOfSpeciesTypes()))
print(list(model.getListOfUnitDefinitions()))

In [None]:
print("--Species--")
for i in range(model.getNumSpecies()):
    species: libsbml.Species = model.getSpecies(i)
    print(
        f"{species.getName()} {species.getIdAttribute()} {species.getInitialConcentration()}"
    )

print("")
print("--Parameters--")
for i in range(model.getNumParameters()):
    parameter: libsbml.Parameter = model.getParameter(i)
    print(f"{parameter.getName()} {parameter.getIdAttribute()} {parameter.getValue()}")

print("")
print("--Reactions--")
for i in range(model.getNumReactions()):
    reaction: libsbml.Reaction = model.getReaction(i)
    kinetic_law: libsbml.KineticLaw = reaction.getKineticLaw()

    print(f"{reaction.getName()} {reaction.getIdAttribute()}")

    for j in range(kinetic_law.getNumParameters()):
        parameter: libsbml.Parameter = kinetic_law.getParameter(j)
        print(
            f"\t{parameter.getName()} {parameter.getIdAttribute()} {parameter.getValue()}"
        )

    math: libsbml.ASTNode = kinetic_law.getMath()
    print(f"\t{math.getNumChildren()} {math.getChild(1).getOperatorName()}")
    print(f"\t{libsbml.formulaToL3String(math)}")
    print(f"\t{libsbml.formulaToL3String(math.getChild(0))}")
    print(f"\t{libsbml.formulaToL3String(math.getChild(1))}")

    """
    if math.getNumChildren() > 1:
        print("(" + (math.getLeftChild()).getName())

        for n in range (1, math.getNumChildren()):
            print(", " + str((math.getChild(n)).getName()))
    """


print("")
print("--Function Definitions--")
for i in range(model.getNumFunctionDefinitions()):
    fdef: libsbml.FunctionDefinition = model.getFunctionDefinition(i)
    print(f"{fdef.getName()} {fdef.getIdAttribute()}")

In [None]:
from typing import List, Union, Callable, Dict
from functools import partial


def species_to_dict(species_list: libsbml.ListOfSpecies) -> Dict[str, float]:
    species_dict = {}
    for i in range(len(species_list)):
        species: libsbml.Species = species_list[i]
        species_dict[species.getIdAttribute()] = species.getInitialConcentration()

    return species_dict


def parameters_to_dict(parameter_list: libsbml.ListOfParameters) -> Dict[str, float]:
    parameter_dict = {}
    for i in range(len(parameter_list)):
        parameter: libsbml.Parameter = parameter_list[i]
        parameter_dict[parameter.getIdAttribute()] = parameter.getValue()

    return parameter_dict


def species_references_to_dict(
    species_references_list: libsbml.ListOfSpeciesReferences,
) -> Dict[str, float]:
    species_reference_dict = {}
    for i in range(len(species_references_list)):
        species_reference: libsbml.SpeciesReference = species_references_list[i]
        species_reference_dict[
            species_reference.getSpecies()
        ] = species_reference.getStoichiometry()

    return species_reference_dict


def reactions_to_dict(reaction_list: libsbml.ListOfReactions):
    reaction_dict = {}
    for i in range(len(reaction_list)):
        reaction: libsbml.Reaction = reaction_list[i]
        kinetic_law: libsbml.KineticLaw = reaction.getKineticLaw()
        reaction_dict[reaction.getIdAttribute()] = {
            "parameters": parameters_to_dict(kinetic_law.getListOfParameters()),
            "reactants": species_references_to_dict(reaction.getListOfReactants()),
            "products": species_references_to_dict(reaction.getListOfProducts()),
        }

    return reaction_dict


def model_to_dict(model: libsbml.Model) -> Dict[str, Dict[str, float]]:
    species_dict = species_to_dict(model.getListOfSpecies())
    reaction_dict = reactions_to_dict(model.getListOfReactions())
    parameter_dict = parameters_to_dict(model.getListOfParameters())

    return {
        "species": species_dict,
        "parameters": parameter_dict,
        "reactions": reaction_dict,
    }


def eval_node(node: libsbml.ASTNode, values: Dict[str, float]) -> float:
    if node.isOperator():
        # Unary negation operator
        if node.getNumChildren() == 1:
            val = eval_node(node.getChild(0), values)
            return -val

        # Binary operators
        elif node.getNumChildren() == 2:
            operators = {
                libsbml.AST_PLUS: lambda a, b: a + b,
                libsbml.AST_MINUS: lambda a, b: a - b,
                libsbml.AST_TIMES: lambda a, b: a * b,
                libsbml.AST_DIVIDE: lambda a, b: a / b,
                libsbml.AST_POWER: lambda a, b: a**b,
            }

            op = operators[node.getType()]
            left = eval_node(node.getLeftChild(), values)
            right = eval_node(node.getRightChild(), values)
            val = op(left, right)

            return val

        else:
            raise ValueError(node.getNumChildren())

    elif node.isFunction():
        functions = {libsbml.AST_FUNCTION_POWER: lambda a, b: a**b}

        fn = functions[node.getType()]
        left = eval_node(node.getLeftChild(), values)
        right = eval_node(node.getRightChild(), values)
        val = fn(left, right)

        return val

    elif node.isNumber():
        return node.getReal()

    elif node.isName():
        return values[node.getName()]

    else:
        raise TypeError(
            "Unsupported AST node type:", node, node.getName(), node.getType()
        )


def ast_to_lambda(root: libsbml.ASTNode) -> partial[float]:
    return partial(eval_node, node=root)


def model_to_lambda(model: libsbml.Model) -> partial[dict]:
    extra_values = {"Italy": 1.0}
    global_parameters = parameters_to_dict(model.getListOfParameters())

    def ode_fn(y: dict) -> dict:
        dy_dt = {k: 0.0 for k in y}

        reaction: libsbml.Reaction
        for reaction in model.getListOfReactions():
            kinetic_law: libsbml.KineticLaw = reaction.getKineticLaw()

            reaction_fn = ast_to_lambda(kinetic_law.getMath())
            local_parameters = parameters_to_dict(kinetic_law.getListOfParameters())

            values = {}
            values.update(y)
            values.update(extra_values)
            values.update(global_parameters)
            values.update(local_parameters)

            reaction_vel = reaction_fn(values=values)

            reactant: libsbml.SpeciesReference
            for reactant in reaction.getListOfReactants():
                dy_dt[reactant.getSpecies()] -= (
                    reactant.getStoichiometry() * reaction_vel
                )

            product: libsbml.SpeciesReference
            for product in reaction.getListOfProducts():
                dy_dt[product.getSpecies()] += product.getStoichiometry() * reaction_vel

        return dy_dt

    return ode_fn

In [None]:
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import diffrax

species_dict = species_to_dict(model.getListOfSpecies())
model_fn = model_to_lambda(model)

terms = diffrax.ODETerm(lambda t, y, args: model_fn(y))

In [None]:
solution = diffrax.diffeqsolve(
    terms=terms,
    solver=diffrax.Kvaerno5(),
    t0=0.0,
    t1=100.0,
    dt0=0.01,
    y0=species_dict,
    saveat=diffrax.SaveAt(dense=True),
    stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-4),
    max_steps=100000,
    throw=False,
)

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt

plt.figure()
plt.plot(jax.vmap(solution.evaluate)(jnp.linspace(0.0, 0.1, 100))["ACE"])
plt.show()