### We want to have a structure, similar to cobra, that allows use to build models in a similar fashion
1. I think it should be reaction-centric.

In [61]:
%cd /home/plent/Documenten/Gitlab/NeuralODEs/jax_neural_odes
from source.kinetic_mechanisms import JaxKineticMechanisms as jm
import jax.numpy as jnp
import jax
import numpy

/home/plent/Documenten/Gitlab/NeuralODEs/jax_neural_odes


In [64]:





class Reaction:
    """Base class that can be used for building kinetic models. The following things must be specified: 
    species involved,
    name of reaction
    stoichiometry of the specific reaction,
    mechanism + named parameters, and compartment """
    def __init__(self,name:str, species:list,stoichiometry:list,mechanism,compartment:str):
        self.species=species
        self.stoichiometry=dict(zip(species,stoichiometry))
        self.mechanism=mechanism
        self.compartment=compartment

    # def retrieve_parameters()
    



ReactionA=Reaction(
    name="ReactionA",
    species=['A','B'],
    stoichiometry=[-1,1],
    mechanism=jm.Jax_MM(substrate="A",vmax="A_Vmax",km="A_Km"),
    compartment="c"
    )

ReactionB=Reaction(
    name="ReactionB",
    species=['B','C'],
    stoichiometry=[-1,1],
    mechanism=jm.Jax_MM(substrate="B",vmax="B_Vmax",km="B_Km"),
    compartment="c"
    )



TypeError: Reaction.__init__() got an unexpected keyword argument 'name'

In [62]:
vars(ReactionB.mechanism)


{'vmax': 'B_Vmax', 'km': 'B_Km', 'substrate': 'B'}

In [None]:
class JaxKineticModel:
    def __init__(self, v,
                 S,
                 flux_point_dict,
                 species_names,
                 reaction_names,
                 compartment_values,):  # params,
        """Initialize given the following arguments:
        v: the flux functions given as lambidified jax functions,
        S: a stoichiometric matrix. For now only support dense matrices, but later perhaps add for sparse
        params: kinetic parameters
        flux_point_dict: a dictionary for each vi that tells what the corresponding metabolites should be in y. Should be matched to S.
        ##A pointer dictionary?
        """
        self.func = v
        self.stoichiometry = S
        # self.params=params
        self.flux_point_dict = flux_point_dict  # this is ugly but wouldnt know how to do it another wa
        self.species_names = np.array(species_names)
        self.reaction_names = np.array(reaction_names)
        self.compartment_values=jnp.array(compartment_values)

    def __call__(self, t, y, args):
        """I explicitly add params to call for gradient calculations. Find out whether this is actually necessary"""
        params, local_params, time_dict = args

        #evaluate the time dictionary values at time t (for event functions e.g.)
        time_dict = time_dict(t)

 
        
        #function evaluates the flux vi given y, parameter, local parameters, time dictionary
        def apply_func(i, y, flux_point_dict, local_params, time_dict):

            if len(flux_point_dict) != 0:
                y = y[flux_point_dict]
                species = self.species_names[flux_point_dict]
                y = dict(zip(species, y))
            else:
                y = {}

            parameters = params[i]

            eval_dict = {**y, **parameters, **local_params, **time_dict}
            vi = self.func[i](**eval_dict)
            return vi

        # Vectorize the application of the functions


        v = jnp.stack([apply_func(i, y, self.flux_point_dict[i],
                                  local_params[i],
                                  time_dict[i])
                       for i in self.reaction_names])  # perhaps there is a way to vectorize this in a better way
        dY = jnp.matmul(self.stoichiometry, v)  # dMdt=S*v(t)
        dY=dY/self.compartment_values
        return dY