In [3]:
import numpy as np

import jax
import jax.numpy as jnp

# -------------- helper libraries -------------- #
import sys
import os
import time
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import importlib

path_ = os.path.abspath(os.path.join('..', '00_utils'))

if path_ not in sys.path:
    sys.path.append(path_)
    
path_ = os.path.abspath(os.path.join('..', '00_models'))

if path_ not in sys.path:
    sys.path.append(path_)
    
path_ = os.path.abspath(os.path.join('..', '00_utils_training'))

if path_ not in sys.path:
    sys.path.append(path_)
    
import logging
logging.basicConfig(level=logging.ERROR, filename='error_log.txt')

In [8]:
importlib.reload(run_train_toy)
TrainerToy = run_train_toy.TrainerToy

def load_trainer(type_, spacing_type="uniform"):
    data_params_ho = {
        'N': 200,
        'noise_level': 0.2,
        'ode_type': "harmonic_oscillator",
        'data_param': {"omega_squared": 2},
        'start_time': 0,
        'end_time': 10,
        'spacing_type': spacing_type,
        'initial_state': np.array([0.0, 1.0])
    }

    data_params_vdp = {
        'N': 200,
        'noise_level': 0.1,
        'ode_type': "van_der_pol",
        'data_param': {"mu": 1, "omega": 1},
        'start_time': 0,
        'end_time': 15,
        'spacing_type': spacing_type,
        'initial_state': np.array([0.0, 1.0])
    }

    data_params_do = {
        'N': 200,
        'noise_level': 0.1,
        'ode_type': "damped_oscillation",
        'data_param': {"damping_factor": 0.1, "omega_squared": 1},
        'start_time': 0,
        'end_time': 10,
        'spacing_type': spacing_type,
        'initial_state': np.array([0.0, 1.0])
    }

    if type_ == "ho":
        p_ = data_params_ho
    elif type_ == "vdp":
        p_ = data_params_vdp
    elif type_ == "do":
        p_ = data_params_do
    else:
        raise ValueError(f"Invalid type {type_}")

    trainer = TrainerToy(p_, model_type="pytorch")
    trainer.prepare_inputs()
    return trainer

trainer_ho = load_trainer("ho")

In [17]:
params_model = {
    'layer_widths': [2, 32, 2],
    'penalty_lambda_reg': 1e-5,
    'time_invariant': True,
    'learning_rate': 1e-3,
    'max_iter': 200,
    'pretrain': [0.2, 1]
}

In [18]:
trainer_ho.train(params_model)

Epoch 0, Loss: 1.2716577053070068
Epoch 100, Loss: 0.04143500700592995
Epoch 0, Loss: 2.9366061687469482
Epoch 100, Loss: 0.059821780771017075


In [19]:
trainer_ho.extract_results()

{'odeint_pred': tensor([[ 0.0000,  1.0000],
         [ 0.0610,  0.9930],
         [ 0.1212,  0.9811],
         [ 0.1803,  0.9642],
         [ 0.2379,  0.9426],
         [ 0.2938,  0.9165],
         [ 0.3477,  0.8859],
         [ 0.3992,  0.8512],
         [ 0.4483,  0.8126],
         [ 0.4946,  0.7702],
         [ 0.5381,  0.7244],
         [ 0.5784,  0.6754],
         [ 0.6155,  0.6235],
         [ 0.6493,  0.5690],
         [ 0.6796,  0.5121],
         [ 0.7064,  0.4533],
         [ 0.7295,  0.3926],
         [ 0.7490,  0.3306],
         [ 0.7647,  0.2674],
         [ 0.7766,  0.2034],
         [ 0.7848,  0.1387],
         [ 0.7891,  0.0739],
         [ 0.7897,  0.0090],
         [ 0.7864, -0.0556],
         [ 0.7794, -0.1196],
         [ 0.7686, -0.1828],
         [ 0.7542, -0.2450],
         [ 0.7361, -0.3057],
         [ 0.7145, -0.3649],
         [ 0.6894, -0.4222],
         [ 0.6609, -0.4775],
         [ 0.6291, -0.5304],
         [ 0.5942, -0.5807],
         [ 0.5562, -0.6284],