In [26]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime, timedelta

# jax
import jax.numpy as jnp
import time

import sys
import os
import importlib
import pickle
import itertools

def append_path(path):
    if path not in sys.path:
        sys.path.append(path)

def reload_module(module_name, class_name):
    module = importlib.import_module(module_name)
    importlib.reload(module)
    return getattr(module, class_name)
        
append_path(os.path.abspath(os.path.join('..', '00_utils')))
append_path(os.path.abspath(os.path.join('..', '00_utils_training')))
append_path(os.path.abspath(os.path.join('..', '00_models')))

import logging
logging.basicConfig(level=logging.ERROR, filename='error_log.txt')

import optimize_diffrax_rl

ExperimentRunner = reload_module('optimize_diffrax_rl', 'ExperimentRunner')

In [74]:
path = '../00_trained_wb/rl_32_weights_2024-09-02_11-32-54.pkl'
path = '../00_trained_wb/trained_wb_2024-09-02_14-40-45.pkl'
# import trained weights
with open(path, 'rb') as file:
    trained_wb = pickle.load(file)
    
trained_wb = ExperimentRunner.format_weights_from_pyomo(trained_wb)

In [79]:
importlib.reload(optimize_diffrax_rl)
ExperimentRunner = optimize_diffrax_rl.ExperimentRunner
start_date = '2015-01-10'

extra_inputs = {}
extra_inputs['params_model'] = {'layer_sizes': [7, 32, 32, 1], 'penalty': 1e-5, 'learning_rate': 1e-2, 'num_epochs': 20000, 'pretrain': False}

extra_inputs['params_data'] = {'file_path': '../00_data/df_train.csv', 'start_date': start_date, 
                'n_points': 300, 'split': 200, 'n_days': 1, 'm': 1, 
                'prev_hour': False, 'prev_week': True, 'prev_year': True,
                'spacing': 'uniform',
                'encoding': {'settlement_date': 't', 'temperature': 'var1', 'hour': 'var2', 'nd': 'y'}}

extra_inputs['params_sequence'] = {'sequence_len': 10, 'frequency': 35}
extra_inputs['params_results'] = {'plot':True, 'log' : False, 'split_time' : False}
#extra_inputs['trained_wb'] = trained_wb

runner = ExperimentRunner(start_date, 'default', extra_inputs)
runner.run()

Generating default parameters for data
Generating default parameters for model
Running iteration 1 with parameters: 1
Start Data: 2015-01-10 00:00:00
days_offset: 1
Offset: 2015-01-09 00:00:00
Epoch 100, Loss: 0.5808260649058886
Epoch 200, Loss: 0.2611207497737063
Epoch 300, Loss: 0.15181834528485907
Epoch 400, Loss: 0.09622062042604294
Epoch 500, Loss: 0.07358478881024964
Epoch 600, Loss: 0.058200305890362965
Epoch 700, Loss: 0.06254173408365848
Epoch 800, Loss: 0.06679710072977299
Epoch 900, Loss: 0.05095791189409308
Epoch 1000, Loss: 0.04990359369442999
Epoch 1100, Loss: 0.042075525415067115
Epoch 1200, Loss: 0.07182199174888158
Epoch 1300, Loss: 0.09213322938268506
Epoch 1400, Loss: 0.050988476420629036
Epoch 1500, Loss: 0.044300226467768905
Epoch 1600, Loss: 0.03908287598436878
Epoch 1700, Loss: 0.03570385412060241
Epoch 1800, Loss: 0.036795350863467594
Epoch 1900, Loss: 0.054972452330798366
Epoch 2000, Loss: 0.08542006668878481
Epoch 2100, Loss: 0.03662204332955993
Epoch 2200, Lo

In [81]:
runner.results_avg

{1: {'times_elapsed': 65.04188876152038,
  'mse_diffrax': 0.003666992289213589,
  'mse_diffrax_test': 0.31264357364636086}}

In [73]:
runner.save_results('short_training_10_jax')

Results saved to results/short_training_10_jax_2024-09-08_13-37-55_full.pkl
Results saved to results/short_training_10_jax_2024-09-08_13-37-55_avg.pkl


In [None]:
'results/convergence_jax_no_pretrain_2024-09-02_18-25-17_full.pkl'