### We want to have a structure, similar to cobra, that allows use to build models in a similar fashion
1. I think it should be reaction-centric.

In [162]:
%cd /home/plent/Documenten/Gitlab/NeuralODEs/jax_neural_odes
from source.kinetic_mechanisms import JaxKineticMechanisms as jm
import jax.numpy as jnp
import jax
import numpy as np
import pandas as pd
from source.utils import get_logger
logger = get_logger(__name__)

/home/plent/Documenten/Gitlab/NeuralODEs/jax_neural_odes


In [248]:





class Reaction:
    """Base class that can be used for building kinetic models. The following things must be specified: 
    species involved,
    name of reaction
    stoichiometry of the specific reaction,
    mechanism + named parameters, and compartment """
    def __init__(self,name:str, species:list,stoichiometry:list,compartments:list,mechanism):
        self.name=name
        self.species=species
        self.stoichiometry=dict(zip(species,stoichiometry))
        self.mechanism=mechanism
        self.compartments=dict(zip(species,compartments))
        

        self.parameters=np.setdiff1d(list(vars(mechanism).values()),species) #excludes species as parameters, but add as seperate variable
        self.species_in_mechanism=np.intersect1d(list(vars(mechanism).values()),species) 
        

    # def retrieve_parameters()
    


ReactionA=Reaction(
    name="ReactionA",
    species=['A','B'],
    stoichiometry=[-1,1],
    compartments=['c','c'],
    mechanism=jm.Jax_MM(substrate="A",vmax="A_Vmax",km="A_Km"),
    )

ReactionB=Reaction(
    name="ReactionB",
    species=['B','C'],
    stoichiometry=[-1,1],
    compartments=['c','c'],
    mechanism=jm.Jax_MM(substrate="B",vmax="B_Vmax",km="B_Km"),
    )

ReactionC=Reaction(
    name="ReactionC",
    species=['C','D','E'],
    compartments=['c','c','c'],
    stoichiometry=[-1,-1,1],
    mechanism=jm.Jax_MM_Irrev_Bi(substrate1="C",substrate2="D",vmax="C_Vmax",km_substrate1="C_Km",km_substrate2="C_Km")
    )


reactions=[ReactionA,ReactionB]



In [292]:

def flatten(xss):
    return [x for xs in xss for x in xs]


class JaxKineticModel_Build:
    def __init__(self,
                 reactions:list,
                 compartment_values:dict):
        """Kinetic model that is defined through it's reactions:
        Input:"""
        self.reactions=reactions

        self.S=self._get_stoichiometry()
        self.reaction_names=self.S.columns.to_list()
        self.species_names=self.S.index.to_list()

        self.species_compartments=self._get_compartments_species()
        self.compartment_values=jnp.array([compartment_values[self.species_compartments[i]] for i in self.species_names])

        # only retrieve the mechanisms from each reaction
        self.v=[reaction.mechanism for reaction in self.reactions]


        #retrieve parameter names
        self.parameter_names=flatten([reaction.parameters for reaction in self.reactions])

    def _get_stoichiometry(self):
        """Build stoichiometric matrix from reactions """
        build_dict={}
        for reaction in self.reactions:
            build_dict[reaction.name]=reaction.stoichiometry
        S=pd.DataFrame(build_dict).fillna(value=0)
        return S
    
    def _get_compartments_species(self):
        """Retrieve compartments for species and do a consistency check that compartments are properly defined for each species"""
        comp_dict={}
        for reaction in self.reactions:

            for species,values in reaction.compartments.items():
                if species not in comp_dict.keys():
                    comp_dict[species]=values
                else:
                    if comp_dict[species]!=values:
                        logger.error(f"Species {species} has ambiguous compartment values, please check consistency in the reaction definition")
            
        return comp_dict
    





compartment_values={'c':1.0,'e':3.0}

kmodel=JaxKineticModel_Build(reactions,compartment_values)
kmodel.v
params=dict(zip(kmodel.parameter_names,np.ones(len(kmodel.parameter_names))))
y=dict(zip(kmodel.species_names,[1,2,3]))
evaluation_dictionary={**params,**y}
# evaluation_dictionary




0.5
0.6666666666666666


In [250]:
class JaxKineticModel:
    def __init__(self, v,
                 S,
                 flux_point_dict,
                 species_names,
                 reaction_names,
                 compartment_values,):  # params,
        """Initialize given the following arguments:
        v: the flux functions given as lambidified jax functions,
        S: a stoichiometric matrix. For now only support dense matrices, but later perhaps add for sparse
        params: kinetic parameters
        flux_point_dict: a dictionary for each vi that tells what the corresponding metabolites should be in y. Should be matched to S.
        ##A pointer dictionary?
        """
        self.func = v
        self.stoichiometry = S
        # self.params=params
        self.flux_point_dict = flux_point_dict  # this is ugly but wouldnt know how to do it another wa
        self.species_names = np.array(species_names)
        self.reaction_names = np.array(reaction_names)
        self.compartment_values=jnp.array(compartment_values)

    def __call__(self, t, y, args):
        """I explicitly add params to call for gradient calculations. Find out whether this is actually necessary"""
        params, local_params, time_dict = args

        #evaluate the time dictionary values at time t (for event functions e.g.)
        time_dict = time_dict(t)

        
        #function evaluates the flux vi given y, parameter, local parameters, time dictionary
        def apply_func(i, y, flux_point_dict, local_params, time_dict):

            if len(flux_point_dict) != 0:
                y = y[flux_point_dict]
                species = self.species_names[flux_point_dict]
                y = dict(zip(species, y))
            else:
                y = {}

            parameters = params[i]

            eval_dict = {**y, **parameters, **local_params, **time_dict}
            vi = self.func[i](**eval_dict)
            return vi

        # Vectorize the application of the functions


        v = jnp.stack([apply_func(i, y, self.flux_point_dict[i],
                                  local_params[i],
                                  time_dict[i])
                       for i in self.reaction_names])  # perhaps there is a way to vectorize this in a better way
        dY = jnp.matmul(self.stoichiometry, v)  # dMdt=S*v(t)
        dY=dY/self.compartment_values
        return dY
    

class NeuralODE:
    func: JaxKineticModel

    def __init__(self,
                 v,
                 S,
                 met_point_dict,
                 v_symbol_dictionaries,
                 compartment_values,):
        self.func = JaxKineticModel(v,
                                    jnp.array(S),
                                    met_point_dict,
                                    list(S.index),
                                    list(S.columns),
                                    compartment_values)
        self.reaction_names = list(S.columns)
        self.v_symbol_dictionaries = v_symbol_dictionaries
        self.Stoichiometry = S

        self.max_steps=200000
        self.rtol=1e-7
        self.atol=1e-9

        def wrap_time_symbols(t):
            time_dependencies = time_dependency_symbols(v_symbol_dictionaries, t)
            return time_dependencies
        

        ## time dependencies: a function that return for all fluxes whether there is a time dependency
        self.time_dict = jax.jit(wrap_time_symbols)

    def __call__(self, ts, y0, params):
        global_params, local_params = separate_params(params)

        # ensures that global params are loaded flux specific (necessary for jax)
        global_params = construct_param_point_dictionary(self.v_symbol_dictionaries,
                                                         self.reaction_names,
                                                         global_params)  # this is required,

        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Kvaerno5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=1e-6,
            y0=y0,
            args=(global_params, local_params, self.time_dict),
            stepsize_controller=diffrax.PIDController(rtol=self.rtol, atol=self.atol,pcoeff=0.4,icoeff=0.3,dcoeff=0),
            saveat=diffrax.SaveAt(ts=ts),
            max_steps=self.max_steps
        )


        return solution.ys
