In [46]:
from jaxnrsur.NRSur7dq4 import NRSur7dq4Model
from jaxnrsur.PolyPredictor import PolyPredictor
import jax.numpy as jnp
import jax
import equinox as eqx
import optax
from jaxtyping import PyTree, Float
jax.config.update("jax_enable_x64", True)  # Use double precision

In [38]:
time = jnp.linspace(-5000, 100, 10000)
params = jnp.array([0.9, 0.1, 0.4, 0.1, 0.5, 0.1, 0.3])
model = NRSur7dq4Model()
target_h = model(time, params)

Try loading file from cache
Cache found and loading data


As an toy example, we will optimize the model parameters of a `NRSur7dq4Model` to match a target waveform generated with the same model but with slightly different parameters. This example does not have inherent physical meaning, but serves to illustrate how to use the model and its gradients for parameter tuning.

In [33]:
def filter_func(x):
    return isinstance(x, PolyPredictor)
filtered_module, filter_static = eqx.partition(model, filter_func, is_leaf=filter_func)
dynamic, static = eqx.partition(filtered_module, eqx.is_array)

In [41]:
def loss(model, time, params):
    model = eqx.combine(eqx.combine(dynamic, static), filter_static, is_leaf=filter_func)
    hp, hc = model(time, params)
    return jnp.mean( (hp - target_h[0])**2 + (hc - target_h[1])**2 )

In [45]:
loss_value = loss(dynamic, time, params)
optim = optax.adam(learning_rate=1e-3)
opt_state = optim.init(eqx.filter(eqx.filter(model, filter_func, is_leaf=filter_func), eqx.is_array))

Up to this point 

In [55]:
def optimize_loss(model: NRSur7dq4Model, opt_state:  PyTree) -> tuple[NRSur7dq4Model, PyTree, Float]:
    values, grads = eqx.filter_jit(eqx.filter_value_and_grad(loss))(model, time, params+1e-3)
    updates, opt_state = optim.update(grads, opt_state, eqx.filter(eqx.filter(model, filter_func, is_leaf=filter_func), eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state, values

In [56]:
optimized_model, opt_state, loss_value = optimize_loss(dynamic, opt_state)

In [57]:
combined_model = eqx.combine(eqx.combine(dynamic, static), filter_static, is_leaf=filter_func)

In [58]:
new_h = combined_model(time, params+1e-3)

In [59]:
new_h[0] - target_h[0], new_h[1] - target_h[1]  # Check the difference

(Array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        5.45840428e-06, 6.28605965e-06, 6.61686950e-06], dtype=float64),
 Array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
        -5.09789162e-06, -4.36389742e-06, -3.19168348e-06], dtype=float64))