In [1]:
import sys
import torch
import random
sys.path.insert(0, '..')


import os
import optuna
from ode_equation.Millard_dicts import ode_parameters_dict, ode_parameter_ranges_dict
from script.pinn_score import pinn_score
from optuna.pruners import SuccessiveHalvingPruner


DATA_FILE = "../data/"
RESULT_FILE = "../result/"

random.seed(42)
torch.manual_seed(42)


DATA_FILE = "../data/"
RESULT_FILE = "../result/"

STUDY_NAME = "test"


# Define an objective function to be minimized.
def objective(trial):

    epoch = trial.suggest_int("epoch",10000,200000)
    lr=trial.suggest_float("lr",1e-6,1)
    base_lr=trial.suggest_float("base_lr",1e-6,1)
    multiple_loss_method = trial.suggest_categorical("multiple_loss_method",["prior_losses", "soft_adapt"]) # other method
    if multiple_loss_method=="prior_losses":
        prior_losses_t=trial.suggest_categorical("prior_losses_t",[0,100])
        method_parameters = {"prior_losses_t":prior_losses_t}
    if multiple_loss_method == "soft_adapt":
        method_parameters = {"soft_adapt_t":100,
                             "soft_adapt_beta":10,
                             "soft_adapt_by_type":True,
                             "soft_adapt_normalize":True,
                             }
    res_weights = [trial.suggest_float("res_weight_{}".format(i),1e-16,1) for i in range(6)]
    

    training_dict = {"epoch":epoch,
                     "optimizer" : {"name": "Adam",
                                    "parameters": {"lr":lr},
                                    },
                     "scheduler": {"name": "CyclicLR",
                                    "parameters":{"base_lr":base_lr, 
                                                  "max_lr":base_lr, 
                                                  "step_size_up":100,
                                                  "scale_mode":"exp_range",
                                                  "gamma":0.999,
                                                  "cycle_momentum":False,
                                                  }
                                  },
                     "multiple_loss_method": {"name": multiple_loss_method,
                                              "method_parameters": method_parameters,
                                              "manual_variable_weights":None,
                                              "manual_residual_weights":res_weights, #[1e-4,1e-1,1e-1,1e-14,1e-15,1e-8],
                                              },
                     }

    data_dict = {"file":os.path.join(DATA_FILE,'generated_Millard_data.csv'),
                 "observables":["GLC","ACE_env","X","ACCOA","ACP","ACE_cell"],
                 "unknown_variable":{},
                 "parameter_names": ["v_max_AckA",
                                     "v_max_Pta",
                                     "v_max_glycolysis",
                                     "Ki_ACE_glycolysis",
                                     "Km_ACCOA_TCA_cycle",
                                     "v_max_TCA_cycle",
                                     "Ki_ACE_TCA_cycle", 
                                     "Y","v_max_acetate_exchange",
                                     "Km_ACE_acetate_exchange"],
                 }

    ode_dict = {"ode_parameter_dict":ode_parameters_dict,
                "ode_parameter_ranges_dict":ode_parameter_ranges_dict,
                }
    
    if trial.should_prune():
        raise optuna.TrialPruned()
    return pinn_score(training_dict, data_dict, ode_dict,seed=42,optuna=True)




pruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=2, min_early_stopping_rate=0)
study = optuna.create_study(study_name=STUDY_NAME,direction="minimize",pruner=pruner)  # Create a new study.
study.optimize(objective, n_trials=100)  # Invoke optimization of the objective function.

  from .autonotebook import tqdm as notebook_tqdm
[I 2024-09-17 15:43:01,473] A new study created in memory with name: test
Training the neural network: 100%|███████████████████████████████████████████████████████████████████████████████| 1873/1873 [00:17<00:00, 105.89it/s]
[I 2024-09-17 15:43:19,793] Trial 0 finished with value: 7591.132371172179 and parameters: {'epoch': 1873, 'lr': 0.03957040502019648, 'base_lr': 0.8428797365972374, 'multiple_loss_method': 'prior_losses', 'prior_losses_t': 100, 'res_weight_0': 0.07188065452614553, 'res_weight_1': 0.16270947963048618, 'res_weight_2': 0.3685862589306291, 'res_weight_3': 0.03869073026077972, 'res_weight_4': 0.14906152673990844, 'res_weight_5': 0.6830808848348532}. Best is trial 0 with value: 7591.132371172179.
Training the neural network: 100%|███████████████████████████████████████████████████████████████████████████████| 1681/1681 [00:15<00:00, 106.67it/s]
[I 2024-09-17 15:43:35,568] Trial 1 finished with value: 7150.021746765913 and

In [3]:
study.best_trial

FrozenTrial(number=53, state=1, values=[2220.2709911200036], datetime_start=datetime.datetime(2024, 9, 17, 13, 37, 49, 379871), datetime_complete=datetime.datetime(2024, 9, 17, 13, 38, 7, 380995), params={'epoch': 1916, 'lr': 0.6974565804541575, 'base_lr': 0.32807475814462794, 'multiple_loss_method': 'prior_losses', 'res_weight_0': 0.813025259863673, 'res_weight_1': 0.34296939633391493, 'res_weight_2': 0.23945110390610402, 'res_weight_3': 0.10507226958307188, 'res_weight_4': 0.5114470191398341, 'res_weight_5': 0.5355253344631333}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'epoch': IntDistribution(high=2000, log=False, low=1000, step=1), 'lr': FloatDistribution(high=1.0, log=False, low=1e-06, step=None), 'base_lr': FloatDistribution(high=1.0, log=False, low=1e-06, step=None), 'multiple_loss_method': CategoricalDistribution(choices=('prior_losses', 'soft_adapt')), 'res_weight_0': FloatDistribution(high=1.0, log=False, low=1e-16, step=None), 'res_weight_1': Fl

In [4]:
df = study.trials_dataframe()
df.loc[67]

number                                                 67
value                                       263722.616852
datetime_start                 2024-09-17 13:40:54.528374
datetime_complete              2024-09-17 13:41:05.944628
duration                           0 days 00:00:11.416254
params_base_lr                                   0.141315
params_epoch                                         1210
params_lr                                         0.84295
params_multiple_loss_method                  prior_losses
params_res_weight_0                              0.991488
params_res_weight_1                              0.580332
params_res_weight_2                              0.152025
params_res_weight_3                              0.375836
params_res_weight_4                              0.626401
params_res_weight_5                              0.542482
state                                            COMPLETE
Name: 67, dtype: object

In [5]:
df

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_base_lr,params_epoch,params_lr,params_multiple_loss_method,params_res_weight_0,params_res_weight_1,params_res_weight_2,params_res_weight_3,params_res_weight_4,params_res_weight_5,state
0,0,1.759035e+06,2024-09-17 13:26:18.541506,2024-09-17 13:26:33.128690,0 days 00:00:14.587184,0.706059,1445,0.677175,prior_losses,0.411307,0.022291,0.308646,0.096293,0.661964,0.339263,COMPLETE
1,1,,2024-09-17 13:26:33.129743,2024-09-17 13:26:34.121478,0 days 00:00:00.991735,0.092919,1300,0.879107,soft_adapt,0.390042,0.332694,0.086263,0.529091,0.268308,0.527043,PRUNED
2,2,,2024-09-17 13:26:34.122282,2024-09-17 13:26:34.288510,0 days 00:00:00.166228,0.165446,1545,0.617540,soft_adapt,0.546804,0.121539,0.542496,0.795425,0.536388,0.873912,PRUNED
3,3,8.600318e+08,2024-09-17 13:26:34.289342,2024-09-17 13:26:51.084052,0 days 00:00:16.794710,0.289459,1743,0.472614,prior_losses,0.415751,0.607390,0.195773,0.665536,0.613384,0.683412,COMPLETE
4,4,4.807157e+05,2024-09-17 13:26:51.085005,2024-09-17 13:27:04.467088,0 days 00:00:13.382083,0.191801,1426,0.417179,prior_losses,0.046027,0.912560,0.910812,0.058671,0.825933,0.270633,COMPLETE
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,95,7.314540e+03,2024-09-17 13:47:20.149694,2024-09-17 13:47:34.410291,0 days 00:00:14.260597,0.067633,1521,0.719492,prior_losses,0.435411,0.769861,0.606258,0.280915,0.107171,0.577481,COMPLETE
96,96,3.218975e+09,2024-09-17 13:47:34.411244,2024-09-17 13:47:49.544018,0 days 00:00:15.132774,0.117226,1609,0.834706,prior_losses,0.468096,0.722645,0.159499,0.322235,0.504984,0.859163,COMPLETE
97,97,5.141554e+04,2024-09-17 13:47:49.545067,2024-09-17 13:47:59.018811,0 days 00:00:09.473744,0.216956,1007,0.810480,prior_losses,0.901743,0.041372,0.294930,0.149156,0.051960,0.436802,COMPLETE
98,98,,2024-09-17 13:47:59.019815,2024-09-17 13:47:59.770778,0 days 00:00:00.750963,0.171231,1642,0.861117,soft_adapt,0.870202,0.958231,0.200496,0.415069,0.984183,0.972720,PRUNED
