# Paul's data PINN implementation

This notebook presents a code that imports IML1515 sbml model, reduces it, and then creates dictionaries and elements to train the PINN on Paul's data for E. coli growth measurements. It is the basis for the codes for creation of dictionaries and importation + reduction of an SBML model, and can be considered as reference when reviewing the source code of these two features. 

In [None]:
pip install python-libsbml

Collecting python-libsbml
  Downloading python_libsbml-5.20.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (666 bytes)
Downloading python_libsbml-5.20.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.1/8.1 MB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: python-libsbml
Successfully installed python-libsbml-5.20.5


In [None]:
pip install cobra

Collecting cobra
  Downloading cobra-0.29.1-py2.py3-none-any.whl.metadata (9.3 kB)
Collecting appdirs~=1.4 (from cobra)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting depinfo~=2.2 (from cobra)
  Downloading depinfo-2.2.0-py3-none-any.whl.metadata (3.8 kB)
Collecting diskcache~=5.0 (from cobra)
  Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 kB)
Collecting optlang~=1.8 (from cobra)
  Downloading optlang-1.8.3-py2.py3-none-any.whl.metadata (8.2 kB)
Collecting ruamel.yaml~=0.16 (from cobra)
  Downloading ruamel.yaml-0.18.13-py3-none-any.whl.metadata (24 kB)
Collecting swiglpk (from cobra)
  Downloading swiglpk-5.0.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting ruamel.yaml.clib>=0.2.7 (from ruamel.yaml~=0.16->cobra)
  Downloading ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.7 kB)
Downloading cobra-0.29.1-py2.py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━

## First : importing IML1515 and preparing the equations and parameters dictionaries

In [None]:
#Imports
import libsbml
import numpy as np
import cobra
import cobra.manipulation as manip
from cobra import Reaction, Metabolite, Model
from cobra.flux_analysis import pfba
import sympy as sp
import sys

In [None]:
#Loading the sbml model as "model_"
def load_sbml_model(filepath):
    reader = libsbml.SBMLReader()
    doc = reader.readSBML(filepath)
    if doc.getNumErrors() > 0:
        raise ValueError("SBML read error")
    return doc.getModel()

#model_ = load_sbml_model("iML1515.xml")
model_=load_sbml_model("iML1515_duplicated-2.xml")

# Get all species
species_ids = [s.getId() for s in model_.getListOfSpecies()]
print("Species:", species_ids)

# Get all reactions and their kinetic laws
#for rxn in model.getListOfReactions():
    #print("Reaction:", rxn.getId())
    #print("  Equation:", model.getReaction(rxn.getId())) #.getKineticLaw().getFormula())

Species: ['M_octapb_c', 'M_cysi__L_e', 'M_dhap_c', 'M_prbatp_c', 'M_10fthf_c', 'M_btal_c', 'M_6pgg_c', 'M_co2_e', 'M_akg_e', 'M_gsn_e', 'M_pydx5p_c', 'M_3dhgulnp_c', 'M_g3ps_c', 'M_adphep_LD_c', 'M_lyx__L_c', 'M_din_p', 'M_2pg_c', 'M_ptrc_p', 'M_malt_p', 'M_pppn_p', 'M_arbtn_p', 'M_hphhlipa_c', 'M_phphhlipa_c', 'M_13dpg_c', 'M_murein3px4p_p', 'M_34dhpac_e', 'M_1odec11eg3p_c', 'M_12dgr181_p', 'M_anhgm_e', 'M_prbamp_c', 'M_dsbdrd_c', 'M_cu2_p', 'M_sla_c', 'M_14glucan_p', 'M_grdp_c', 'M_ribflv_p', 'M_dms_e', 'M_pgp141_c', 'M_cysi__L_c', 'M_fpram_c', 'M_f1p_c', 'M_dsbard_p', 'M_thr__L_c', 'M_dcyt_p', 'M_2ddglcn_c', 'M_fum_p', 'M_galctn__L_e', 'M_btn_e', 'M_pydxn_p', 'M_ocdcea_e', 'M_preq0_c', 'M_tyr__L_p', 'M_dtdp4d6dg_c', 'M_acmum_e', 'M_man_p', 'M_adocbl_e', 'M_iscu_2fe2s_c', 'M_frulysp_c', 'M_dump_p', 'M_novbcn_e', 'M_feenter_e', 'M_eca4und_p', 'M_gg4abut_c', 'M_flxr_c', 'M_pa160_c', 'M_lcts_c', 'M_arbtn_e', 'M_cdpdhdec9eg_c', 'M_aso3_c', 'M_progly_c', 'M_3ohdcoa_c', 'M_clpn161_p', 'M_1

In [None]:
len([model_.getReaction(rx.getId()) for rx in model_.getListOfReactions()])

3682

**Reduction of the model** (work in progress)

In [None]:
#Getting Paul's data and reformatting
import pandas as pd

data=pd.read_csv("Growth_curves_copie.csv", sep=";")
iML1515dat=pd.read_csv("iML1515.csv")

#Data curation and addings
data=data.drop([data.columns[0],data.columns[-1],data.columns[-2],data.columns[-3]], axis=1)
data.columns = [col[2:] if col.startswith('R_') else col for col in data.columns]
new_med=iML1515dat.loc[:, iML1515dat.loc[0] == 1]
for col in data.columns[:-1]:
  new_med[col]=[100, 2.2,np.nan, np.nan] #Adding columns for all metabolites that are variable in the medium
for col in iML1515dat.loc[:, iML1515dat.loc[0] == 1].columns:
  if col not in data.columns[:-1]:
    data[col]=[1 for i in range(280)] #Adding 1 columns for all necessary metabolites present in all mediums

data.to_csv('curated_dataset.csv')
new_med.to_csv('new_iML1515.csv')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_med[col]=[100, 2.2,np.nan, np.nan] #Adding columns for all metabolites that are variable in the medium
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_med[col]=[100, 2.2,np.nan, np.nan] #Adding columns for all metabolites that are variable in the medium
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-

In [None]:
data

Unnamed: 0,EX_glc__D_e_i,EX_xyl__D_e_i,EX_succ_e_i,EX_ala__L_e_i,EX_arg__L_e_i,EX_asn__L_e_i,EX_asp__L_e_i,EX_cys__L_e_i,EX_glu__L_e_i,EX_gln__L_e_i,...,EX_mobd_e_i,EX_so4_e_i,EX_nh4_e_i,EX_k_e_i,EX_na1_e_i,EX_cl_e_i,EX_o2_e_i,EX_tungs_e_i,EX_slnt_e_i,EX_glyc_e_i
0,1,0,0,0,0,0,0,0,0,0,...,1,1,1,1,1,1,1,1,1,1
1,1,0,0,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
2,1,0,0,0,0,0,0,0,0,0,...,1,1,1,1,1,1,1,1,1,1
3,1,0,0,0,1,0,0,0,0,0,...,1,1,1,1,1,1,1,1,1,1
4,1,0,0,0,1,0,0,1,0,0,...,1,1,1,1,1,1,1,1,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
275,0,0,0,1,0,1,1,0,1,1,...,1,1,1,1,1,1,1,1,1,1
276,0,0,0,0,1,0,0,0,1,1,...,1,1,1,1,1,1,1,1,1,1
277,0,0,0,0,1,0,0,0,1,1,...,1,1,1,1,1,1,1,1,1,1
278,0,0,0,0,1,0,0,0,1,1,...,1,1,1,1,1,1,1,1,1,1


In [None]:
new_med

Unnamed: 0,EX_co2_e_i,EX_fe3_e_i,EX_h_e_i,EX_mn2_e_i,EX_fe2_e_i,EX_zn2_e_i,EX_mg2_e_i,EX_ca2_e_i,EX_ni2_e_i,EX_cu2_e_i,...,EX_phe__L_e_i,EX_ser__L_e_i,EX_trp__L_e_i,EX_tyr__L_e_i,EX_val__L_e_i,EX_ade_e_i,EX_gua_e_i,EX_csn_e_i,EX_ura_e_i,EX_thymd_e_i
0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0
1,2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2,...,2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2
2,,,,,,,,,,,...,,,,,,,,,,
3,,,,,,,,,,,...,,,,,,,,,,


In [None]:
model=cobra.io.read_sbml_model("iML1515_duplicated-2.xml")
mediumname="new_iML1515"
method='EXP'
mediumsize=51

def read_csv(filename):
    # Reading datafile with pandas
    # Return HEADER and DATA
    filename += '.csv'
    dataframe = pd.read_csv(filename, header=0)
    HEADER = dataframe.columns.tolist()
    dataset = dataframe.values
    DATA = np.asarray(dataset[:,:])
    return HEADER, DATA

H, M = read_csv(mediumname) #H:header ; M:data (fluxes)
if 'EXP' in method : # Reading X, Y
    if mediumsize < 1:
        sys.exit('must indicate medium size with experimental dataset')
        medium = []
        for i in range(mediumsize):
            medium.append(H[i])
    else:
        medium = H[1:]

#measure = [r.id for r in model.reactions]
#measure=[r.id for r in model.reactions if "BIOMASS" in r.id.upper()] #Conservation of all the biomass ODE
measure=[]
flux=M[:,len(medium):]

def reduce_model(model, medium, measure, flux, verbose=False):
    # Remove all reactions not in medium having a zero flux
    # Input: the model, the medium, the flux vector (a 2D array)
    # Output: the reduce model

    # Collect reaction to be removed
    remove = {}
    for i in range(flux.shape[1]):
        if np.count_nonzero(flux[:,i]) == 0 and \
        model.reactions[i].id not in medium and \
        model.reactions[i].id not in measure:
            remove[i] = model.reactions[i]

    # Actual deletion
    model.remove_reactions(list(remove.values()))
    manip.delete.prune_unused_reactions(model)
    for m in model.metabolites:
        if len(m.reactions) == 0:
            model.remove_metabolites(m)
    manip.delete.prune_unused_metabolites(model)
    print('reduced numbers of metabolites and reactions:',
          len(model.metabolites), len(model.reactions))

    return model

In [None]:
reduce_model(model, medium, measure, flux)

reduced numbers of metabolites and reactions: 1877 3682


0,1
Name,iML1515
Memory address,78f524138050
Number of metabolites,1877
Number of reactions,3682
Number of genes,1516
Number of groups,0
Objective expression,1.0*BIOMASS_Ec_iML1515_core_75p37M - 1.0*BIOMASS_Ec_iML1515_core_75p37M_reverse_35685
Compartments,"cytosol, extracellular space, periplasm"


The previous code is not working : no reduction of the model. Therefore, chatGPT code below to try actually get a reduction.

In [None]:
import numpy as np
from cobra.manipulation import delete

In [None]:
def reduce_model_GPT(model, medium, measure, verbose=False):
    """
    Reduce an SBML model by removing reactions that:
    - Are not part of the medium or 'measure' list
    - Have zero flux under FBA

    Parameters:
    - model: cobra.Model object
    - medium: list of exchange reaction IDs that should be kept
    - measure: list of reaction IDs to preserve
    - verbose: if True, print info about removed reactions

    Returns:
    - Reduced cobra.Model object
    """
    # Solve the model to get fluxes
    solution = model.optimize()

    if solution.status != 'optimal':
        raise RuntimeError("Model optimization failed. Check constraints or medium setup.")

    fluxes = solution.fluxes
    remove = []

    for rxn in model.reactions:
        rxn_id = rxn.id
        if np.isclose(fluxes[rxn_id], 0.0) and rxn_id not in medium and rxn_id not in measure:
            remove.append(rxn)
            if verbose:
                print(f"Removing {rxn_id} (flux = 0)")

    # Remove reactions and prune
    model.remove_reactions(remove)
    delete.prune_unused_reactions(model)

    # Remove orphan metabolites
    orphan_mets = [m for m in model.metabolites if len(m.reactions) == 0]
    model.remove_metabolites(orphan_mets)
    delete.prune_unused_metabolites(model)

    if verbose:
        print(f"Reduced model: {len(model.metabolites)} metabolites, {len(model.reactions)} reactions.")

    return model, fluxes

In [None]:
from cobra.io import read_sbml_model
model = read_sbml_model("iML1515_duplicated-2.xml")
reduced, fluxes = reduce_model_GPT(model, medium=medium, measure=measure, verbose=True)

Removing CYTDK2 (flux = 0)
Removing XPPT (flux = 0)
Removing HXPRT (flux = 0)
Removing NDPK5_for (flux = 0)
Removing NDPK6_for (flux = 0)
Removing NDPK8_for (flux = 0)
Removing DHORTS_for (flux = 0)
Removing PYNP2r_for (flux = 0)
Removing ALATA_L2 (flux = 0)
Removing DURIPP_for (flux = 0)
Removing ACALD_for (flux = 0)
Removing PTRCTA (flux = 0)
Removing ACS (flux = 0)
Removing CYSDS (flux = 0)
Removing MAN6PI_for (flux = 0)
Removing TRPAS2_for (flux = 0)
Removing PPCK (flux = 0)
Removing ME1 (flux = 0)
Removing ALATA_L_for (flux = 0)
Removing XYLK (flux = 0)
Removing RBK (flux = 0)
Removing GLYK (flux = 0)
Removing PPM_for (flux = 0)
Removing ASPTA_for (flux = 0)
Removing ACP1_FMN (flux = 0)
Removing NDP3 (flux = 0)
Removing CDPPH (flux = 0)
Removing NDP7 (flux = 0)
Removing FBP (flux = 0)
Removing EX_pi_e_o (flux = 0)
Removing GLGC (flux = 0)
Removing ALATA_D2 (flux = 0)
Removing EX_met__L_e_o (flux = 0)
Removing GTHOr_for (flux = 0)
Removing ILETA_for (flux = 0)
Removing DHORD5 (flux

In [None]:
#Getting back an SBML model
from cobra.io import write_sbml_model
write_sbml_model(reduced, "reduced_model.xml")
model=load_sbml_model("reduced_model.xml")

In [None]:
[rx.getId for rx in model.getListOfReactions()][0]

**Adding Biomass** : Biomass is not a metabolite, so to consider it, we need to add it to the system.

In [None]:
#Adding biomass as a metabolite in the model
if "biomass" not in [m.id for m in model.species]:
    biomass_species = model.createSpecies()
    biomass_species.setId("biomass")
    biomass_species.setName("Biomass")
    if model.getCompartment("c") is not None:
        biomass_species.setCompartment("c")
    else:
        print("Error : no cytosol compartment") #This is not supposed to happen with our model
    biomass_species.setInitialAmount(0.0)
    biomass_species.setBoundaryCondition(False)
    biomass_species.setHasOnlySubstanceUnits(False)
    biomass_species.setConstant(False)

In [None]:
#Getting the biomass reactions
biomass_reactions=[rxn.id for rxn in model.getListOfReactions() if "BIOMASS" in rxn.id.upper()]
for id in biomass_reactions:
    biomass_rxn = model.getReaction(id)
    biomass_product = biomass_rxn.createProduct()
    biomass_product.setSpecies("biomass")
    biomass_product.setStoichiometry(1.0)
    biomass_product.setConstant(True)

In [None]:
#Checking that all the changes have occured
print("biomass" in [s.getId() for s in model.getListOfSpecies()])

for id in biomass_reactions:
    biomass_rxn = model.getReaction(id)
    for prod in biomass_rxn.getListOfProducts():
        print(f"{prod.getStoichiometry()} {prod.getSpecies()}")

True
75.37723 M_adp_c
75.37723 M_h_c
75.37323 M_pi_c
0.773903 M_ppi_c
1.0 biomass


**Building the ODE :** version without the kinetic laws

In [None]:
##IN CASE OF
def get_products_reactants(reaction):
  # Helper function to get species name
    def get_species_name(species_ref):
        species = model.getSpecies(species_ref.getSpecies())
        return species.getName() if species.getName() else species.getId()

  # Extract names of reactants
    reactants = [get_species_name(sr) for sr in reaction.getListOfReactants()]

  # Extract names of products
    products = [get_species_name(sr) for sr in reaction.getListOfProducts()]

    return reactants,products

[get_products_reactants(rxn) for rxn in model.getListOfReactions() if "BIOMASS" in rxn.id.upper()][0] #To check how it works

(['10-Formyltetrahydrofolate',
  '[2Fe-2S] iron-sulfur cluster',
  '2-Octaprenyl-6-hydroxyphenol',
  '[4Fe-4S] iron-sulfur cluster',
  'L-Alanine',
  'S-Adenosyl-L-methionine',
  'L-Arginine',
  'L-Asparagine',
  'L-Aspartate',
  'ATP C10H12N5O13P3',
  'Biotin',
  'Calcium',
  'Chloride',
  'Coenzyme A',
  'Co2+',
  'CTP C9H12N3O14P3',
  'Copper',
  'L-Cysteine',
  'DATP C10H12N5O12P3',
  'DCTP C9H12N3O13P3',
  'DGTP C10H12N5O13P3',
  'DTTP C10H13N2O14P3',
  'Flavin adenine dinucleotide oxidized',
  'Fe2+ mitochondria',
  'Iron (Fe3+)',
  'L-Glutamine',
  'L-Glutamate',
  'Glycine',
  'GTP C10H12N5O14P3',
  'H2O H2O',
  'L-Histidine',
  'L-Isoleucine',
  'Potassium',
  'KDO(2)-lipid IV(A)',
  'L-Leucine',
  'L-Lysine',
  'L-Methionine',
  'Magnesium',
  '5,10-Methylenetetrahydrofolate',
  'Manganese',
  'Molybdate',
  'Two disacharide linked murein units, pentapeptide crosslinked tetrapeptide (A2pm->D-ala) (middle of chain)',
  'Nicotinamide adenine dinucleotide',
  'Nicotinamide adeni

In [None]:
#Modif with ChatGPT – working code
from collections import defaultdict

def build_odes_nolaws(model):
    odes = defaultdict(list)

    # Create symbolic variables for all species
    species_symbols = {
        s.getId(): sp.Symbol(s.getId())
        for s in model.getListOfSpecies()
    }

    for rxn in model.getListOfReactions():
        rid = rxn.getId()
        Vmax = sp.Symbol(f"v_max_{rid}")

        reactants = rxn.getListOfReactants()
        products = rxn.getListOfProducts()

        # Symbols for Km values
        KmS = [sp.Symbol(f"Km_{r.getSpecies()}_{rid}") for r in reactants]
        KmP = [sp.Symbol(f"Km_{p.getSpecies()}_{rid}") for p in products]

        # Convenience kinetics rate equation
        reactant_terms = [
            (species_symbols[r.getSpecies()] / KmS[i]) /
            (1 + (species_symbols[r.getSpecies()] / KmS[i]))
            for i, r in enumerate(reactants)
        ]

        product_terms = [
            1 / (1 + (species_symbols[p.getSpecies()] / KmP[i]))
            for i, p in enumerate(products)
        ]

        rate = Vmax * sp.Mul(*reactant_terms) * sp.Mul(*product_terms)

        # Build the ODEs
        for reactant in reactants:
            species = reactant.getSpecies()
            stoich = -reactant.getStoichiometry()
            odes[species].append(stoich * rate)

        for product in products:
            species = product.getSpecies()
            stoich = product.getStoichiometry()
            odes[species].append(stoich * rate)

    # Combine terms for each species
    return {s: sum(terms) for s, terms in odes.items()}

In [None]:
ode_system = build_odes_nolaws(model)
for species, ode in ode_system.items():
    print(f"d{species}/dt = {ode}")

dM_3dhsk_c/dt = -1.0*M_3dhsk_c*M_h_c*M_nadph_c*v_max_R_SHK3Dr_for/(Km_M_3dhsk_c_R_SHK3Dr_for*Km_M_h_c_R_SHK3Dr_for*Km_M_nadph_c_R_SHK3Dr_for*(1 + M_3dhsk_c/Km_M_3dhsk_c_R_SHK3Dr_for)*(1 + M_h_c/Km_M_h_c_R_SHK3Dr_for)*(1 + M_nadp_c/Km_M_nadp_c_R_SHK3Dr_for)*(1 + M_nadph_c/Km_M_nadph_c_R_SHK3Dr_for)*(1 + M_skm_c/Km_M_skm_c_R_SHK3Dr_for)) + 1.0*M_3dhq_c*v_max_R_DHQTi/(Km_M_3dhq_c_R_DHQTi*(1 + M_3dhq_c/Km_M_3dhq_c_R_DHQTi)*(1 + M_3dhsk_c/Km_M_3dhsk_c_R_DHQTi)*(1 + M_h2o_c/Km_M_h2o_c_R_DHQTi))
dM_h_c/dt = 1.0*M_uacgam_c*M_uagmda_c*v_max_R_UAGPT3/(Km_M_uacgam_c_R_UAGPT3*Km_M_uagmda_c_R_UAGPT3*(1 + M_h_c/Km_M_h_c_R_UAGPT3)*(1 + M_uaagmda_c/Km_M_uaagmda_c_R_UAGPT3)*(1 + M_uacgam_c/Km_M_uacgam_c_R_UAGPT3)*(1 + M_uagmda_c/Km_M_uagmda_c_R_UAGPT3)*(1 + M_udp_c/Km_M_udp_c_R_UAGPT3)) + 2.0*M_uaagmda_c*v_max_R_MPTG/(Km_M_uaagmda_c_R_MPTG*(1 + M_h_c/Km_M_h_c_R_MPTG)*(1 + M_murein5p5p_p/Km_M_murein5p5p_p_R_MPTG)*(1 + M_uaagmda_c/Km_M_uaagmda_c_R_MPTG)*(1 + M_udcpdp_c/Km_M_udcpdp_c_R_MPTG)) + 1.0*M_ru5p

In [None]:
#for species, ode in ode_system.items():
#    symbs=ode.free_symbols
#    if sp.Symbol('biomass') in symbs:
#      print("ode_s", symbs)

**Construction du dictionnaire des paramètres** : construction validée, manque des valeurs réelles des paramètres. Pas d'ajustement de la std sur la biomasse (à ajouter dans un second temps).

In [None]:
#Building dictionnary – Parameters
def get_all_symbols(ode_exprs):
    symbols = set()
    for expr in ode_exprs.values():
        symbols.update(expr.free_symbols)
    return sorted(symbols, key=lambda s: str(s))
all_symbs=get_all_symbols(ode_system)

In [None]:
def split_species_and_params(expr, species_ids):
    expr_symbols = expr.free_symbols
    species_syms = {sp.Symbol(sid) for sid in species_ids}

    species = sorted(expr_symbols & species_syms, key=str)
    parameters = sorted(expr_symbols - species_syms, key=str)

    return species, parameters

In [None]:
#True values dictionnary
#ode_parameters_dict={}
#for symb in all_symbs:
#  ode_parameters_dict[symb]= value
#ode_parameters_dict

#Ranges dictionnary and Variables std
ode_parameter_ranges_dict={}
variables_standard_dev_dict={}
species_ids = [s.getId() for s in model.getListOfSpecies()]
for species, expr in ode_system.items():
    ode_vars, params = split_species_and_params(expr, species_ids)
    for p in params:
        ode_parameter_ranges_dict[p]=(1e-5,1e5)
    for v in ode_vars:
        print(v)
        variables_standard_dev_dict[v]=1

#Adding some complementary parameters
ode_parameter_ranges_dict["volume"]=(1e-5,1e5)

[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
M_dctp_c
M_dgtp_c
M_dtdp_c
M_dttp_c
M_fad_c
M_fe2_c
M_fe3_c
M_gln__L_c
M_glu__L_c
M_gly_c
M_gtp_c
M_h2o_c
M_h_c
M_his__L_c
M_ile__L_c
M_k_c
M_kdo2lipid4_e
M_leu__L_c
M_lys__L_c
M_met__L_c
M_mg2_c
M_mlthf_c
M_mn2_c
M_mobd_c
M_murein5px4p_p
M_nad_c
M_nadp_c
M_nh4_c
M_ni2_c
M_pe160_p
M_pe161_p
M_phe__L_c
M_pheme_c
M_pi_c
M_ppi_c
M_pro__L_c
M_pydx5p_c
M_ribflv_c
M_ser__L_c
M_sheme_c
M_so4_c
M_succoa_c
M_thf_c
M_thmpp_c
M_thr__L_c
M_trp__L_c
M_tyr__L_c
M_udcpdp_c
M_utp_c
M_val__L_c
M_zn2_c
biomass
M_ac_c
M_acg5sa_c
M_acorn_c
M_akg_c
M_glu__L_c
M_h2o_c
M_orn_c
M_ac_c
M_acorn_c
M_cbp_c
M_citr__L_c
M_h2o_c
M_h_c
M_orn_c
M_pi_c
M_10fthf_c
M_2fe2s_c
M_2ohph_c
M_3fe4s_c
M_4fe4s_c
M_adp_c
M_ala__L_c
M_amet_c
M_arg__L_c
M_asn__L_c
M_asp__L_c
M_atp_c
M_btn_c
M_ca2_c
M_cl_c
M_coa_c
M_cobalt2_c
M_ctp_c
M_cu2_c
M_cys__L_c
M_datp_c
M_dctp_c
M_dgtp_c
M_dttp_c
M_fad_c
M_fadh2_c
M_fe2_c
M_fe2_p
M_fe3_c
M_gln__L_c
M_

In [None]:
len(ode_parameter_ranges_dict.keys()) #Nombre de paramètres à estimer

2427

In [None]:
len(variables_standard_dev_dict.keys()) #Nombre de métabolites d'étude

472

In [None]:
sp.Symbol('biomass') in variables_standard_dev_dict.keys()

True

**Construction du dictionnaire et des ODE** : construction du dictionnaire des résidus, et de la fonction permettant l'intégration odeint. Working

In [None]:
#Building dictionnary – Equations
def build_lambda_dict(ode_exprs):
    lambda_dict = {}
    symbol_dict = {}
    ode_dict={}

    for species, expr in ode_exprs.items():
        symbols = sorted(expr.free_symbols, key=lambda s: str(s))
        symbol_dict[species] = symbols
        lambda_dict[species] = sp.lambdify(symbols, expr, modules='numpy')
        ode_dict[species]=expr

    return lambda_dict, symbol_dict, ode_dict
lambda_dict, symbol_dict, ode_dict = build_lambda_dict(ode_system)

In [None]:
list(lambda_dict.keys()).index('biomass')

431

In [None]:
list(ode_parameter_ranges_dict.keys()).index('volume')

2426

In [None]:
def sort_symbols_by_name(symbols, names):
    # To sort the symbols by name
    sorted_symbols = [sym for _, sym in sorted(zip(names, symbols), key=lambda pair: pair[0])]
    return sorted_symbols

In [None]:
import sympy as sp

def create_dict_lambda(expression_func, var_keys, d_dt_keys, value_keys, min_keys, max_keys):
    # Create symbols
    all_keys = {
        "var_dict": var_keys,
        "d_dt_var_dict": d_dt_keys,
        "value": value_keys,
        "min_var_dict": min_keys,
        "max_var_dict": max_keys,
    }

    symbol_map = {}
    for dict_name, keys in all_keys.items():
        for key in keys:
            symbol = sp.Symbol(f"{dict_name}__{key}")
            symbol_map[(dict_name, key)] = symbol

    # Build the symbolic expression
    expr = expression_func(symbol_map)

    # Create the ordered list of variables to lambdify
    ordered_symbols = [symbol_map[(dict_name, key)]
                       for dict_name, keys in all_keys.items()
                       for key in keys]

    # Create a lambda function using sympy.lambdify
    f = sp.lambdify(ordered_symbols, expr, modules="numpy")

    # Define the final lambda wrapper that takes 5 dictionaries
    def final_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict):
        lookup = {
            "var_dict": var_dict,
            "d_dt_var_dict": d_dt_var_dict,
            "value": value,
            "min_var_dict": min_var_dict,
            "max_var_dict": max_var_dict
        }
        args = [lookup[dict_name][key] for dict_name, keys in all_keys.items() for key in keys]
        return f(*args)

    return final_lambda

In [None]:
'''# Define keys
var_keys = ["X", "ACCOA", "ACE_env"]
d_dt_keys = ["X"]
value_keys = ["v_max_TCA_cycle", "Km_ACCOA_TCA_cycle", "Ki_ACE_TCA_cycle"]
min_keys = []
max_keys = []

# Define symbolic expression builder
def expr_func(symbol_map):
    get = lambda d, k: symbol_map[(d, k)]
    def v_TCA_cycle(v_max_TCA_cycle, ACCOA, Km_ACCOA_TCA_cycle, ACE_env, Ki_ACE_TCA_cycle):
        return v_max_TCA_cycle * ACCOA / (Km_ACCOA_TCA_cycle + ACCOA) * (1 / (1 + ACE_env / Ki_ACE_TCA_cycle))

    return get("d_dt_var_dict", "X") - (
        get("var_dict", "X") * v_TCA_cycle(
            get("value", "v_max_TCA_cycle"),
            get("var_dict", "ACCOA"),
            get("value", "Km_ACCOA_TCA_cycle"),
            get("var_dict", "ACE_env"),
            get("value", "Ki_ACE_TCA_cycle"),
        )
    )

# Create the lambda
my_lambda = create_dict_lambda(expr_func, var_keys, d_dt_keys, value_keys, min_keys, max_keys)

# Example usage
var_dict = {"X": 1.0, "ACCOA": 0.5, "ACE_env": 0.2}
d_dt_var_dict = {"X": 0.05}
value = {"v_max_TCA_cycle": 2.0, "Km_ACCOA_TCA_cycle": 0.1, "Ki_ACE_TCA_cycle": 0.05}
min_var_dict = {}
max_var_dict = {}

# Evaluate
result = my_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict)
print("Result:", result)'''

'# Define keys\nvar_keys = ["X", "ACCOA", "ACE_env"]\nd_dt_keys = ["X"]\nvalue_keys = ["v_max_TCA_cycle", "Km_ACCOA_TCA_cycle", "Ki_ACE_TCA_cycle"]\nmin_keys = []\nmax_keys = []\n\n# Define symbolic expression builder\ndef expr_func(symbol_map):\n    get = lambda d, k: symbol_map[(d, k)]\n    def v_TCA_cycle(v_max_TCA_cycle, ACCOA, Km_ACCOA_TCA_cycle, ACE_env, Ki_ACE_TCA_cycle):\n        return v_max_TCA_cycle * ACCOA / (Km_ACCOA_TCA_cycle + ACCOA) * (1 / (1 + ACE_env / Ki_ACE_TCA_cycle))\n\n    return get("d_dt_var_dict", "X") - (\n        get("var_dict", "X") * v_TCA_cycle(\n            get("value", "v_max_TCA_cycle"),\n            get("var_dict", "ACCOA"),\n            get("value", "Km_ACCOA_TCA_cycle"),\n            get("var_dict", "ACE_env"),\n            get("value", "Ki_ACE_TCA_cycle"),\n        )\n    )\n\n# Create the lambda\nmy_lambda = create_dict_lambda(expr_func, var_keys, d_dt_keys, value_keys, min_keys, max_keys)\n\n# Example usage\nvar_dict = {"X": 1.0, "ACCOA": 0.5, "A

In [None]:
def build_ode_dict(lambda_dict,ode_d):
    ode_dict={}
    list_spes=list(lambda_dict.keys())
    species_ids = [s.getId() for s in model.getListOfSpecies()]

    for i in range(len(list_spes)):
        new_symbol_var=list_spes[i]
        list_nsv=[new_symbol_var] #d_dt_keys, min_var_keys, max_var_keys
        expr_=ode_d[list_spes[i]]
        vars, params =split_species_and_params(expr_, species_ids)
        symbs_var= [str(v) for v in vars if str(v) != list_spes[i]]
        var_keys=symbs_var + list_nsv
        if 'biomass' not in var_keys:
            var_keys = var_keys + ["biomass"] #var_keys
        symbs_params=[str(p) for p in params]
        values_keys=symbs_params + ["volume"] #values_keys

      #Building the function for symbolic expression
      def expr_func(symbol_map):
        get = lambda d, k: symbol_map[(d, k)]
        names_args=[sv for sv in symbs_var] + [sp for sp in symbs_params] + [new_symbol_var]
        args=[get("var_dict",sv) for sv in symbs_var] + [get("value",sp) for sp in symbs_params]+[get("var_dict",new_symbol_var)]
        args=sort_symbols_by_name(args, names_args)
        args=tuple(args)
        return get("d_dt_var_dict", new_symbol_var) - ((lambda_dict[list_spes[i]](*args))*get("var_dict","biomass")*get("value","volume"))/(get("max_var_dict",new_symbol_var)-get("min_var_dict",new_symbol_var))

      ode_dict[f"ode_{i}"] = create_dict_lambda(expr_func, var_keys, list_nsv, values_keys, list_nsv, list_nsv)

    return ode_dict

ode_dict_2 = build_ode_dict(lambda_dict, ode_dict)

In [None]:
## OLD VERSION
'''def build_ode_dict(lambda_dict,ode_d):
    ode_dict={}
    list_spes=list(lambda_dict.keys())
    symbols= [sp.IndexedBase(f"var_dict"),sp.IndexedBase(f"d_dt_var_dict"),sp.IndexedBase(f"value"),
              sp.IndexedBase(f"min_var_dict"), sp.IndexedBase(f"max_var_dict")]
    idx_biomass=list(lambda_dict.keys()).index('biomass')
    biomass=sp.Symbol(list_spes[idx_biomass]) #f"{}"
    volume=sp.Symbol("volume")
    species_ids = [s.getId() for s in model.getListOfSpecies()]

    for i in range(len(list_spes)):
      new_symbol_var=list_spes[i]
      expr_=ode_d[list_spes[i]]
      vars, params =split_species_and_params(expr_, species_ids)
      symbs_var= [str(v) for v in vars if str(v) != list_spes[i]] #sp.Symbol(v)
      symbs_params=[str(p) for p in params] #sp.Symbol(p)
      names_args=[sv for sv in symbs_var] + [sp for sp in symbs_params] + [new_symbol_var]
      args=[symbols[0][f"{str(sv)}"] for sv in symbs_var] + [symbols[2][f"{str(sp)}"] for sp in symbs_params]+[symbols[0][f"{str(new_symbol_var)}"]]
      #print(args)
      args=sort_symbols_by_name(args, names_args) ## PROBLEM : sorting the list
      args=tuple(args)
      expr= symbols[1][f"{str(new_symbol_var)}"] - ((lambda_dict[list_spes[i]](*args))*symbols[0][f"{str(biomass)}"]*symbols[2][f"{str(volume)}"])/(symbols[4][f"{str(new_symbol_var)}"]-symbols[3][f"{str(new_symbol_var)}"])
      ode_dict[f"ode_{i}"] = sp.lambdify(symbols, expr, modules='numpy')

    return ode_dict

ode_dict_2 = build_ode_dict(lambda_dict, ode_dict)'''

'def build_ode_dict(lambda_dict,ode_d):\n    ode_dict={}\n    list_spes=list(lambda_dict.keys())\n    symbols= [sp.IndexedBase(f"var_dict"),sp.IndexedBase(f"d_dt_var_dict"),sp.IndexedBase(f"value"),\n              sp.IndexedBase(f"min_var_dict"), sp.IndexedBase(f"max_var_dict")]\n    idx_biomass=list(lambda_dict.keys()).index(\'biomass\')\n    biomass=sp.Symbol(list_spes[idx_biomass]) #f"{}"\n    volume=sp.Symbol("volume")\n    species_ids = [s.getId() for s in model.getListOfSpecies()]\n\n    for i in range(len(list_spes)):\n      new_symbol_var=list_spes[i]\n      expr_=ode_d[list_spes[i]]\n      vars, params =split_species_and_params(expr_, species_ids)\n      symbs_var= [str(v) for v in vars if str(v) != list_spes[i]] #sp.Symbol(v)\n      symbs_params=[str(p) for p in params] #sp.Symbol(p)\n      names_args=[sv for sv in symbs_var] + [sp for sp in symbs_params] + [new_symbol_var]\n      args=[symbols[0][f"{str(sv)}"] for sv in symbs_var] + [symbols[2][f"{str(sp)}"] for sp in symbs_

In [None]:
ode_dict_2

{'ode_0': <function __main__.create_dict_lambda.<locals>.final_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict)>,
 'ode_1': <function __main__.create_dict_lambda.<locals>.final_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict)>,
 'ode_2': <function __main__.create_dict_lambda.<locals>.final_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict)>,
 'ode_3': <function __main__.create_dict_lambda.<locals>.final_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict)>,
 'ode_4': <function __main__.create_dict_lambda.<locals>.final_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict)>,
 'ode_5': <function __main__.create_dict_lambda.<locals>.final_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict)>,
 'ode_6': <function __main__.create_dict_lambda.<locals>.final_lambda(var_dict, d_dt_var_dict, value, min_var_dict, max_var_dict)>,
 'ode_7': <function __main__.create_dict_lambda.<locals>.final_lambda(var_di

## PINN implementation

In [None]:
pip install optuna

Collecting optuna
  Downloading optuna-4.3.0-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.16.1-py3-none-any.whl.metadata (7.3 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Downloading optuna-4.3.0-py3-none-any.whl (386 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m386.6/386.6 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.16.1-py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.9.0-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, alembic, optuna
Successfully installed alembic-1.16.1 colorlog-6.9.0 optuna-4.3.0


In [None]:
#Imports
import torch
import random
import os
from numpy import genfromtxt
from tools import random_ranges #lib
import numpy as np
from pinn_warming_softadapt import Pinn #lib
import matplotlib.pyplot as plt
import jax.numpy as jnp
import pandas as pd

random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7aaeb2bade30>

In [None]:
# Load experimental data
all_data = pd.read_csv('Growth_curves_copie_nodots.csv', delimiter=';')

#Getting the OD
data=all_data.loc[[1]]
data=data.dropna(axis=1)
data=data.drop("Title", axis=1)
data=data.to_numpy()
#Transforming to biomass (first draft)
data=data*0.44 #Usual conversion factor for E.coli, to elaborate
data=data[0]

#Getting the time
data_t = all_data.loc[[0]]
data_t=data_t.dropna(axis=1)
data_t=data_t.drop(["Media ID","Title"], axis=1)
data_t=data_t.to_numpy()
data_t=data_t[0]

#Data Auxiliary – To adapt
#data_aux_1mM=[torch.tensor([12.89999655, 0.9200020244, 0.06999993881, 0.27305, 0.063, 1.035]),
#              torch.tensor([0.77457781,  3.99120414,  0.85387902,  0.54487839,  0.12541807, 4.07405124])]

In [None]:
#Experimetal case
observables = ["biomass"]
variable_data = {"biomass": data}
variable_no_data  = {str(key): None for key in variables_standard_dev_dict.keys() if key != sp.Symbol("biomass")} #{"ACCOA":None,"ACP":None,"ACE_cell":None}
data_t = data_t

parameter_names= [str(i) for i in ode_parameter_ranges_dict.keys()]#list(ode_parameter_ranges_dict.keys())
param_names=list(ode_parameter_ranges_dict.keys())

In [None]:
# Creating the ranges: using random_ranges or the ranges provided by Millard
ranges=[ode_parameter_ranges_dict[key] for key in param_names]
for i,name in enumerate(parameter_names):
    if name in ode_parameter_ranges_dict:
        ranges[i]= ode_parameter_ranges_dict[name]

constants_dict = {str(key): None for key in ode_parameter_ranges_dict.keys()} #ode_parameters_dict

In [None]:
# Training parameters
epoch_number = 150000

# Optimizer parameters
optimizer_type = "Adam"
optimizer_hyperparameters = {"lr":1e-4, "betas":(0.9, 0.8)} #, "betas":(0.9, 0.8)

# Scheduler parameters
scheduler_hyperparameters = {"base_lr":1e-4,
                             "max_lr":1e-4,
                             "step_size_up":100,
                             "scale_mode":"exp_range",
                             "gamma":0.999,
                             "cycle_momentum":False}

# Loss calculation method
multiple_loss_method ="soft_adapt"

In [None]:
#Because otherwise the loss explodes
res_loss=[842870144.0, 17186459648.0, 2.798671245574951, 1.7806094884872437, 0.022711969912052155, 0.002836819738149643, 506990848.0, 0.21996110677719116, 0.07969094812870026, 786085120.0, 3141105664.0, 1437840512.0, 321427800064.0, 0.033832889050245285, 1191327872.0, 359538528.0, 0.0033868050668388605, 12.888894081115723, 35.17216491699219, 1.0654918014552095e-07, 0.03697841241955757, 7578031104.0, 445359616.0, 445388192.0, 631991488.0, 631994176.0, 81.48145294189453, 323542624.0, 74.53617095947266, 0.016376590356230736, 1294176640.0, 323548000.0, 329867680.0, 3353050112.0, 0.02652144990861416, 4.064201354980469, 2.816359758377075, 0.009874053299427032, 8.576565630117872e+18, 317588544.0, 0.157928004860878, 704844800.0, 0.06266576796770096, 382942592.0, 1.9477361945519078e-08, 693273344.0, 693252416.0, 0.023737825453281403, 0.040480777621269226, 823198464.0, 1301722240.0, 1.5333547592163086, 0.02839387208223343, 734194752.0, 0.2676365077495575, 870460736.0, 852284544.0, 870459264.0, 718452800.0, 0.7231525778770447, 0.024585021659731865, 0.01785426400601864, 0.2448606938123703, 7.681457611991941e+18, 0.04756324365735054, 11.032358169555664, 6.143470727693057e-06, 0.0749589204788208, 1348611200.0, 1348584448.0, 0.7545976638793945, 0.036011405289173126, 422762528.0, 0.024269383400678635, 567569472.0, 1379312896.0, 1379301376.0, 0.09160599112510681, 0.005137595813721418, 302362432.0, 275468384.0, 86517016.0, 4802625024.0, 0.047364071011543274, 772462912.0, 4.5821376915000656e-08, 651116544.0, 651101056.0, 0.0027889825869351625, 0.40470945835113525, 2.2622189987941965e-07, 948804672.0, 8.326557576765481e-08, 577215104.0, 577216704.0, 0.07763822376728058, 4.821667687338049e-08, 0.07277147471904755, 0.009972860105335712, 0.0027702213265001774, 522780032.0, 2840978.25, 0.0013896515592932701, 0.03831253573298454, 0.006458199582993984, 0.0012931155506521463, 251561216.0, 251561456.0, 6002137088.0, 84108192.0, 0.00785770732909441, 0.4588962197303772, 0.10060746967792511, 0.00586169445887208, 0.05625240504741669, 0.056035641580820084, 421794496.0, 421785760.0, 0.005542227067053318, 3.053621292114258, 0.03563125804066658, 518946912.0, 0.02163054421544075, 458983328.0, 458969792.0, 252501712.0, 252494016.0, 0.4402288496494293, 0.04056740179657936, 949384192.0, 2381426432.0, 0.039642855525016785, 0.024405699223279953, 754498368.0, 754489152.0, 1.8631023834814187e-08, 0.35780513286590576, 422689568.0, 422714176.0, 0.034757401794195175, 0.05379288271069527, 0.8570560812950134, 0.05712774023413658, 0.012863499112427235, 0.33305367827415466, 2328568576.0, 0.02830183319747448, 5.948107073550091e-08, 0.07482540607452393, 0.028649497777223587, 1814332032.0, 0.03830275312066078, 176757696.0, 0.01765320636332035, 2929657856.0, 2929678080.0, 0.004105704370886087, 0.08508935570716858, 1.9191037381460774e-07, 0.07598941028118134, 0.06788276880979538, 0.17982129752635956, 0.041696321219205856, 0.0020222426392138004, 0.0015304891858249903, 0.06266263127326965, 0.3965437114238739, 0.01785818114876747, 376873568.0, 2.438305139541626, 498600064.0, 498601504.0, 3.6510680700985176e-08, 0.00035541821853257716, 0.052199166268110275, 0.046189043670892715, 567551360.0, 718465536.0, 3.052615165710449, 0.01226806640625, 0.0015696686459705234, 0.04051772505044937, 0.029865572229027748, 0.08106798678636551, 429277760.0, 0.007032346911728382, 0.027725283056497574, 1265630848.0, 0.017894743010401726, 0.049202557653188705, 0.049154311418533325, 948764224.0, 3.9539717278103126e-08, 0.128246009349823, 0.035301920026540756, 0.08347871899604797, 1.1742972816364272e-07, 0.9258205890655518, 0.009413857012987137, 376882400.0, 0.2988339960575104, 1.250818542075649e-07, 0.22013624012470245, 0.003219499019905925, 382933504.0, 0.03239552676677704, 0.04353613406419754, 0.0052312221378088, 0.04778160899877548, 0.04799743369221687, 747895168.0, 408917376.0, 408915168.0, 408942208.0, 0.07139603793621063, 0.04837152734398842, 0.13872092962265015, 0.0054526012390851974, 2.8772493720907732e-08, 5.489173560135896e-08, 0.1987171620130539, 0.013612412847578526, 1.573526020592908e-08, 734065856.0, 786085120.0, 0.0003791186027228832, 0.07410339266061783, 0.04342708736658096, 0.04965328797698021, 9.378187115771652e-08, 0.0017039981903508306, 0.03410724550485611, 0.131031796336174, 3.408129245485725e-08, 0.29818639159202576, 526857248.0, 0.061076149344444275, 0.06117855757474899, 526855072.0, 486075712.0, 486080256.0, 0.012955710291862488, 581696768.0, 581700992.0, 3.838744078166201e-08, 0.027638893574476242, 0.12102679163217545, 0.07173839211463928, 7.777106958428703e-08, 0.013798055239021778, 0.00020227275672368705, 0.04779655858874321, 0.05525387078523636, 0.06939173489809036, 0.0005139486747793853, 5.955865708529018e-05, 0.09012752026319504, 0.45898714661598206, 0.07294610142707825, 1.93587602552725e-05, 0.03574739024043083, 1280870400.0, 0.07079622149467468, 0.007378921844065189, 0.018154116347432137, 0.07088840007781982, 587482496.0, 0.01597529463469982, 731340992.0, 0.3351377546787262, 949384192.0, 759847552.0, 3.779211610321909e-09, 0.007210777141153812, 0.18031719326972961, 0.00885346531867981, 0.21822482347488403, 0.07574082165956497, 0.07759197056293488, 0.0017782142385840416, 0.02230025641620159, 0.027699680998921394, 0.006606159266084433, 0.025788333266973495, 0.011401587165892124, 0.026736345142126083, 416235616.0, 569355008.0, 0.02943551354110241, 63769992.0, 0.00048652137047611177, 359436800.0, 3.848699165785719e+18, 6.301233518819148e+18, 1814314496.0, 0.07772038131952286, 7.458944496076587e+18, 522656960.0, 6.084959581635609e+18, 692031296.0, 25902140.0, 0.08960293978452682, 0.0015344693092629313, 0.018256476148962975, 587482496.0, 3.509342323576144e+18, 8125896.0, 0.011119825765490532, 0.01215053629130125, 7.707346163023741e+18, 211545488.0, 258917024.0, 0.0004045231326017529, 0.008046606555581093, 17121166.0, 0.01594538614153862, 0.0003416405525058508, 0.016035793349146843, 180787120.0, 1.1001255728615433e-07, 7.170765247215567e+18, 3.1174706070007316e+18, 157443456.0, 5.219346512680583e+18, 3.7530529241633915e+18, 369958816.0, 4.672745747408486e+18, 525204640.0, 0.007144563365727663, 0.1423605978488922, 3.592766119066206e+18, 76062264.0, 6.191669933890339e+18, 0.02280287817120552, 0.12722386419773102, 0.16489265859127045, 3.930150782593955e-08, 0.016125615686178207, 4.931549893866029e+18, 69872512.0, 4.723997832669823e+18, 327488352.0, 0.006713880691677332, 7.42426425006909e+18, 246461152.0, 0.060209497809410095, 0.029082361608743668, 1.5728344848753295e-08, 361938624.0, 0.0013339106226339936, 0.0013399237068369985, 0.04851968213915825, 0.01913989707827568, 0.013485024683177471, 0.06609039753675461, 1263928448.0, 1044511616.0, 772445824.0, 0.035566944628953934, 945123264.0, 1819382912.0, 1819426176.0, 0.07857595384120941, 548733440.0, 0.036114491522312164, 0.0006440538563765585, 6.179427458619102e-08, 0.0006575316656380892, 0.013536144979298115, 0.06496839225292206, 0.053243525326251984, 1191294080.0, 6.70619794498021e+18, 819040320.0, 0.006818392314016819, 3.222622879661685e-08, 1.1805235544670722e-07, 2.318761538333547e-08, 0.055456992238759995, 1155344000.0, 2725938944.0, 1.0161588903656593e-07, 627493760.0, 303945696.0, 1.0095458030700684, 4.6301760647793344e-08, 7.341079566458575e-08, 0.008360564708709717, 354044160.0, 0.028453001752495766, 1.8412282543067704e-07, 1776275712.0, 1.6001246549990356e-08, 5.258352331338756e-08, 1.399137943280948e-07, 699525312.0, 699525312.0, 0.047652050852775574, 0.017373640090227127, 1776275712.0, 440826176.0, 0.1240222305059433, 0.4278763234615326, 0.09155859053134918, 0.05156620591878891, 0.05153803154826164, 1.3625592210075865e-08, 2.375575292035137e-08, 2.9061729378554446e-07, 0.10269053280353546, 3.491695679258555e-07, 0.35644394159317017, 0.0763024389743805, 0.07614729553461075, 0.02042684704065323, 0.2153996378183365, 7.207560059896423e-08, 1.946272476516242e-07, 0.04487757384777069, 526289376.0, 987415552.0, 2929683712.0, 987428800.0, 3.7811491715444845e-09, 0.09501170367002487, 1.0232344749283584e-07, 5.2068029617657885e-05, 0.052289657294750214, 0.15607240796089172, 0.06038887798786163, 1.0007180861748566e-07, 6.2639604614389555e-09, 3.311684492640387e+18, 759836416.0, 5.434651230364238e+18, 8.745170240679182e+18, 8.725093708111806e+18, 8.668402888583676e+18, 6.044956050082234e+18, 5.078638711139205e+18, 6.226982399083807e+18, 6.774639795030196e+18, 5.30667632322832e+18, 5.460024110442611e+18, 4.372913599438389e+18, 8.141214550855778e+18, 3.6293339516606874e+18, 4.2604786398936433e+18, 6.025471054770602e+18, 3.3010610112928154e+18, 3.910775118388789e+18, 0.022770065814256668, 7.880055800247353e+18, 3.9108768232143585e+18, 4.170918195131056e+18, 7.996383030954426e+18, 4.864363686094963e+18, 0.09781336039304733, 43777904.0, 57635852.0, 4.0690652103699333e+18, 6.878559037017817e+18, 7.084311197679354e+18, 4.566423522758361e+18, 1.024829288098772e-07, 7.36194722954163e+18, 8.721376259298296e+18, 7.672881421295288e+18, 3.8862455637289206e+18, 7.299444391549075e+18, 8.591622342348636e+18, 168239248.0]

def first_non_empty_index(lst):
    return next(i for i, s in enumerate(lst) if s)

#Function to optimise the weighting based on first values of the loss function.
def get_weighting(list_loss):
    weightings=[]
    for elt in list_loss:
        order=str(elt).split(".")
    if order[0]!="0":
        weightings.append(10**(-len(order[0])+1))
    else:
        v=order[1].split("0")
        weightings.append(10**(first_non_empty_index(v)))
    return weightings

residual_weights = [0 for i in range(len(list(variable_no_data.keys()))+1)] #get_weighting(res_loss)

In [None]:
#Creating PINN
pinn_cell = Pinn(ode_residual_dict= ode_dict_2, #ODE_residual_dict_Millard,
                 ranges=ranges,
                 data_t=data_t,
                 variables_data=variable_data,
                 variables_no_data=variable_no_data,
                 #data_aux=data_aux_1mM,
                 parameter_names=parameter_names,
                 optimizer_type=optimizer_type,
                 optimizer_hyperparameters=optimizer_hyperparameters,
                 scheduler_hyperparameters=scheduler_hyperparameters,
                 constants_dict=constants_dict,
                 multi_loss_method=multiple_loss_method,
                 residual_weights=residual_weights,
                 )

# Training
r2_score, pred_variables, losses, variable_fit_losses, residual_losses, all_learned_parameters, learning_rates = pinn_cell.train(epoch_number) #epoch_number
#, aux_losses
biomass=pred_variables

Training the neural network:   0%|                                                                                         | 0/150000 [00:00<?, ?it/s]

[1.1864223773003875e-09, 5.818534011548857e-11, 0.3573124215933275, 0.561605453899705, 44.02964621176906, 352.50741756762613, 1.9724221925205245e-09, 4.546258266525939, 12.548476627295335, 1.2721268658539168e-09, 3.18359236195373e-10, 6.954874283024793e-10, 3.111118577176238e-12, 29.557038375141367, 8.393994831340603e-10, 2.781343088771838e-09, 295.26352425513795, 0.07758617564133441, 0.02843157372769185, 9385337.349703083, 27.042805100824406, 1.3196039792871242e-10, 2.2453764644884192e-09, 2.2452324016708553e-09, 1.5822997919870718e-09, 1.5822930621436612e-09, 0.012272731571356648, 3.090782870080203e-09, 0.013416304958081724, 61.062771813153034, 7.726920492089859e-10, 3.0907315143348127e-09, 3.031518577388364e-09, 2.982359244859386e-10, 37.705329212608405, 0.24605080129077553, 0.35506827457875945, 101.27553190927455, 1.165967874703078e-19, 3.1487281858630267e-09, 6.331999197234971, 1.418752043002942e-09, 15.957675656594803, 2.6113574746994977e-09, 51341655.13775124, 1.4424324960055005

Training the neural network:   0%|                                                                             | 1/150000 [00:08<338:38:48,  8.13s/it]

[1.1864223773003875e-09, 5.818534011548857e-11, 0.3573124215933275, 0.561605453899705, 44.02964621176906, 352.50741756762613, 1.9724221925205245e-09, 4.546258266525939, 12.548476627295335, 1.2721268658539168e-09, 3.18359236195373e-10, 6.954874283024793e-10, 3.111118577176238e-12, 29.557038375141367, 8.393994831340603e-10, 2.781343088771838e-09, 295.26352425513795, 0.07758617564133441, 0.02843157372769185, 9385337.349703083, 27.042805100824406, 1.3196039792871242e-10, 2.2453764644884192e-09, 2.2452324016708553e-09, 1.5822997919870718e-09, 1.5822930621436612e-09, 0.012272731571356648, 3.090782870080203e-09, 0.013416304958081724, 61.062771813153034, 7.726920492089859e-10, 3.0907315143348127e-09, 3.031518577388364e-09, 2.982359244859386e-10, 37.705329212608405, 0.24605080129077553, 0.35506827457875945, 101.27553190927455, 1.165967874703078e-19, 3.1487281858630267e-09, 6.331999197234971, 1.418752043002942e-09, 15.957675656594803, 2.6113574746994977e-09, 51341655.13775124, 1.4424324960055005

Training the neural network:   0%|                                                                             | 1/150000 [00:11<466:24:47, 11.19s/it]


ValueError: loss is not a number (nan) anymore. Consider changing the hyperparameters. This happened at epoch 1.

In [None]:
# Print and Plot learning rate
plt.figure(figsize=(5,3))
plt.plot(learning_rates[0:], color = 'teal',linewidth=4)
plt.grid(True)
plt.xlabel('Epochs',fontsize=15)
plt.ylabel('Learning rate',fontsize=15)

plt.tight_layout()

#fig_name = 'learning_rate'
#plt.savefig(fig_name+'.png', format='png')

plt.show()

In [None]:
## Print and Plot Losses
print("Loss: ","%.5g" % losses[-1])
fig, axs = plt.subplots(1, 4, figsize=(15, 6))

axs[0].plot(losses[0:], color = 'teal',linewidth=4)
axs[0].grid(True)
axs[0].set_xlabel('Epochs',fontsize=15)
axs[0].set_ylabel('Loss',fontsize=15)
axs[0].set_xscale('log')
axs[0].set_yscale('log')

axs[1].plot(variable_fit_losses[0:], color = 'teal',linewidth=4)
axs[1].grid(True)
axs[1].set_xlabel('Epochs',fontsize=15)
axs[1].set_ylabel('Variable fit loss',fontsize=15)
axs[1].set_xscale('log')
axs[1].set_yscale('log')

axs[2].plot(residual_losses[0:], color = 'teal',linewidth=4)
axs[2].grid(True)
axs[2].set_xlabel('Epochs',fontsize=15)
axs[2].set_ylabel('Residual loss',fontsize=15)
axs[2].set_xscale('log')
axs[2].set_yscale('log')


axs[3].plot(aux_losses[0:], color = 'teal',linewidth=4)
axs[3].grid(True)
axs[3].set_xlabel('Epochs',fontsize=15)
axs[3].set_ylabel('Auxiliary loss',fontsize=15)
axs[3].set_xscale('log')
axs[3].set_yscale('log')

plt.tight_layout()

#fig_name = 'loss_short'
#plt.savefig(fig_name+'.png', format='png')

plt.show()

In [None]:
## Print and Plot R2
print("r2 :","%.5g" % r2_score[-1])

plt.figure(figsize=(5,3))
plt.plot(r2_score, color = 'black',linewidth=4)
plt.grid(True)
plt.xlabel('Epochs',fontsize=15)
plt.ylabel('Params error',fontsize=15)
plt.xscale('log')
plt.yscale('log')

plt.tight_layout()

#fig_name = 'error_short'
#plt.savefig(fig_name+'.png', format='png')

plt.show()

In [None]:
## Comparing parameters – No true value, so not working
learned_parameters=[pinn_cell.output_param_range(v,i).item() for (i,(k,v)) in enumerate(pinn_cell.ode_parameters.items())]
true_parameters=[ode_parameters_dict[key] for key in parameter_names]

plt.grid('true')
plt.plot([0, 2*10**10], [0, 2*10**10],color='black')
plt.scatter(true_parameters,learned_parameters)
plt.xscale('log')
plt.yscale('log')

cmap = plt.cm.get_cmap('viridis', len(ranges))  # Get a colormap with as many colors as there are ranges

# Map each range index to a color from the colormap
colors = [cmap(i) for i in range(len(ranges))]


for i, (true_val, learned_val) in enumerate(zip(true_parameters, learned_parameters)):
    plt.scatter(true_val, learned_val, s=70, color=colors[i], label=f'Range {i}' if i == 0 else "",zorder=3)

    # Also color the corresponding vertical line
    plt.vlines(x=true_val, ymin=ranges[i][0], ymax=ranges[i][1], colors=colors[i],zorder=2,linewidth=3)


min_value = min(r[0] for r in ranges)
max_value = max(r[1] for r in ranges)
plt.ylim([10**(-10),2*10**10])
plt.xlim([10**(-10),2*10**10])

plt.xlabel('True parameters',fontsize=20)
plt.ylabel('Learned parameters',fontsize=20)

plt.tight_layout()

#fig_name = 'params_short'
#plt.savefig(fig_name+'.png', format='png')

plt.show()

In [None]:
## Percentage error on parameters
#err=np.array([(abs(true_parameters[i]-learned_parameters[i])/true_parameters[i])*100 for i in range(len(true_parameters))])
#print("percentage error", np.mean(err))

## AIC
AIC = 2*(20+7*20+20)+2*(losses[-1])
print("AIC",AIC)

## R2_scores
from sklearn.metrics import r2_score
R2_scores_train_data=[r2_score(data.T,biomass_pred.detach().numpy())]
print("R2_scores_train_data",R2_scores_train_data)

In [None]:
plt.figure(1)
plt.plot(data_t, biomass_pred.detach().numpy(), linestyle='dashed', label='biomass_pred', color='b')
plt.plot(data_t, data, 'o', label='biomass_data', color='b')
plt.title('biomass')
plt.xlabel('t(h)')
plt.ylabel('biomass')
plt.legend()
plt.grid(True)
plt.show()