In [29]:
import numpy as np

import jax
import jax.numpy as jnp

# -------------- helper libraries -------------- #
import sys
import os
import time
import pickle

import matplotlib.pyplot as plt
import pandas as pd
import importlib

path_ = os.path.abspath(os.path.join('..', '00_utils'))

if path_ not in sys.path:
    sys.path.append(path_)
    
path_ = os.path.abspath(os.path.join('..', '00_models'))

if path_ not in sys.path:
    sys.path.append(path_)
    
import logging
logging.basicConfig(level=logging.ERROR, filename='error_log.txt')

In [31]:
import run_train_toy
importlib.reload(run_train_toy)
PyomoTrainerToy = run_train_toy.PyomoTrainerToy

data_params = {
    'N': 200,
    'noise_level': 0.2,
    'ode_type': "harmonic_oscillator",
    'data_param': {"omega_squared": 2},
    'start_time': 0,
    'end_time': 10,
    'spacing_type': "uniform",
    'initial_state': np.array([0.0, 1.0])
}

trainer = PyomoTrainerToy(data_params, model_type='jax_diffrax')
trainer.prepare_inputs()

In [32]:
import itertools

reg_list = [1e-5, 1e-4, 1e-3]
max_iter_li = [500, 1000, 1500]

param_combinations = list(itertools.product(reg_list, max_iter_li))

for c in param_combinations:
    print(str(c))
    print("reg:", c[0], "max_iter:", c[1])

(1e-05, 500)
reg: 1e-05 max_iter: 500
(1e-05, 1000)
reg: 1e-05 max_iter: 1000
(1e-05, 1500)
reg: 1e-05 max_iter: 1500
(0.0001, 500)
reg: 0.0001 max_iter: 500
(0.0001, 1000)
reg: 0.0001 max_iter: 1000
(0.0001, 1500)
reg: 0.0001 max_iter: 1500
(0.001, 500)
reg: 0.001 max_iter: 500
(0.001, 1000)
reg: 0.001 max_iter: 1000
(0.001, 1500)
reg: 0.001 max_iter: 1500


In [36]:
params_model = {
    'layer_widths': [2, 32, 32, 2],
    'penalty_lambda_reg': 1e-5,
    'time_invariant': True,
    'learning_rate': 1e-3,
    'max_iter': 1000,
    'pretrain': [0.2, 1],
    'pretrain': False,
    'verbose': False
}

results = {}
lw_list = [[2, 16, 2], [2, 32, 2], [2, 64, 2], [2, 128, 2], [2, 8, 8, 2], [2, 16, 16, 2], [2, 32, 32, 2], [2, 64, 64, 2]]
lw_list = [[2, 16, 2], [2, 32, 2], [2, 16, 16, 2], [2, 32, 32, 2]]
lw_list

total_iter = len(param_combinations)*len(lw_list)
i = 1

for lw in lw_list:
    params_model['layer_widths'] = lw
    results[str(lw)] = {}
    
    for param_comb in param_combinations:
      params_model['penalty_lambda_reg'] = param_comb[0]
      params_model['max_iter'] = param_comb[1]
      
      try:
        trainer.train(params_model)
      except Exception as e:
        logging.error("Failed to complete training: {}".format(e))
        continue
      
      result = trainer.extract_results()
      results[str(lw)][str(param_comb)] = result
      
      print("Iteration:", i, "/", total_iter)
      print(results[str(lw)][str(param_comb)]['mse_train'])
      print(results[str(lw)][str(param_comb)]['mse_test'])
      i+=1

Iteration: 1 / 36
0.3262873680277081
38.032072042269164
Iteration: 2 / 36
0.32417534610396764
47.57851353068948
Iteration: 3 / 36
0.3273842871854942
52.6441986517983
Iteration: 4 / 36
0.32560319404789817
40.469810561895386
Iteration: 5 / 36
0.3255149320741678
41.499469035869076
Iteration: 6 / 36
0.3247006791686451
45.99606219448478
Iteration: 7 / 36
0.3256874752271199
37.739296620021435
Iteration: 8 / 36
0.3248118810548344
40.977289496661314
Iteration: 9 / 36
0.32524583653768857
39.42075421666091
Iteration: 10 / 36
0.42661691521075484
1.0199080924141397
Iteration: 11 / 36
0.3655938190120422
1.4685059162901104
Iteration: 12 / 36
0.34858786352952414
1.5986266620360312
Iteration: 13 / 36
0.42667877094568496
1.0198322704823926
Iteration: 14 / 36
0.3656914455979986
1.4689719551581377
Iteration: 15 / 36
0.3487180656488711
1.597966668496386
Iteration: 16 / 36
0.42731320554084257
1.0190219763139698
Iteration: 17 / 36
0.3666817065600658
1.4736479933520994
Iteration: 18 / 36
0.3500315748863642
1

---
### Van Der Pol

In [68]:
import run_train_toy
importlib.reload(run_train_toy)
PyomoTrainerToy = run_train_toy.PyomoTrainerToy
# ode_type, params = "van_der_pol", {"mu": 1, "omega": 1}

data_params = {
    'N': 200,
    'noise_level': 0.1,
    'ode_type': "van_der_pol",
    'data_param': {"mu": 1, "omega": 1},
    'start_time': 0,
    'end_time': 10,
    'spacing_type': "uniform",
    'initial_state': np.array([0.0, 1.0])
}

trainer = PyomoTrainerToy(data_params, model_type='jax_diffrax')
trainer.prepare_inputs()

In [65]:
import itertools

reg_list = [1e-5, 1e-4, 1e-3]
max_iter_li = [2500, 5000, 7500]

param_combinations = list(itertools.product(reg_list, max_iter_li))

for c in param_combinations:
    print(str(c))
    print("reg:", c[0], "max_iter:", c[1])

(1e-05, 2500)
reg: 1e-05 max_iter: 2500
(1e-05, 5000)
reg: 1e-05 max_iter: 5000
(1e-05, 7500)
reg: 1e-05 max_iter: 7500
(0.0001, 2500)
reg: 0.0001 max_iter: 2500
(0.0001, 5000)
reg: 0.0001 max_iter: 5000
(0.0001, 7500)
reg: 0.0001 max_iter: 7500
(0.001, 2500)
reg: 0.001 max_iter: 2500
(0.001, 5000)
reg: 0.001 max_iter: 5000
(0.001, 7500)
reg: 0.001 max_iter: 7500


In [66]:
params_model = {
    'layer_widths': [2, 32, 32, 2],
    'penalty_lambda_reg': 1e-5,
    'time_invariant': True,
    'learning_rate': 1e-3,
    'max_iter': 1000,
    'pretrain': [0.2, 1],
    'verbose': False
}

results = {}
#lw_list = [[2, 16, 2], [2, 32, 2], [2, 64, 2], [2, 128, 2], [2, 8, 8, 2], [2, 16, 16, 2], [2, 32, 32, 2], [2, 64, 64, 2]]
lw_list = [[2, 8, 2], [2, 16, 2], [2, 32, 2], [2, 64, 2], [2, 128,2], [2, 8, 8, 2], [2, 16, 16, 2], [2, 32, 32, 2], [2, 64, 64, 2]]
# lw_list

total_iter = len(param_combinations)*len(lw_list)
i = 1

for lw in lw_list:
    params_model['layer_widths'] = lw
    results[str(lw)] = {}
    
    for param_comb in param_combinations:
      params_model['penalty_lambda_reg'] = param_comb[0]
      params_model['max_iter'] = param_comb[1]
      
      try:
        trainer.train(params_model)
      except Exception as e:
        logging.error("Failed to complete training: {}".format(e))
        continue
      
      result = trainer.extract_results()
      results[str(lw)][str(param_comb)] = result
      
      print("Iteration:", i, "/", total_iter)
      print(results[str(lw)][str(param_comb)]['mse_train'])
      print(results[str(lw)][str(param_comb)]['mse_test'])
      i+=1

Iteration: 1 / 81
1.787088743987356
22.12355326334817
Iteration: 2 / 81
2.5408369442347056
15.080965587903588
Iteration: 3 / 81
2.4896505032844107
16.624313706069945
Iteration: 4 / 81
1.8064064796027453
15.684628472208875
Iteration: 5 / 81
1.6399427658805485
159.88876395524184
Iteration: 6 / 81
1.7358080682026906
35.93328377184612
Iteration: 7 / 81
0.16509604215116255
6.470829221212435
Iteration: 8 / 81
1.7688202286162205
4.278827383726954
Iteration: 9 / 81
0.16990788730139855
8.444133465382732
Iteration: 10 / 81
0.14775642096851177
5.081362999233868
Iteration: 11 / 81
0.1254225765460013
3.2772578722205266
Iteration: 12 / 81
0.11938660487521546
2.1403030806692853
Iteration: 13 / 81
0.1590658891892935
5.570381585701149
Iteration: 14 / 81
0.1243242233561474
3.0443150428423484
Iteration: 15 / 81
0.11925377782529417
2.4211936171821136
Iteration: 16 / 81
0.33503071164007353
7.840579248820341
Iteration: 17 / 81
0.1456031936092291
5.423557956138534
Iteration: 18 / 81
0.13577209618719605
3.773

In [67]:
import pickle

with open('jax_diffrax_results.pkl', 'wb') as file:
    pickle.dump(results, file)