In [5]:
import jax 
import scipy
import jax.numpy as jnp
import diffrax as dfx
import optax as opx

In [6]:
from simulation import generate_y0, simulate_pdu, measure_pdu
from pdu_rhs import V_RATIO
import numpy as np

data_ICs = jnp.array([80, 0, 0, 0]) 
params = jnp.array([.5, .4, .8, .9, .6])
y0 = generate_y0(params, data_ICs)

t0 = 0.0
t1 = 10.0*3600
dt = 1e-2

params = jnp.array([ .5, .4, .8, 1e8, 40, \
1e3, -3000, 100, 50, 200, 140, \
1e4, -5000, 40, 10, 100, 70, \
1e9,  -10000, 50, 85, \
1e9, -10000, 10, 100, \
-5, -6, \
-4, 1.2*V_RATIO])

t_out, y_out = measure_pdu(simulate_pdu(params, y0, t0, t1, dt, KO=None))
y_out_sim = y_out + np.random.normal(0, 1, y_out.shape)

In [7]:
from objective import construct_loss, parameterize_loss, objective

options = {"loss_fn": "MSE",
           "t_weight": None,
           "weight": None,
           }

loss_fn = construct_loss(options)
loss_fn = parameterize_loss(y_out, t_out, loss_fn)

loss_fn(y_out, y_out_sim)

Array(1.04712089, dtype=float64)

In [8]:
#objective_jit = jax.jit(objective, static_argnums=0)
from collections import namedtuple

KO = None
objective(loss_fn, params, y_out_sim, t_out, dt, data_ICs, KO)

Array(1.04712089, dtype=float64)

In [13]:
from run_adam_tune import run_adam_loop, batch_run_adam

obj = lambda p : objective(loss_fn, p, y_out_sim, t_out, dt, data_ICs, KO)

batch_run_adam(jnp.array([1.01*params, 1.0001*params, 2*params]), 1e-3, .9, .999, 100, obj)

Array([1.04740582e+00, 1.04685095e+00, 1.56712577e+03], dtype=float64)

In [14]:
from lhs_sampling import lh_samples

In [18]:
pars = lh_samples(100)[50:55]
batch_run_adam(pars, 1e-3, .9, .999, 100, obj)

Array([1179.88251476, 1119.48839873,  383.26850638,  833.74952948,
         43.1409696 ], dtype=float64)