In [2]:
import jax 
import scipy
import jax.numpy as jnp
import diffrax as dfx
import optax as opx

In [3]:
from simulation import generate_y0, simulate_pdu, measure_pdu
from pdu_rhs import V_RATIO
import numpy as np

np.random.seed(42)

data_ICs = jnp.array([80, 0, 0, 0]) 
params = jnp.array([.5, .4, .8, .9, .6])
y0 = generate_y0(params, data_ICs)

t0 = 0.0
t1 = 10.0*3600
dt = 1e-2

params = jnp.array([ .5, .4, .8, 1e8, 40, \
1e3, -3000, 100, 50, 200, 140, \
1e4, -5000, 40, 10, 100, 70, \
1e9,  -10000, 50, 85, \
1e9, -10000, 10, 100, \
-5, -6, \
-4, 1.2*V_RATIO])

t_out, y_out = measure_pdu(simulate_pdu(params, y0, t0, t1, dt, KO=None))
y_out_sim = y_out + np.random.normal(0, 1, y_out.shape)

In [4]:
from objective import construct_loss, parameterize_loss, objective

options = {"loss_fn": "MSE",
           "t_weight": None,
           "weight": None,
           }

loss_fn = construct_loss(options)
loss_fn = parameterize_loss(y_out, t_out, loss_fn)

m = loss_fn(y_out, y_out_sim)

In [5]:
m

Array(0.97860841, dtype=float64)

In [6]:
#objective_jit = jax.jit(objective, static_argnums=0)
from collections import namedtuple

KO = None
jnp.squeeze(objective(loss_fn, params, y_out_sim, t_out, dt, data_ICs, KO))

Array(0.97860841, dtype=float64)

In [7]:
from run_adam_tune import run_adam_loop, batch_run_adam

obj = lambda p : jnp.squeeze(objective(loss_fn, p, y_out_sim, t_out, dt, data_ICs, KO)).reshape(())

#batch_run_adam(jnp.array([1.01*params, 1.0001*params, 2*params]), 1e-3, .9, .999, 100, obj)

In [8]:
from lhs_sampling import lh_samples

In [11]:
pars = lh_samples(100)[50:55]
batch_run_adam(pars, 1e-3, .9, .999, 100, obj)

Array([  82.47785683, 1231.69895018,   38.45305812,   42.25219098,
        184.13129774], dtype=float64)

In [9]:
"""from tuning import tune_adam_hyperparams

best = tune_adam_hyperparams(obj, pars, num_steps=10, num_trials=5)"""

'from tuning import tune_adam_hyperparams\n\nbest = tune_adam_hyperparams(obj, pars, num_steps=10, num_trials=5)'

In [None]:
from run_adam_full import run_adam_with_outputs

run_adam_with_outputs(pars[2], 1e-3, .9, .999, 100, obj)

{'optimal_params': Array([-1.23649948e-01,  1.38424377e-01,  4.64136999e-01,  2.99341918e+07,
         9.83041921e+00,  6.56734128e+05, -3.73873497e+04,  1.19624224e+02,
         1.35542248e+02,  3.83614137e+02,  1.09832043e+02,  5.58582573e+05,
        -6.93863542e+02,  3.83072980e+01,  1.12800823e+01,  1.76500584e+02,
         5.73229690e+01,  8.51042498e+05,  2.46648537e+04,  3.70562664e+01,
         1.67781024e+02,  9.30549865e+02, -3.66088139e+04,  7.57422066e+01,
         3.55876177e+02, -5.94864119e+00, -9.70998575e+00, -3.45326162e+00,
         9.33680100e+03], dtype=float64),
 'loss': Array(1.6418833e+09, dtype=float64),
 'grad': Array([-1.92201054e-02,  0.00000000e+00,  0.00000000e+00,  1.09781527e+02,
         2.80449542e+04, -5.81336978e-27,  3.68853315e-39,  3.19152611e-23,
         2.81671462e-23, -2.05804340e-41,  4.07748415e-39,  4.64506832e-27,
        -1.07336697e-43, -6.77326345e-23, -2.30020858e-22, -6.24328401e-39,
        -6.90853449e-39, -8.96334059e-33, -2.67578

In [10]:
J = jnp.array([-3.76411976e-05,  0.00000000e+00,  0.00000000e+00, -7.34038476e-11,
         1.62801508e-04,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00, -2.06295554e+00,  0.00000000e+00,  3.15587644e-02,
         6.39481504e-05])

In [11]:
jax.jacobian(obj)(J)

ERROR:2025-04-11 00:35:08,899:jax._src.callback:102: jax.pure_callback failed
Traceback (most recent call last):
  File "/Users/ccat/.pyenv/versions/anaconda3-2024.10-1/envs/pdu_fit/lib/python3.13/site-packages/jax/_src/callback.py", line 100, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
                                          ~~~~~~~~^^^^^^^
  File "/Users/ccat/.pyenv/versions/anaconda3-2024.10-1/envs/pdu_fit/lib/python3.13/site-packages/jax/_src/callback.py", line 77, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/ccat/.pyenv/versions/anaconda3-2024.10-1/envs/pdu_fit/lib/python3.13/site-packages/equinox/_errors.py", line 89, in raises
    raise _EquinoxRuntimeError(
    ...<9 lines>...
    )
equinox._errors._EquinoxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.


--------------------
An 

XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
  File "/Users/ccat/.pyenv/versions/anaconda3-2024.10-1/envs/pdu_fit/lib/python3.13/site-packages/jax/_src/callback.py", line 778, in _wrapped_callback
  File "/Users/ccat/.pyenv/versions/anaconda3-2024.10-1/envs/pdu_fit/lib/python3.13/site-packages/jax/_src/callback.py", line 224, in _callback
  File "/Users/ccat/.pyenv/versions/anaconda3-2024.10-1/envs/pdu_fit/lib/python3.13/site-packages/jax/_src/callback.py", line 103, in pure_callback_impl
  File "/Users/ccat/.pyenv/versions/anaconda3-2024.10-1/envs/pdu_fit/lib/python3.13/site-packages/jax/_src/callback.py", line 77, in __call__
  File "/Users/ccat/.pyenv/versions/anaconda3-2024.10-1/envs/pdu_fit/lib/python3.13/site-packages/equinox/_errors.py", line 89, in raises
_EquinoxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.


--------------------
An error occurred during the runtime of your JAX program! Unfortunately you do not appear to be using `equinox.filter_jit` (perhaps you are using `jax.jit` instead?) and so further information about the error cannot be displayed. (Probably you are seeing a very large but uninformative error message right now.) Please wrap your program with `equinox.filter_jit`.
--------------------


In [14]:
objective_fn = lambda x: jnp.sum((x - 1.0)**2)
hessian_fn = jax.jacobian(jax.grad(objective_fn))
hess = hessian_fn(jnp.array([1.0, 2.0, 3.0]))


In [16]:
objective_fn(jnp.array([1.0, 2.0, 3.0]))

Array(5., dtype=float64)