<a href="https://colab.research.google.com/github/guibuzi/bioinfo/blob/master/Simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
try:
    import diffrax as dfx
    import equinox as eqx
except ImportError:
    !pip install -U pip diffrax
    import diffrax as dfx
    import equinox as eqx

import jax
import jax.numpy as jnp
import jax.random as jrn
import time
import scipy.stats.qmc as qmc
import numpy as np

from frozendict import frozendict
from diffrax.custom_types import Scalar
from typing import Dict, Tuple, Optional, Any, List
from datetime import datetime
from functools import partial

jax.config.update('jax_enable_x64', True)

_InitialConditions_t = Optional[Tuple[jax.Array, jax.Array]]
_Carry_t = Tuple[jax.Array, jax.Array, jax.Array, Dict[str,int]]
_Solution_t = Tuple[jax.Array, jax.Array, jax.Array]

np.random.seed(1024)

## Model Definition

In [None]:
class Model(eqx.Module):
    system_dimension: int = 4
    for_diffrax: bool = eqx.static_field(default=True)
    to_estimate: int = 2

    @property
    def event_mask(self) -> jax.Array:
        return jnp.array([])

    @property
    def defaults(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """ Return default values for states and arguments """
        states = jnp.array(
            [
               15061.32075, # Tstar
               1860000.0, # V
               1860000.0, # Vin
               0.0, # Vni
            ], dtype=jnp.float64)

        args = jnp.array(
            [
               1.0, # default
               480.0, # NN
               11000.0, # T0
               3.9e-07, # K0
               2.06, # c
               0.53, # delta
            ], dtype=jnp.float64)

        return states, args

    @eqx.filter_jit
    def initial(self, t: float, state: jnp.ndarray, args: jnp.ndarray) -> Tuple[np.ndarray, jnp.ndarray]:
        """ Compute initial equations with given initial value """

        Tstar = state[0]
        V = state[1]
        Vin = state[2]
        Vni = state[3]
        default = args[0]
        NN = args[1]
        T0 = args[2]
        K0 = args[3]
        c = args[4]
        delta = args[5]

        state = jnp.array([Tstar,V,Vin,Vni])

        args = jnp.array([default,NN,T0,K0,c,delta])

        return state, args

    @jax.jit
    def __call__(self, t: float, state: jax.Array, args: jax.Array) -> jax.Array:
        """ Returns the ODEs """
        Tstar = state[0]
        Vin = state[2]
        Vni = state[3]

        K0 = args[3]
        NN = args[1]
        T0 = args[2]
        c = args[4]
        default = args[0]
        delta = args[5]

        v4 = Vni*c
        v3 = Vin*c
        v5 = NN*Tstar*delta
        v2 = Tstar*delta
        v1 = K0*T0*Vin

        dTstar_dt = 1.0*(1.0*v1 - 1.0*v2)/default
        dV_dt = 1.0*(-1.0*v3 - 1.0*v4 + 1.0*v5)/default
        dVin_dt = -1.0*1.0*v3/default
        dVni_dt = 1.0*(-1.0*v4 + 1.0*v5)/default

        return jnp.stack([
            dTstar_dt, dV_dt, dVin_dt, dVni_dt
        ], dtype=jnp.float64)

    @jax.jit
    def algebraic(self, t: float, state: jax.Array, args: jax.Array) -> jax.Array:
        """ Apply algebraic equations and returns the modified values """
        return args

    @jax.jit
    def events(
            self, t: float, state: jax.Array, args: jax.Array, event_mask: jax.Array
    ) -> Tuple[jax.Array, jax.Array, jax.Array]:
        """ Function that execute all the events of the mode """

        return state, args, event_mask

## Diffrax interface (Assuming `args` vector trajectory is to required as result)

In [None]:
@partial(jax.vmap, in_axes=(None,None,None,None,0,0))
def dfxsimulate(
    model: eqx.Module,
    t0: float,
    t1: float,
    duration: int,
    x0: jax.Array,
    args0: jax.Array
) -> dfx.Solution:
    ode_term = dfx.ODETerm(model)
    ts = jnp.linspace(t0, t1, duration + 1)
    pid_controller = dfx.PIDController(rtol=1.0e-6, atol=1.0e-9)
    saveat = dfx.SaveAt(ts=ts)
    solver = dfx.Kvaerno5()

    # Run with diffeqsolve, we can do this since
    # the model has no events
    dt0 = (t1 - t0) / duration
    return dfx.diffeqsolve(
        ode_term, solver, t0, t1, dt0, x0, args0,
        stepsize_controller=pid_controller,
        saveat=saveat, max_steps=500000
    )

## Step-by-step simulation Interface

In [None]:
def _simulation_loop(
    x0: jnp.ndarray,
    args0: jnp.ndarray,
    event_mask: jnp.ndarray,
    ts: jnp.ndarray,
    model: eqx.Module,
    dt0: jnp.ndarray,
    t0: jnp.ndarray,
    ode_term: dfx.AbstractTerm,
    pid_controller: dfx.AbstractStepSizeController,
    solver: dfx.AbstractSolver,
    saveat: dfx.SaveAt
) -> _Solution_t:
    """ Simulate for a single initial condition """

    def _simulation_step(carry: _Carry_t, x: Scalar) -> Tuple[_Carry_t, List[jax.Array]]:
        """ Execute a single integration step of the simulation loop """
        _y0, _args0, event_mask = carry
        _t0 = x
        _t1 = _t0 + dt0

        # Apply Events
        _y0, _args0, event_mask = model.events(_t1, _y0, _args0, event_mask)

        # Run the integration step and other middle steps in case they are necessary
        # for the adaptive step size controller
        _solution = dfx.diffeqsolve(
            terms=ode_term, solver=solver, t0=_t0, t1=_t1, dt0=dt0, y0=_y0, args=_args0,
            saveat=saveat, stepsize_controller=pid_controller, max_steps=50000000)

        # Take the solutions
        _y0, _args0 = _solution.ys[0][0], _solution.ys[1][0]

        # Apply algebraic equations
        _args0 = model.algebraic(_t1, _y0, _args0)
        return (_y0, _args0, event_mask), [_y0.T, _args0.T]

    # Apply initial equations
    x0, args0 = model.initial(t0, x0, args0)

    # Execute initial events
    x0, args0, event_mask = model.events(t0, x0, args0, event_mask)

    # Run the entire simulation and construct the final solution
    initial_carry = (x0, args0, event_mask)
    (_, _, _), ys = jax.lax.scan(_simulation_step, initial_carry, ts)

    # Now we need to add the initial conditions to the trajectories
    state_trajectory = jnp.vstack([x0, ys[0][:-1, :]])
    args_trajectory = jnp.vstack([args0, ys[1][:-1, :]])

    return state_trajectory, args_trajectory, ts


@eqx.filter_jit
def simulate(
    model: eqx.Module,
    t0: float,
    t1: float,
    duration: int,
    initial_conditions: _InitialConditions_t=(None, None),
    controller_atol: float=1.0e-9,
    controller_rtol: float=1.0e-6,
    controller_pcoeff: float=.0,
    controller_icoeff: float=1.0,
    controller_dcoeff: float=.0,
    stiff: bool=True
) -> _Solution_t:
    # Define the ODE term used by Diffrax for defining the vectorial field
    ode_term = dfx.ODETerm(model)

    # Obtain initial assignments and event condition mask
    initial_x, initial_args = initial_conditions
    x0_model, args0_model = model.defaults
    x0 = x0_model if initial_x is None else initial_x
    args0 = args0_model if initial_args is None else initial_args
    event_mask = model.event_mask

    # Define the solver and the step size controller
    solver = dfx.Kvaerno5() if stiff else dfx.Tsit5()
    pid_controller = dfx.PIDController(
        atol=controller_atol,
        rtol=controller_rtol,
        pcoeff=controller_pcoeff,
        icoeff=controller_icoeff,
        dcoeff=controller_dcoeff)

    # Define the saveat
    saveat_y = dfx.SubSaveAt(t1=True, fn=lambda t,y,args: y)
    saveat_args = dfx.SubSaveAt(t1=True, fn=lambda t,y,args: args)
    saveat = dfx.SaveAt(subs=[saveat_y, saveat_args])

    # Define the integration time sequence
    time_sequence = jnp.linspace(t0, t1, duration + 1)

    # Make the variables Traced objects for future uses
    t0 = jnp.asarray(t0, dtype=jnp.float64)
    t1 = jnp.asarray(t1, dtype=jnp.float64)
    dt0 = jnp.asarray((t1 - t0) / duration, dtype=jnp.float64)

    # First of all checks if at least one between x0 and args0 have more then
    # 1 dimension. In this case we need to tile the other to be able to VMAP.
    if x0.ndim > 1 and args0.ndim == 1: args0 = jnp.tile(args0, (x0.shape[0], 1))
    if args0.ndim > 1 and x0.ndim == 1: x0 = jnp.tile(x0, (args0.shape[0], 1))

    _single_simulation = lambda x,y: _simulation_loop(
        x, y, event_mask, time_sequence, model, dt0, t0,ode_term,
        pid_controller, solver, saveat)

    # If both x0 and args0 have one single dimension, then we do not need to VMAP
    if x0.ndim == 1 and args0.ndim == 1:
        return _single_simulation(x0, args0)

    # Otherwise, they have both multiple dimensions and we VMAP
    return jax.vmap(_single_simulation)(x0, args0)

## Run simulations

In [None]:
# Instantiate the model
dfx_model = Model()
my_model = Model(for_diffrax=False)
dfx_model, my_model

(Model(system_dimension=4, for_diffrax=True, to_estimate=2),
 Model(system_dimension=4, for_diffrax=False, to_estimate=2))

### Generate initial conditions

In [None]:
x0, args0 = dfx_model.defaults

num_simulations = 1000
args_noise = jnp.array(np.random.uniform(low=1.0e-5, high=100000, size=(num_simulations, dfx_model.to_estimate)))
multi_args0 = jnp.tile(args0, (num_simulations, 1))
multi_args0 = multi_args0.at[:, -2:].set(args_noise)
multi_x0 = jnp.tile(x0, (num_simulations, 1))

### Comparision Diffrax simulation with step-by-step simulations

In [None]:
def timeis(f, prefix: str=""):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = f(*args, **kwargs)
        jax.block_until_ready(result)
        print(f"{prefix} Elapsed Time: {time.time() - start:.3f} sec")
        return result

    return wrapper

In [None]:
t0 = 0.0
t1 = 7.0
duration = 700

In [None]:
""" Single simulation performance comparision """
# First compile both mine simulate and diffrax simulate
dfxsimulate(dfx_model, t0, t1, duration, x0[None, :], args0[None, :])
simulate(my_model, t0, t1, duration, (x0, args0))

# Then time the diffrax solution and mine solution
solution = timeis(dfxsimulate, prefix="Diffrax")(dfx_model, t0, t1, duration, x0[None, :], args0[None, :])
results = timeis(simulate, prefix="Mine")(my_model, t0, t1, duration, (x0, args0))

Diffrax Elapsed Time: 0.086 sec
Mine Elapsed Time: 1.329 sec


In [None]:
""" Multiple simulation performance comparision """
# Recompile all of them, in this case we also take the compilation time of mine version
dfxsimulate(dfx_model, t0, t1, duration, multi_x0, multi_args0)
timeis(simulate, prefix="Comp Mine")(my_model, t0, t1, duration, (multi_x0, multi_args0))

# We already compile those functions, then we can run them
solution = timeis(dfxsimulate, prefix="Diffrax")(dfx_model, t0, t1, duration, multi_x0, multi_args0)
results = timeis(simulate, prefix="Mine")(my_model, t0, t1, duration, (multi_x0, multi_args0))

Comp Mine Elapsed Time: 6.433 sec
Diffrax Elapsed Time: 1.899 sec
Mine Elapsed Time: 3.512 sec
