In [1]:
import numpy as np
import itertools

import jax
import jax.numpy as jnp

# -------------- helper libraries -------------- #
import sys
import os
import time
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import importlib

def append_path(path):
    if path not in sys.path:
        sys.path.append(path)

def reload_module(module_name, class_name):
    module = importlib.import_module(module_name)
    importlib.reload(module)
    return getattr(module, class_name)
        
append_path(os.path.abspath(os.path.join('..', '00_utils')))
append_path(os.path.abspath(os.path.join('..', '00_utils_training')))
append_path(os.path.abspath(os.path.join('..', '00_models')))

import logging
logging.basicConfig(level=logging.ERROR, filename='error_log.txt')

1. Train the model with Pyomo
    * the main hyperparameter if the number of epochs
2. Initialize JAX model with the trained Pyomo weights and biases

In [2]:
def run(optimization_type, max_iter):
  Trainer = reload_module('run_train_toy', 'TrainerToy')
  
  tol = 1e-4
  params_model = {
      'layer_widths': [2, 32, 2],
      'act_func': 'tanh',
      'penalty_lambda_reg': 0.001,
      'time_invariant': True,
      'w_init_method': 'xavier',
      "reg_norm": False,
      "skip_collocation": np.inf,
      'params': 
              {
            "tol":tol, 
            "acceptable_iter": 0,    
            "halt_on_ampl_error" : 'yes',
            "print_level": 5, 
            "max_iter": max_iter}}

  trainer = Trainer.load_trainer("vdp")
  results = {}

  if optimization_type == 'regularization':
    param_combinations = [0, 1e-6, 0.0001, 0.01, 0.1, 1]

  elif optimization_type == 'max_iter':
      max_iter_li = np.array([i for i in range(1, 100)])
      param_combinations = max_iter_li * 1
      converged = False
        
  elif optimization_type == 'none':
    param_combinations = [0]  
  
  else:
    raise ValueError(f"Invalid optimization type {optimization_type}")
      
  total_iter = len(param_combinations)
  i = 1

  for param_comb in param_combinations:
      # PARAMETER UPDATE
      if optimization_type == 'regularization':
          params_model['penalty_lambda_reg'] = param_comb

      elif optimization_type == 'max_iter':
          params_model['params']['max_iter'] = param_comb
          if converged:
            continue
      
      elif optimization_type == 'none':
          pass
     
      try:
        trainer.train_pyomo(params_model)
        if optimization_type == 'max_iter' and 'optimal' in trainer.termination:
          print(f"Optimal solution found at/before iteration {param_comb}")
          converged = True
          
      except Exception as e:
        results[param_comb] = {'time_elapsed': np.array(np.nan), 'mse_train': np.array(np.nan), 'mse_test': np.array(np.nan)}
        logging.error("Failed to complete training: {}".format(e))
        print(f"{e}")
        continue
      
      try:
        if optimization_type == 'network_size_grid_search':
          k = (param_comb[0][1], param_comb[1], param_comb[2])
          results[k] = trainer.extract_results_pyomo()
        else:
          results[param_comb] = trainer.extract_results_pyomo()
      except Exception as e:
        results[param_comb] = {'time_elapsed': np.array(np.nan), 'mse_train': np.array(np.nan), 'mse_test': np.array(np.nan)}
        results.error("Failed to extract results: {}".format(e))
        print(f"{e}")
      
      print("Iteration:", i, "/", total_iter)
      i+=1
      
  return trainer, results

In [3]:
trainer, results = run('none', max_iter=100)



current_16_08




Ipopt 3.14.16: tol=0.0001
acceptable_iter=0
halt_on_ampl_error=yes
print_level=5
max_iter=100


******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
 Ipopt is released as open source code under the Eclipse Public License (EPL).
         For more information visit https://github.com/coin-or/Ipopt
******************************************************************************

This is Ipopt version 3.14.16, running with linear solver MUMPS 5.7.1.

Number of nonzeros in equality constraint Jacobian...:   131342
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:    64729

Total number of variables............................:      562
                     variables with only lower bounds:        0
                variables with lower and upper bounds:      562
                     variables with only upper bounds:       

In [11]:
def run_diffrax(optimization_type, layer_width, learning_rate = 1e-3, reg_lambda = 1e-5, custom_params=None, pretrain = False, max_iter = 1000):
  Trainer = reload_module('run_train_toy', 'TrainerToy')
  TRAINER = Trainer.load_trainer("vdp", 'uniform', 'jax_diffrax')
  RESULTS = {}

  params_model = {
      'layer_widths': layer_width,
      'penalty_lambda_reg': reg_lambda,
      'time_invariant': True,
      'learning_rate': learning_rate,
      'max_iter': [max_iter, max_iter],
      'pretrain': [0.2, 1],
      'verbose': False,
      'rtol': 1e-3,
      'atol': 1e-6,
      "log": False,
      'act_func': 'tanh',
      'split_time': True
  }

  if pretrain:
    params_model['pretrain'] = [0.2, 1]
    params_model['max_iter'] = [max_iter, max_iter]
  else:
    params_model['pretrain'] = False
    params_model['max_iter'] = max_iter

  if optimization_type == 'none':
    param_combinations = [0]
  elif optimization_type == 'training_convergence':
    param_combinations = [0] 
    params_model['log'] = 100  
  else:
    raise ValueError("Invalid optimization type")

  total_iter = len(param_combinations)
  i = 1

  for param_comb in param_combinations:
      if optimization_type == 'training_convergence':
        pass
      elif optimization_type == 'none':
        pass
      else:
        raise ValueError("Invalid optimization type")
      
      print(params_model['log'])
      try:
        TRAINER.train(params_model, custom_params)
      except Exception as e:
        print("Failed to complete training: {}".format(e))
        logging.error("Failed to complete training: {}".format(e))
        continue
      
      result = TRAINER.extract_results()
      print(param_comb)
        
      RESULTS[param_comb] = result
      
      if optimization_type == 'training_convergence':
          training_loss = TRAINER.losses
          RESULTS[param_comb]['training_loss'] = training_loss
      
      print(RESULTS[param_comb]['mse_train'])
      print(RESULTS[param_comb]['mse_test'])
      print("Iteration:", i, "/", total_iter)
      i+=1
      
  return RESULTS

In [13]:
wb_trained = trainer.extract_pyomo_weights()

custom_params = {
    'Dense_0': {
        'kernel': jnp.array(wb_trained['W1']).T,
        'bias': jnp.array(wb_trained['b1'])
    },
    'Dense_1': {
        'kernel': jnp.array(wb_trained['W2']).T,
        'bias': jnp.array(wb_trained['b2'])
    }
}

Trainer = reload_module('run_train_toy', 'TrainerToy')
trainer_jax = Trainer.load_trainer('vdp', 'uniform', 'jax_diffrax')

#results_jax = run_diffrax('none', trainer.layer_widths, custom_params = custom_params, learning_rate=1e-7, reg_lambda=1e-5, max_iter=20000)
results_jax = run_diffrax('none', [2, 32, 2], pretrain = True, learning_rate=1e-3, reg_lambda=1e-5, max_iter=30000)
results_jax

False
Failed to complete training: '>' not supported between instances of 'int' and 'list'


{}

`lr = 1e-7`
0.013535604903391951
0.30661465729396137

In [9]:
# results= run_diffrax('training_convergence', trainer.layer_widths, custom_params = custom_params, max_iter=20000, learning_rate=1e-7, reg_lambda=1e-5)
results = run_diffrax('training_convergence', [2, 32, 2], pretrain = True, learning_rate=1e-3, reg_lambda=1e-5, max_iter=[10000, 30000])
results

100
Failed to complete training: 'epoch_recording_step'


{}

In [13]:
suffix = '_32'

reload = True
if reload:
    formatted_time = time.strftime('%Y-%m-%d_%H-%M-%S')
    filename = f'results/diffrax_pyomo_pretraining{suffix}.pkl'
    with open(filename, 'wb') as file:
        pickle.dump(results, file)
        
    print(f"Results saved to {filename}")

if reload:
    formatted_time = time.strftime('%Y-%m-%d_%H-%M-%S')
    filename = f'results/diffrax_pyomo_pretraining_time{suffix}.pkl'
    with open(filename, 'wb') as file:
        pickle.dump(results_jax, file)
        
    print(f"Results saved to {filename}")

Results saved to results/diffrax_pyomo_pretraining_32.pkl
Results saved to results/diffrax_pyomo_pretraining_time_32.pkl
