In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime, timedelta

# jax
import jax.numpy as jnp
import time

import sys
import os
import importlib
import pickle
import itertools

path_ = os.path.abspath(os.path.join('..', '00_utils'))
if path_ not in sys.path:
    sys.path.append(path_)

path_ = os.path.abspath(os.path.join('..', '00_models'))
if path_ not in sys.path:
    sys.path.append(path_)

path_ = os.path.abspath(os.path.join('..', '00_utils_training'))
if path_ not in sys.path:
    sys.path.append(path_)

import logging
logging.basicConfig(level=logging.ERROR, filename='error_log.txt')

import optimize_diffrax_rl

In [2]:
importlib.reload(optimize_diffrax_rl)
ExperimentRunner = optimize_diffrax_rl.ExperimentRunner

In [3]:
extra_inputs = {}
extra_inputs['params_sequence'] = {'sequence_len': 1, 'frequency': 1}
extra_inputs['params_model'] = {'layer_sizes': [6, 64, 64, 1], 'penalty': 1e-5, 'learning_rate': 1e-3, 'num_epochs': [1000, 1000], 'pretrain': [0.2, 1]}
extra_inputs['params_sequence'] = {'sequence_len': 1, 'frequency': 1}
extra_inputs['params_results'] = {'plot':True, 'log' : True, 'split_time' : False}

runner = ExperimentRunner('2015-01-01', 'convergence', extra_inputs)
runner.run()

Generating default parameters for data
Generating default parameters for model
Running iteration 1 with parameters: 1
Epoch 100, Loss: 0.3188007581995243
Epoch 200, Loss: 0.1864720219847577
Epoch 300, Loss: 0.12195367010197833
Epoch 400, Loss: 0.106541392718977
Epoch 500, Loss: 0.07727697536717684
Epoch 600, Loss: 0.0703752283469337
Epoch 700, Loss: 0.06722180054425332
Epoch 800, Loss: 0.06625444706657413
Epoch 900, Loss: 0.06561430375345993
Epoch 1000, Loss: 0.06527610776913062
Epoch 100, Loss: 0.9683698168357237
Epoch 200, Loss: 0.6790588103575996
Epoch 300, Loss: 0.5873211843950249
Epoch 400, Loss: 0.5305257229001522
Epoch 500, Loss: 0.468424984767969
Epoch 600, Loss: 0.4238035587598031
Epoch 700, Loss: 0.37655865223278073
Epoch 800, Loss: 0.3449382346713796
Epoch 900, Loss: 0.31029459840754486
Epoch 1000, Loss: 0.2753071923473948
Iteration i: 1/1 completed


In [4]:
runner.losses

[[],
 [8.446103300526982,
  5.805316639509833,
  3.524640687784096,
  2.4377769860867113,
  1.5833811846072259,
  1.2155161374043022,
  1.074547381597868,
  1.0195482226963086,
  0.9872856516882951,
  0.9606146497034942,
  0.9207337557760201,
  0.8805106779528538,
  0.8401760269342512,
  0.8089829212594358,
  0.7817880269401392,
  0.7593166155607913,
  0.7326345454614562,
  0.713485540006522,
  0.6908473045474617,
  0.6782385414990931,
  0.6647463617835965,
  0.6503505955068942,
  0.6413676645703013,
  0.6287670626058331,
  0.6276101547614725,
  0.612799962742723,
  0.6069228011423209,
  0.5982171603131671,
  0.5918575835495666,
  0.587312329303742,
  0.5829715299671815,
  0.5724641573204492,
  0.5677018060680624,
  0.5791266471975405,
  0.5562726048959056,
  0.5524653020624192,
  0.5446380295768134,
  0.5403018440016342,
  0.5282337349684305,
  0.5300616163259032,
  0.5225740760655045,
  0.518704600360735,
  0.5105770044191555,
  0.5058092818215015,
  0.490964715175195,
  0.4889518381

In [13]:
runner.results_avg

{'times_elapsed': 9.685075481732687,
 'mse_diffrax': 0.2775134084976922,
 'mse_diffrax_test': 0.8211285403569759}

In [14]:
runner.results_full

{(0, '2015-01-01'): {'times_elapsed': 9.472509860992432,
  'mse_diffrax': 0.2775134084976922,
  'mse_diffrax_test': 0.8211285403569759},
 (1e-07, '2015-01-01'): {'times_elapsed': 9.560502052307129,
  'mse_diffrax': 0.2775134084976922,
  'mse_diffrax_test': 0.8211285403569759},
 (1e-05, '2015-01-01'): {'times_elapsed': 9.58388876914978,
  'mse_diffrax': 0.2775134084976922,
  'mse_diffrax_test': 0.8211285403569759},
 (0.001, '2015-01-01'): {'times_elapsed': 9.641717910766602,
  'mse_diffrax': 0.2775134084976922,
  'mse_diffrax_test': 0.8211285403569759},
 (0.01, '2015-01-01'): {'times_elapsed': 9.6963791847229,
  'mse_diffrax': 0.2775134084976922,
  'mse_diffrax_test': 0.8211285403569759},
 (0.1, '2015-01-01'): {'times_elapsed': 10.155455112457275,
  'mse_diffrax': 0.2775134084976922,
  'mse_diffrax_test': 0.8211285403569759}}