In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

import diffrax
import matplotlib.pyplot as plt
import optax
from jaxtyping import Array, PyTree

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments.examples as examples
import optimal_control.environments.sbml as sbmlenv
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers

In [None]:
import optimal_control.sbml as sbmlutils

sbmlutils.pprint_model("/home/lena/master-thesis/repos/optimal-control/data/Giordano2020.xml")

In [None]:
import rich.console
import rich.table
import libsbml
from typing import Union

def pprint_model(
    model_or_filepath: Union[libsbml.Model, str],
    console: rich.console.Console = rich.console.Console(),
):
    """Pretty-prints a libsbml.Model"""

    model: libsbml.Model
    if isinstance(model_or_filepath, libsbml.Model):
        model = model_or_filepath
    else:
        model = libsbml.readSBMLFromFile(model_or_filepath).getModel()

    # Species
    table = rich.table.Table(title="Species")
    table.add_column("Species")
    table.add_column("ID")
    table.add_column("Initial Concentration")

    species: libsbml.Species
    for species in model.getListOfSpecies():
        table.add_row(
            species.getName(),
            species.getIdAttribute(),
            f"{species.getInitialConcentration()}",
        )

    console.print(table)

    # Parameters
    table = rich.table.Table(title="Parameters")
    table.add_column("Parameter")
    table.add_column("ID")
    table.add_column("Value")

    parameter: libsbml.Parameter
    for parameter in model.getListOfParameters():
        table.add_row(
            parameter.getName(),
            parameter.getIdAttribute(),
            f"{parameter.getValue()}",
        )

    console.print(table)

    # Reactions
    table = rich.table.Table(title="Reactions")
    table.add_column("Reaction")
    table.add_column("ID")
    table.add_column("Kinetic Law")

    reaction: libsbml.Reaction
    for reaction in model.getListOfReactions():
        table.add_row(
            reaction.getName(),
            reaction.getIdAttribute(),
            f"{libsbml.formulaToL3String(reaction.getKineticLaw().getMath())}",
        )

    console.print(table)

    # Function Definitions
    table = rich.table.Table(title="Function Definitions")
    table.add_column("Function")
    table.add_column("ID")
    table.add_column("Definition")

    fdef: libsbml.FunctionDefinition
    for fdef in model.getListOfFunctionDefinitions():
        table.add_row(
            fdef.getName(),
            fdef.getIdAttribute(),
            f"{libsbml.formulaToL3String(fdef.getMath())}",
        )

    console.print(table)

    # Compartments
    table = rich.table.Table(title="Compartments")
    table.add_column("Compartment")
    table.add_column("ID")
    table.add_column("Size")

    compartment: libsbml.Compartment
    for compartment in model.getListOfCompartments():
        table.add_row(
            compartment.getName(),
            compartment.getIdAttribute(),
            f"{compartment.getSize()}",
        )

    console.print(table)

pprint_model("/home/lena/master-thesis/repos/optimal-control/data/Giordano2020.xml")

In [None]:
def control_output_fn(control_values: Array) -> PyTree:
    overrides = {}

    # spread_factor = 1 - jax.nn.sigmoid(control_values[0]) * 0.5
    spread_factor = jnp.exp(-control_values[0])
    overrides["alpha"] = 0.57 * spread_factor
    overrides["beta"] = 0.011 * spread_factor
    overrides["gamma"] = 0.456 * spread_factor
    overrides["delta"] = 0.011 * spread_factor

    # testing_factor = 1 + jax.nn.sigmoid(control_values[1])
    testing_factor = 2 * (1 - jnp.exp(-control_values[1]))
    overrides["epsilon"] = 0.171 * testing_factor
    overrides["theta"] = 0.371 * testing_factor

    return overrides


environment: sbmlenv.SBMLEnvironment = sbmlenv.SBMLEnvironment(
    model_or_filepath="/home/lena/master-thesis/repos/optimal-control/data/Giordano2020.xml",
    control_output_fn=control_output_fn,
    t0=0.0,
    t1=1000.0,
    dt0=0.1,
    saveat=diffrax.SaveAt(t1=True, dense=False),
    solver=diffrax.Dopri8(),
    stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-4, dtmax=1.0),
)
environment_state = environment.init()

solver = solvers.DirectSolver(
    optax.adam(learning_rate=1e-4), num_control_points=1024, ignore_nans=False
)

key = jax.random.PRNGKey(1234)
key, subkey = jax.random.split(key)
control = controls.ImplicitControl(
    controls.Siren(
        in_features=1, out_features=2, hidden_features=64, hidden_layers=2, key=subkey
    ),
    0.0,
    1000.0,
)


def reward_fn(solution: diffrax.Solution) -> float:
    # ts = jnp.linspace(0.0, 365.0, 1024)
    # ys = jax.vmap(solution.evaluate)(ts)

    # return jnp.mean(ys["Susceptible"])

    return solution.ys["Susceptible"][-1]


constraint_chain = [
    constraints.NonNegativeConstantIntegralConstraint(integral=jnp.asarray([0.25, 0.25]))
]

In [None]:
optimized_reward, optimized_control = trainers.solve_optimal_control_problem(
    num_train_steps=1024 * 4,
    environment=environment,
    reward_fn=reward_fn,
    constraint_chain=constraint_chain,
    solver=solver,
    control=control,
    key=key,
    pbar_interval=8,
)

In [None]:
environment: sbmlenv.SBMLEnvironment = sbmlenv.SBMLEnvironment(
    model_or_filepath="/home/lena/master-thesis/repos/optimal-control/data/Giordano2020.xml",
    control_output_fn=control_output_fn,
    t0=0.0,
    t1=1000.0,
    dt0=0.1,
    saveat=diffrax.SaveAt(t1=True, dense=True),
    solver=diffrax.Dopri8(),
    stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-4),
)
environment_state = environment.init()

In [None]:
from optimal_control.solvers.base import build_control

start_control = build_control(control, constraint_chain, 1024)

eval_control = build_control(optimized_control, constraint_chain, 1024)
eval_solution: diffrax.Solution = environment.integrate(
    eval_control, environment_state, None
)

eval_solution_no_: diffrax.Solution = environment.integrate(
    eval_control, environment_state, None
)

In [None]:
eval_solution.stats

In [None]:
eval_ts = jnp.linspace(0.0, 1000.0, 1024)
eval_ys = jax.vmap(eval_solution.evaluate)(eval_ts)
eval_cs = jax.vmap(eval_control.__call__)(eval_ts)

fig, ax = plt.subplots(2, 1, sharex=True, figsize=(10, 7.5))
for k in eval_ys:
    ax[0].plot(eval_ts, eval_ys[k], label=k)
ax[0].legend()
ax[1].plot(eval_ts, eval_cs)
plt.show()

In [None]:
import libsbml

model = libsbml.readSBMLFromFile(
    "/home/lena/master-thesis/repos/optimal-control/data/Giordano2020.xml"
).getModel()

ode_fn = sbmlutils.model_to_lambda(model)
y0 = sbmlutils.species_to_dict(model.getListOfSpecies())

solution = diffrax.diffeqsolve(
    terms=diffrax.ODETerm(ode_fn),
    solver=diffrax.Dopri8(),
    t0=0.0,
    t1=365.0,
    dt0=0.1,
    y0=y0,
    args={},
    saveat=diffrax.SaveAt(t1=True, dense=True),
    stepsize_controller=diffrax.PIDController(rtol=1e-8, atol=1e-8),
)

In [None]:
environment.integrate(control, environment_state, None)