In [356]:
%cd /home/plent/Documenten/Gitlab/NeuralODEs/jax_neural_odes
from source.kinetic_mechanisms import JaxKineticMechanisms as jm
from source.building_models import JaxKineticModelBuild as jkm
from source.load_sbml.sbml_load import *
from source.load_sbml.sbml_model import SBMLModel
import jax.numpy as jnp
import jax
import numpy as np
from source.utils import get_logger
logger = get_logger(__name__)
import diffrax 
import matplotlib.pyplot as plt
import pandas as pd
import itertools
import sys


/home/plent/Documenten/Gitlab/NeuralODEs/jax_neural_odes


In [411]:
# a simple sbml model
filepath = (
      "models/sbml_models/working_models/Smallbone2011_TrehaloseBiosynthesis.xml")

model = SBMLModel(filepath)
S=model._get_stoichiometric_matrix()
JaxKmodel = model.get_kinetic_model()
JaxKmodel = jax.jit(JaxKmodel)

ts = jnp.linspace(0,15,2000)
# #parameters are not yet defined
global_params = get_global_parameters(model.model)
params = {**model.local_params, **global_params}

ys=JaxKmodel(ts,model.y0,params)

15:21:15,721 - source.load_sbml.sbml_model - INFO - No internal inconsistencies found
15:21:15,722 - source.load_sbml.sbml_model - INFO - Model loaded.
15:21:15,722 - source.load_sbml.sbml_model - INFO -  number of species: 16
15:21:15,723 - source.load_sbml.sbml_model - INFO -  number of reactions: 8
15:21:15,723 - source.load_sbml.sbml_model - INFO -  number of global parameters: 13
15:21:15,723 - source.load_sbml.sbml_model - INFO -  number of constant boundary metabolites: 10
15:21:15,724 - source.load_sbml.sbml_model - INFO -  number of lambda function definitions: 0
15:21:15,724 - source.load_sbml.sbml_model - INFO -  number of assignment rules: 6
15:21:15,724 - source.load_sbml.sbml_model - INFO -  number of event rules: 0


Assume that boundary is constant for level 2
Assume that boundary is constant for level 2
Assume that boundary is constant for level 2
Assume that boundary is constant for level 2
Assume that boundary is constant for level 2
Assume that boundary is constant for level 2
Assume that boundary is constant for level 2
Assume that boundary is constant for level 2
Assume that boundary is constant for level 2
Assume that boundary is constant for level 2


In [412]:
# ys=JaxKmodel(ts,model.y0,params)



# global_params, local_params = separate_params(params)
# global_params = construct_param_point_dictionary(JaxKmodel.v_symbol_dictionaries,
#                                                          JaxKmodel.reaction_names,
#                                                          global_params) 
# args=(global_params, local_params,JaxKmodel.time_dict)
# a=[]
# for i in range(2000):
#     flux=JaxKmodel.func(0,ys[i,:],args)
#     flux=flux[0]
#     a.append(flux)

In [425]:
class DesignBuildTestLearnCycle:
    """A class that represents a metabolic engineering process. The underlying process is a kinetic model (parameterized and with initial conditions). Can be used
    to simulate scenarios that might occur in true optimization processes
    Input:
    1.  A model: either build or SBML
    2. Parameters: defined in a global way. This will represent the state of optimization process 
    3. Initial conditions for the model
    4. Time evaluation scale of the process. 
    
    
    """
    def __init__(self,
                 kinetic_model,
                 parameters:dict,
                 initial_conditions :jnp.array,
                 timespan:jnp.array):
        self.kinetic_model=kinetic_model
        self.parameters=parameters
        self.initial_conditions=initial_conditions
        self.timespan=timespan
        self.cycle_status=0
        self.library_units=None #library defines the building blocks of actions when constructing ME scenarios
        self.designs_per_cycle={}


    def DESIGN_establish_library_elements(self, parameter_target_names, parameter_perturbation_values):
        """
        The actions that can be taken when sampling scenarios during the Design-phase.
        From an experimental perspective, this can be viewed as the library design phase.
        
        Input:
        - parameter_target_names: names of the parameters that we wish to perturb
        - parameter_perturbation_values: the actual perturbation (promoter) values of the parameters.
        These are defined RELATIVE to the reference state.
        """
        # Check that all parameter_target_names are valid
        for pt in parameter_target_names:
            if pt not in self.parameters.keys():
                logger.error(f"Parameter target {pt} not in the model. Perhaps a spelling mistake?")
                return None  # Return None and do not overwrite self.library_units

        # If all parameters are valid, flatten the combinations
        flattened_combinations = [
            (name, value)
            for name, values in zip(parameter_target_names, parameter_perturbation_values)
            for value in values
        ]
        
        # Create a DataFrame for the elementary actions
        elementary_actions = pd.DataFrame(flattened_combinations, columns=['parameter_name', 'promoter_value'])

        self.library_units = elementary_actions
        return elementary_actions
    
    ## We need an additional function that assigns probabilities to promoters in library units. This allows for biasing sampling when doing the DoE and sampling
    ## Make a function that returns equal probability of occurence if nothing has been inputted, and else it requires a list of occurence of  the promoter types)
    def DESIGN_assign_probabilities(self,occurence_list=None):
        """This functions assigns a probability to each element in the action list. Can be viewed as changing concentrations in a library design"""
        rows,cols=np.shape(self.library_units)
        if occurence_list is not None:
            if len(occurence_list)==rows:
                self.library_units['probability']=np.array(occurence_list)/np.sum(occurence_list)
                return_message="manual probabilities"
                pass 
            else:
                return_message="None"
                logger.error(f"Length of list of occurences of promoters is not matching ")
        else:
            return_message="equal probabilities"
            self.library_units['probability']=np.ones(rows)/rows
        return return_message
    

    ### Now a function that can sample designs given the assigned probabilities. One important note: if a parameter is sampled twice, we sum its values when perturbing the model
    ### The reason we do this is that the action is defined w.r.t. to the reference strain at hand. This means that [E_ref]*[p1_strenghts]+[E_ref]*[p2_strengths]=[E_ref](p2+p1)
    ### We should also add boolean if one does not want duplicates
    def DESIGN_generate_strains(self,elements,samples,replacement=False):
        """Sample designs given the elementary actions given
        Input: number of elements to choose from the library (typically 6), number of samples.
        Replacement means whether we allow duplicate genes in the designs."""
        strains=[]
        for i in range(samples):
            perturbed_parameters=dbtl_cycle.parameters.copy()
            sample=dbtl_cycle.library_units.sample(n=elements,weights=dbtl_cycle.library_units['probability'],replace=replacement)[['parameter_name','promoter_value']]
            strain={}
            for param, value in zip(sample['parameter_name'].values, sample['promoter_value']):
                if param in strain:
                    strain[param] += value  # Sum the values if the key exists
                else:
                    strain[param] = value   # Add the new key-value pair if it doesn't exist

            #overwrite reference parameters.
            for key,values in strain.items():
                perturbed_parameters[key]=perturbed_parameters[key]*strain[key]

            strains.append(perturbed_parameters)
        
        self.designs_per_cycle[str(self.cycle_status)]=strains
        return strains

            
    



dbtl_cycle=DesignBuildTestLearnCycle(kinetic_model=JaxKmodel,
                          parameters=params,
                          initial_conditions=model.y0,
                          timespan=ts)



parameter_target_names=['lp.pgi.Kf6p','lp.hxk.Vmax','lp.tpp.Vmax']
parameter_perturbation_value=[[0.5,1,1.5],[1.2,1.5,1.8],[1.1,1.6,1.3]]

dbtl_cycle.DESIGN_establish_library_elements(parameter_target_names,
                                               parameter_perturbation_value)
dbtl_cycle.DESIGN_assign_probabilities([1,1,1,2,1,1,1,1,1])
strains_perturbed=dbtl_cycle.DESIGN_generate_strains(elements=10,samples=50,replacement=True)


model.species_names
# for i in strains_perturbed:
#     ys=JaxKmodel(ts,model.y0,i)
#     ys=ys[:,1]
#     plt.plot(ts,ys)




['glc', 'g1p', 'g6p', 'trh', 't6p', 'udg']

In [145]:
# the library will look like  
# Gene A p1
# Gene A p2
parameter_target_names=['lp.pgi.Kf6p','lp.hxk.Vmax','lp.tpp.Vmax']
parameter_perturbation_value=[[0.5,1,1.5],[1.2,1.5,1.8],[2,4,6]]
parameter_perturbation_occurence=[[1,1,1],[1,1,1],[1,1,1]]
# pd.DataFrame((itertools.product(parameter_target_names,parameter_promoter_values)))

# Use itertools.product to create the combinations
combinations = itertools.product(parameter_target_names, *parameter_perturbation_value)

flattened_combinations = [
    (name, value, occurrence)
    for name, values, occurrences in zip(parameter_target_names, parameter_perturbation_value, parameter_perturbation_occurence)
    for value, occurrence in zip(values, occurrences)
]

elementary_actions = pd.DataFrame(flattened_combinations, columns=['parameter_target', 'parameter_perturbation','occurence'])
elementary_actions['occurence']=elementary_actions['occurence']/elementary_actions['occurence'].sum()




array(['lp.tpp.Vmax', 'lp.pgi.Kf6p', 'lp.hxk.Vmax', 'lp.tpp.Vmax',
       'lp.pgi.Kf6p', 'lp.hxk.Vmax'], dtype=object)

In [None]:
def equal_sampling_scenario(enz_names, perturb_range, N):
    """
    This function generates a list of designs for a scenario where each enzyme in a set of enzymes 
    has an equal chance of being perturbed within a certain range.
    
    :param enz_names: a list of strings representing enzyme names
    :param perturb_range: a list of tuples representing the perturbation range for each enzyme. 
    Each tuple contains two floats representing the minimum and maximum perturbation values.
    :param N: an integer representing the number of designs to generate
    
    :return: a list of dictionaries where each dictionary represents a design. Each dictionary has keys 
    that correspond to enzyme names and values that correspond to the perturbation value for that enzyme. 
    Additionally, the function returns a list of lists where each inner list represents a design and contains
    perturbation values for each enzyme.
    """
    library_choices=dict(zip(enz_names,perturb_range))# create a dictionary that maps each enzyme to its perturbation range
    cart=[] # create an empty list to hold the perturbation values for each design
    designs_list=[]
    for i in range(N):# loop N times to generate N designs
        design=[]
        for j in enz_names:
        
            x=np.random.choice(library_choices[j])# randomly select a perturbation value for the enzyme
            design.append(x)
        cart.append(design)
    for i in range(len(cart)):
        design=dict(zip(enz_names,cart[i]))# create a dictionary that maps enzyme names to perturbation values for the design
        designs_list.append(design)
    return designs_list,cart  # return the list of design dictionaries and the list of design perturbation values