In [1]:
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import run_train_toy
import importlib

In [2]:
importlib.reload(run_train_toy)
TrainerToy = run_train_toy.TrainerToy

def load_trainer(type_, spacing_type="uniform"):
    data_params_ho = {
        'N': 200,
        'noise_level': 0.2,
        'ode_type': "harmonic_oscillator",
        'data_param': {"omega_squared": 2},
        'start_time': 0,
        'end_time': 10,
        'spacing_type': spacing_type,
        'initial_state': np.array([0.0, 1.0])
    }

    data_params_vdp = {
        'N': 200,
        'noise_level': 0.1,
        'ode_type': "van_der_pol",
        'data_param': {"mu": 1, "omega": 1},
        'start_time': 0,
        'end_time': 15,
        'spacing_type': spacing_type,
        'initial_state': np.array([0.0, 1.0])
    }

    data_params_do = {
        'N': 200,
        'noise_level': 0.1,
        'ode_type': "damped_oscillation",
        'data_param': {"damping_factor": 0.1, "omega_squared": 1},
        'start_time': 0,
        'end_time': 10,
        'spacing_type': spacing_type,
        'initial_state': np.array([0.0, 1.0])
    }

    if type_ == "ho":
        p_ = data_params_ho
    elif type_ == "vdp":
        p_ = data_params_vdp
    elif type_ == "do":
        p_ = data_params_do
    else:
        raise ValueError(f"Invalid type {type_}")

    trainer = TrainerToy(p_, model_type="jax_diffrax")
    trainer.prepare_inputs()
    return trainer

### Regularization

In [5]:
fn = '2024-08-15_12-41-10_jax_reg.pkl'

with open(f'results/{fn}', 'rb') as file:
    results_reg_ho = pickle.load(file)
    
def extract_metrics(results_dict):
    time_elapsed, mse_train, mse_test = [], [], []
    for key, value in results_dict.items():
        time_elapsed.append(value['time_elapsed'])
        mse_train.append(value['mse_train'].item())
        mse_test.append(value['mse_test'].item())
    return time_elapsed, mse_train, mse_test

In [7]:
time_elapsed_ho, mse_train_ho, mse_test_ho = extract_metrics(results_reg_ho)

df = pd.DataFrame({
    'Reg Strength': results_reg_ho.keys(),
    'Time Elapsed HO': time_elapsed_ho,
    'MSE Train HO': mse_train_ho,
    'MSE Test HO': mse_test_ho
})

display(df.T)

Unnamed: 0,0,1,2,3,4,5,6,7
Reg Strength,0.0,1e-06,1e-05,0.0001,0.001,0.01,0.1,1.0
Time Elapsed HO,4.886813,6.083797,4.76736,4.742946,6.147441,4.505516,4.431655,5.879815
MSE Train HO,0.001406,0.001406,0.001401,0.001355,0.001045,0.034854,0.511213,0.520404
MSE Test HO,0.003212,0.00321,0.003206,0.00304,0.002383,0.064625,0.779294,0.831041
