In [1]:
import numpy as np

import jax
import jax.numpy as jnp

# -------------- helper libraries -------------- #
import sys
import os
import time
import pickle

import matplotlib.pyplot as plt
import pandas as pd
import importlib
import itertools

def append_path(path):
    if path not in sys.path:
        sys.path.append(path)

append_path(os.path.abspath(os.path.join('..', '00_utils')))
append_path(os.path.abspath(os.path.join('..', '00_utils_training')))
append_path(os.path.abspath(os.path.join('..', '00_models')))

import run_train_toy
importlib.reload(run_train_toy)
Trainer = run_train_toy.TrainerToy

import analyse_results
reload_module = analyse_results.reload_module

Graphs = reload_module('analyse_results', 'Graphs')
Results = reload_module('analyse_results', 'Results')
convert_lists_in_tuple = reload_module('analyse_results', 'convert_lists_in_tuple')

import logging
logging.basicConfig(level=logging.ERROR, filename='error_log.txt')

In [2]:
reg_list = [1e-5, 1e-4, 1e-3]
max_iter_li = [500, 1000, 1500]

param_combinations = list(itertools.product(reg_list, max_iter_li))

for c in param_combinations:
    print(str(c))
    print("reg:", c[0], "max_iter:", c[1])

(1e-05, 500)
reg: 1e-05 max_iter: 500
(1e-05, 1000)
reg: 1e-05 max_iter: 1000
(1e-05, 1500)
reg: 1e-05 max_iter: 1500
(0.0001, 500)
reg: 0.0001 max_iter: 500
(0.0001, 1000)
reg: 0.0001 max_iter: 1000
(0.0001, 1500)
reg: 0.0001 max_iter: 1500
(0.001, 500)
reg: 0.001 max_iter: 500
(0.001, 1000)
reg: 0.001 max_iter: 1000
(0.001, 1500)
reg: 0.001 max_iter: 1500


In [20]:
def run(optimization_type):
  TRAINER = Trainer.load_trainer("ho", spacing_type = "uniform", model_type = "jax_diffrax") 
  results = {}
  AVERAGED = False

  #optimization_type = 'activation'

  params_model = {
      'layer_widths': [2, 32, 2],
      'penalty_lambda_reg': 1e-3,
      'time_invariant': True,
      'learning_rate': 1e-3,
      'max_iter': [1000, 5000],
      'pretrain': [0.2, 1],
      'verbose': False,
      'rtol': 1e-3,
      'atol': 1e-6,
      "log": False,
      'act_func': 'tanh',
      'split_time': False
  }

  if optimization_type == 'network_size':
    lw = [[2, 8, 2], [2, 16, 2], [2, 32, 2], [2, 16, 16, 2], [2, 32, 32, 2]]
    # lw = [[2, 16, 2], [2, 32, 2], [2, 64, 2], [2, 128, 2]]
    reg_list = [1e-5, 1e-4, 1e-3]
    max_iter_li = [[1000, 5000], [1000, 7500], [1000, 10000]]
    param_combinations = list(itertools.product(lw, reg_list, max_iter_li))
  elif optimization_type == 'tolerance':
    rtol = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    atol = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
    param_combinations = list(itertools.product(rtol, atol))
  elif optimization_type == 'none':
    param_combinations = [0]
  elif optimization_type == 'activation':
    act = ['tanh', 'relu', 'sigmoid']
    data = ['vdp', 'ho', 'do']
    param_combinations = list(itertools.product(act, data))
  elif optimization_type == 'training_convergence':
    data = ['ho', 'vdp', 'do']
    pretrain = [False, [0.2, 1]]
    params_model['log'] = 1000
    #params_model['log'] = False #50 
    #params_model['split_time'] = True
    param_combinations = list(itertools.product(data, pretrain))
  elif optimization_type == 'regularization':
    reg_list = [0, 1e-06, 1.0e-04, 1.0e-03, 1.0e-02, 1.0e-01,1]
    param_combinations = reg_list
  else:
    raise ValueError("Invalid optimization type")

  total_iter = len(param_combinations)
  i = 1

  for param_comb in param_combinations:
      if optimization_type == 'network_size':
        lw = param_comb[0]
        params_model['layer_widths'] = lw
        params_model['penalty_lambda_reg'] = param_comb[1]
        params_model['max_iter'] = param_comb[2]
      
      elif optimization_type == 'regularization':
        params_model['penalty_lambda_reg'] = param_comb
        
      elif optimization_type == 'tolerance':
        params_model['rtol'] = param_comb[0]
        params_model['atol'] = param_comb[1]
        
      elif optimization_type == 'activation':
        params_model['act_func'] = param_comb[0]
        TRAINER = Trainer.load_trainer(param_comb[1], spacing_type = "uniform", model_type = "jax_diffrax") 
        
      elif optimization_type == 'training_convergence':
          TRAINER = Trainer.load_trainer(param_comb[0], spacing_type = "uniform", model_type = "jax_diffrax") 
          params_model['pretrain'] = param_comb[1]
          # params_model['log'] = True
          if params_model['pretrain'] == False:
              params_model['max_iter'] = 30000
          else:
              params_model['max_iter'] = [1000, 30000]
      else:
        if optimization_type != 'none':
          raise ValueError("Invalid optimization type")
        else:
          params_model['log'] = True
      
      print(params_model['log'])
      if not AVERAGED:
        try:
          TRAINER.train(params_model)
        except Exception as e:
          print("Failed to complete training: {}".format(e))
          logging.error("Failed to complete training: {}".format(e))
          continue
        
        if isinstance(param_comb, tuple):
          param_comb = convert_lists_in_tuple(param_comb)
        result = TRAINER.extract_results()
        print(param_comb)
        
        if optimization_type == 'training_convergence':
          param_comb = (param_comb[0], True if param_comb[1] else False)
          
        results[param_comb] = result
        
        if optimization_type == 'training_convergence':
            training_loss = TRAINER.losses
            results[param_comb]['training_loss'] = training_loss
        
        print(results[param_comb]['mse_train'])
        print(results[param_comb]['mse_test'])
        print(results[param_comb]['time_elapsed'])
        
      if AVERAGED:
        mse_train = []
        mse_test = []
        time_elapsed = []
        for _ in range(5):
          try:
            TRAINER.train(params_model)
          except Exception as e:
            print("Failed to complete training: {}".format(e))
            logging.error("Failed to complete training: {}".format(e))
            continue
            
          result = TRAINER.extract_results()
          mse_train.append(result['mse_train'])
          mse_test.append(result['mse_test'])
          time_elapsed.append(result['time_elapsed'])
          
        results[param_comb] = {
            'mse_train': np.mean(mse_train),
            'mse_test': np.mean(mse_test),
            'time_elapsed': np.mean(time_elapsed)
        }
        
        print(results[param_comb]['mse_train'])
        print(results[param_comb]['mse_test'])
      
      print("Iteration:", i, "/", total_iter)
      i+=1
      
  return results

In [21]:
RESULTS = run('regularization')

False
0
0.0005807362238536533
0.002839467146660001
6.474164009094238
Iteration: 1 / 7
False
1e-06
0.0005805351335806764
0.0028393911030052707
9.054824829101562
Iteration: 2 / 7
False
0.0001
0.0005636946629840601
0.0028195188482113822
6.025352954864502
Iteration: 3 / 7
False
0.001
0.0005001933289043781
0.002514104245043799
6.005992889404297
Iteration: 4 / 7
False
0.01
0.0008958695642470911
0.002069021581028907
6.002652883529663
Iteration: 5 / 7
False
0.1
0.42673512278258735
0.5947926526669257
5.7831830978393555
Iteration: 6 / 7
False
1
0.5204039110209011
0.8310408361571286
5.708235025405884
Iteration: 7 / 7


In [22]:
reload = True
if reload:
    formatted_time = time.strftime('%Y-%m-%d_%H-%M-%S')
    filename = f'results/jax_regularization_new.pkl'
    with open(filename, 'wb') as file:
        pickle.dump(RESULTS, file)
        
    print(f"Results saved to {filename}")

Results saved to results/jax_regularization_new.pkl
