In order to use adaptive step size via PIDController, we need to use an SDE solver w/ embedded error estimate. For this, SPaRK is recommended.

## Preliminaries

### Code Setup

In [1]:
# Main imports
import sys
from itertools import count

# Import jax and utils
from jax import numpy as jnp
import jax.random as jr
from typing import Union
from jaxtyping import Float, Array

# Additional, custom codebase
sys.path.append("../..")
sys.path.append("../../..")

# Import dynamax
from dynamax.parameters import ParameterProperties
# import discrete-time filters
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_smoother, unscented_kalman_smoother

# Our own custom src codebase
# continuous-discrete nonlinear Gaussian SSM codebase
from continuous_discrete_nonlinear_gaussian_ssm import ContDiscreteNonlinearGaussianSSM
from continuous_discrete_nonlinear_gaussian_ssm import cdnlgssm_filter, cdnlgssm_forecast
# Load models
from continuous_discrete_nonlinear_gaussian_ssm.models import *

# Useful utility functions
from simulation_utils import *

# Plotting
import matplotlib
%matplotlib inline
# Our own custom plotting codebase
from utils.plotting_utils import *
from utils.test_utils import compare_structs
from lorenz_plotting import *
# Feel free to change the default figure size
#matplotlib.rcParams['figure.figsize'] = [16, 9]

# For pretty print of ndarrays
jnp.set_printoptions(formatter={"float_kind": "{:.2f}".format})

import diffrax as dfx

- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.


## True model set up

We generate data from a Lorenz 63 system, from dynamics with the following stochastic differential equations:

\begin{align*}
\frac{d x}{d t} &= a(y-x) + \sigma w_x(t) \\
\frac{d y}{d t} &= x(b-z) - y + \sigma w_y(t) \\
\frac{d z}{d t} &= xy - cz + \sigma w_z(t),
\end{align*}

With parameters $a=10, b=28, c=8/3$, the system gives rise to chaotic behavior, and we choose $\sigma=1.0$ for diffusion.

To generate data, we numerically approximate random path solutions to this SDE using Heun's method (i.e. improved Euler), as implemented in [Diffrax](https://docs.kidger.site/diffrax/api/solvers/sde_solvers/).


We assume the observation model is
\begin{align*}
y(t) &= H x(t) + r(t) \\
r(t) &\sim N(0,R),
\end{align*}
where we choose $R=I$. 

Namely, we impose partial observability with H=[1, 0, 0], with noisy observations, sampled at irregular time intervals.

### True, data-generating model definition

In [2]:
## Main settings
state_dim = 3
emission_dim = 1

# Define a custom drift model, inherited from LearnableFunction
class lorenz63_drift(LearnableFunction):
    params: Union[Float[Array, "state_dim"], ParameterProperties]

    def f(self, x, u=None, t=None):
        foo = jnp.array(
            [
                self.params[0] * (x[1] - x[0]),
                self.params[1] * x[0] - x[1] - x[0] * x[2],
                -self.params[2] * x[2] + x[0] * x[1],
            ]
        )
        return foo

# Define the true parameters of the drift function
true_l63_drift_params = jnp.array([10.0, 28.0, 8 / 3])
# And the corresponding Lorenz 63 system
true_drift = {
    "params": lorenz63_drift(
        params=true_l63_drift_params
    ),
    "props": lorenz63_drift(
        params=ParameterProperties()
    ),
}

# Define the true parameters of the diffusion function
true_diffusion_cov = {
    "params": LearnableMatrix(
        params=jnp.eye(state_dim)
    ),
    "props": LearnableMatrix(
        params=ParameterProperties(
            constrainer=RealToPSDBijector()
        )
    ),
}

# Define the true parameters of the diffusion function
true_diffusion_coefficient_param = 1.0
true_diffusion_coefficient = {
    "params": LearnableMatrix(
        params=true_diffusion_coefficient_param * jnp.eye(state_dim)
    ),
    "props": LearnableMatrix(
        params=ParameterProperties()
    ),
}

# Define the true parameters of the emission function
# Partial observability
H=jnp.array(
    [[1.0, 0.0, 0.0]]
)
true_emission = {
    "params": LearnableLinear(
        weights=H,
        bias=jnp.zeros(emission_dim)
    ),
    "props": LearnableLinear(
        weights=ParameterProperties(),
        bias=ParameterProperties()
    ),
}

# Define the true parameters of the emission covariance
R=jnp.eye(emission_dim)
true_emission_cov = {
    "params": LearnableMatrix(
        params=R
    ),
    "props": LearnableMatrix(
        params=ParameterProperties(
            constrainer=RealToPSDBijector()
        )
    ),
}

# Define the true initial mean and covariance
true_initial_mean = {
    "params": LearnableVector(
        params=jnp.zeros(state_dim)
    ),
    "props": LearnableVector(
        params=ParameterProperties()
    ),
}

true_initial_cov_param = 5.0
true_initial_cov = {
    "params": LearnableMatrix(
        params=true_initial_cov_param*jnp.eye(state_dim)
    ),
    "props": LearnableMatrix(
        params=ParameterProperties(
            constrainer=RealToPSDBijector()
        )
    ),
}

# Concatenate all parameters in dictionary, for later easy use
all_true_params = {
    'initial_mean': true_initial_mean,
    'initial_cov': true_initial_cov,
    'dynamics_drift': true_drift,
    'dynamics_diffusion_coefficient': true_diffusion_coefficient,
    'dynamics_diffusion_cov': true_diffusion_cov,
    'dynamics_approx_order': 2., # Check on this later
    'emission_function': true_emission,
    'emission_cov': true_emission_cov,
}

In [3]:
# Define hifi diffeqsolve settings
hifi_forward_settings = {
    "solver": dfx.Heun(),
    "dt0": 1e-5,
    "stepsize_controller": dfx.ConstantStepSize(),
    # dfx.PIDController(atol=1e-9, rtol=1e-7, pcoeff=0.1, icoeff=0.3, dcoeff=0),
    "tol_vbt": 1e-5,
    "max_steps": 1e6,
}  # empty uses default settings

# Define hifi diffeqsolve settings
hifi_logprob_settings = {
    "solver": dfx.Tsit5(),
    # "dt0": 1e-5,
    "stepsize_controller": dfx.PIDController(atol=1e-9, rtol=1e-9),
    "tol_vbt": 1e-5,
    "max_steps": 1e7,
}  # empty uses default settings

# Simulation 1: Fast sample rate and modest initial state covariance

### Simulation set-up

In [4]:
# Set up seed for simulation
keys = map(jr.PRNGKey, count())

# Simulation parameters
T_total = 50
T_filter = int(0.8 * T_total)
num_timesteps_total = int(T_total / 0.005)

# Generate time points for measurements, filtering and forecasting
sim1_t, sim1_t_filter, sim1_t_forecast, \
    sim1_num_timesteps, sim1_num_timesteps_filter, sim1_num_timesteps_forecast= \
    generate_irregular_t_emissions(
        T_total=T_total,
        num_timesteps=num_timesteps_total,
        T_filter=T_filter,
        key=next(keys)
    )


# Create CD-NLGSSM model
hifi_model = ContDiscreteNonlinearGaussianSSM(state_dim, emission_dim, diffeqsolve_settings=hifi_forward_settings)
model_params, _ = hifi_model.initialize(next(keys), **all_true_params)

# generate hifi synthetic data
hifi_states, hifi_emissions = hifi_model.sample(
    model_params, next(keys), sim1_num_timesteps, sim1_t, transition_type="path"  # uses the Euler-Maruyama method
)

Sampling from SDE solver path (this may be an unnecessarily poor approximation if you're simulating from a linear SDE). It is an appropriate choice for non-linear SDEs.


In [5]:
print("Running grad and loss via EKF with hifi diffeqsolve settings...")
# compute a hifi marginal_log_prob (via EKF)
def loss_fn(params):
    return hifi_model.marginal_log_prob(
        params,
        hifi_emissions,
        sim1_t,
        filter_hyperparams=EKFHyperParams(diffeqsolve_settings=hifi_logprob_settings),
    )

loss_grad_fn = jax.value_and_grad(loss_fn)

# evaluate the loss and the gradient and time them
%timeit -n 1 -r 1 loss_grad_fn(model_params)
hifi_loss, hifi_grad = loss_grad_fn(model_params)

print(f"Hifi Loss: {hifi_loss}")

Running grad and loss via EKF with hifi diffeqsolve settings...
8.74 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Hifi Loss: -14591.8759765625


In [20]:
print("Running grad and loss via EKF with default diffeqsolve settings...")

def loss_fn(params):
    return hifi_model.marginal_log_prob(
        params,
        hifi_emissions,
        sim1_t,
        filter_hyperparams=EKFHyperParams(),
    )

# first time how long it takes to compute the marginal log prob
%timeit -n 1 -r 1 loss_fn(model_params)

# Next, time how long it takes to compute the gradient and loss together
loss_grad_fn = jax.value_and_grad(loss_fn)

# evaluate the loss and the gradient
%timeit -n 1 -r 1 loss_grad_fn(model_params)
default_loss, default_grad = loss_grad_fn(model_params)

print(f"Hifi marginal log prob: {hifi_loss}")
print(f"Default marginal log prob: {default_loss}")

compare_structs(hifi_grad, default_grad, max_tol=-3, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, default_grad, max_tol=-2, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, default_grad, max_tol=-1, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, default_grad, max_tol=0, max_tol_rel=-5, accept_failure=True)



Running grad and loss via EKF with default diffeqsolve settings...
582 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
3.64 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Hifi marginal log prob: -14591.8759765625
Default marginal log prob: -14591.8759765625
Fields that are close within atol=0.001 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.001 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']
Fields that are close within atol=0.01 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.01 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissi

In [21]:
print("Running grad and loss via EKF with NEW diffeqsolve settings...")
new_diffeqsolve_settings = {
    "solver": dfx.Tsit5(),
    "dt0": 0.1,
    "stepsize_controller": dfx.PIDController(rtol=1e-3, atol=1e-6),
    "adjoint": dfx.RecursiveCheckpointAdjoint(checkpoints=100), #more checkpoints should give faster gradients
    # "tol_vbt": 1, # This won't matter for ODEs
    "max_steps": 100,
}  # empty uses default settings

def loss_fn(params):
    return hifi_model.marginal_log_prob(
        params,
        hifi_emissions,
        sim1_t,
        filter_hyperparams=EKFHyperParams(diffeqsolve_settings=new_diffeqsolve_settings),
    )

# first time how long it takes to compute the marginal log prob
%timeit -n 1 -r 1 loss_fn(model_params)

# Next, time how long it takes to compute the gradient and loss together
loss_grad_fn = jax.value_and_grad(loss_fn)

# evaluate the loss and the gradient
%timeit -n 1 -r 1 loss_grad_fn(model_params)
new_loss, new_grad = loss_grad_fn(model_params)

print(f"Hifi marginal log prob: {hifi_loss}")
print(f"New marginal log prob: {new_loss} ")

compare_structs(hifi_grad, new_grad, max_tol=-3, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, new_grad, max_tol=-2, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, new_grad, max_tol=-1, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, new_grad, max_tol=0, max_tol_rel=-5, accept_failure=True)



Running grad and loss via EKF with NEW diffeqsolve settings...
638 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
4.13 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Hifi marginal log prob: -14591.8759765625
New marginal log prob: -14591.8740234375 
Fields that are close within atol=0.001 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.001 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']
Fields that are close within atol=0.01 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.01 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emi

In [22]:
print("Running grad and loss via EKF with lowfi diffeqsolve settings...")
lowfi_diffeqsolve_settings = {
    "solver": dfx.Euler(),
    "dt0": 1e-4,
    # "stepsize_controller": dfx.ConstantStepSize(),
    "max_steps": 1e3,
}  # empty uses default settings

def loss_fn(params):
    return hifi_model.marginal_log_prob(
        params,
        hifi_emissions,
        sim1_t,
        filter_hyperparams=EKFHyperParams(diffeqsolve_settings=lowfi_diffeqsolve_settings),
    )

# first time how long it takes to compute the marginal log prob
%timeit -n 1 -r 1 loss_fn(model_params)

# Next, time how long it takes to compute the gradient and loss together
loss_grad_fn = jax.value_and_grad(loss_fn)

# evaluate the loss and the gradient
%timeit -n 1 -r 1 loss_grad_fn(model_params)
lowfi_loss, lowfi_grad = loss_grad_fn(model_params)

print(f"Hifi marginal log prob: {hifi_loss}")
print(f"Lowfi marginal log prob: {lowfi_loss} ")

compare_structs(hifi_grad, lowfi_grad, max_tol=-3, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, lowfi_grad, max_tol=-2, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, lowfi_grad, max_tol=-1, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, lowfi_grad, max_tol=0, max_tol_rel=-5, accept_failure=True)



Running grad and loss via EKF with lowfi diffeqsolve settings...
423 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
2.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Hifi marginal log prob: -14591.8759765625
Lowfi marginal log prob: -14592.40625 
Fields that are close within atol=0.001 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.001 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']
Fields that are close within atol=0.01 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.01 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emiss

In [23]:
print("Running grad and loss via EKF with lowfi diffeqsolve settings...")
lowfi_diffeqsolve_settings = {
    "solver": dfx.Heun(), # cheap ODE solver recommended for Neural ODE
    "dt0": 1e-3,
    # "stepsize_controller": dfx.ConstantStepSize(),
    "max_steps": 1e4,
}  # empty uses default settings

def loss_fn(params):
    return hifi_model.marginal_log_prob(
        params,
        hifi_emissions,
        sim1_t,
        filter_hyperparams=EKFHyperParams(diffeqsolve_settings=lowfi_diffeqsolve_settings),
    )

# first time how long it takes to compute the marginal log prob
%timeit -n 1 -r 1 loss_fn(model_params)

# Next, time how long it takes to compute the gradient and loss together
loss_grad_fn = jax.value_and_grad(loss_fn)

# evaluate the loss and the gradient
%timeit -n 1 -r 1 loss_grad_fn(model_params)
lowfi_loss, lowfi_grad = loss_grad_fn(model_params)

print(f"Hifi marginal log prob: {hifi_loss}")
print(f"Lowfi marginal log prob: {lowfi_loss} ")

compare_structs(hifi_grad, lowfi_grad, max_tol=-3, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, lowfi_grad, max_tol=-2, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, lowfi_grad, max_tol=-1, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, lowfi_grad, max_tol=0, max_tol_rel=-5, accept_failure=True)



Running grad and loss via EKF with lowfi diffeqsolve settings...
541 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
3.32 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Hifi marginal log prob: -14591.8759765625
Lowfi marginal log prob: -14591.9013671875 
Fields that are close within atol=0.001 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.001 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']
Fields that are close within atol=0.01 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.01 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions

In [29]:
print("Running grad and loss via EKF with lowfi diffeqsolve settings...")
lowfi_diffeqsolve_settings = {
    "solver": dfx.Heun(), # cheap ODE solver recommended for Neural ODE
    "dt0": 1e-3,
    # "stepsize_controller": dfx.ConstantStepSize(),
    "max_steps": 1e2,
}  # empty uses default settings

def loss_fn(params):
    return hifi_model.marginal_log_prob(
        params,
        hifi_emissions,
        sim1_t,
        filter_hyperparams=EKFHyperParams(diffeqsolve_settings=lowfi_diffeqsolve_settings),
    )

# first time how long it takes to compute the marginal log prob
%timeit -n 1 -r 1 loss_fn(model_params)

# Next, time how long it takes to compute the gradient and loss together
loss_grad_fn = jax.value_and_grad(loss_fn)

# evaluate the loss and the gradient
%timeit -n 1 -r 1 loss_grad_fn(model_params)
lowfi_loss, lowfi_grad = loss_grad_fn(model_params)

print(f"Hifi marginal log prob: {hifi_loss}")
print(f"Lowfi marginal log prob: {lowfi_loss} ")

compare_structs(hifi_grad, lowfi_grad, max_tol=0, max_tol_rel=-3, accept_failure=True)



Running grad and loss via EKF with lowfi diffeqsolve settings...
579 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
3.2 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Hifi marginal log prob: -14591.8759765625
Lowfi marginal log prob: -14591.9013671875 
Fields that are close within atol=1 and rtol=0.001: ['dynamics.diffusion_cov.params', 'dynamics.approx_order', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params']
Fields that are different within tol=1 and rtol=0.001: ['dynamics.drift.params', 'dynamics.diffusion_coefficient.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']


In [31]:
hifi_grad, lowfi_grad

(ParamsCDNLGSSM(initial=ParamsLGSSMInitial(mean=LearnableVector(params=Array([0.13, -0.56, 0.51], dtype=float32)), cov=LearnableMatrix(params=Array([[-0.09, -0.31, 0.30],
        [0.22, 0.06, -0.14],
        [-0.23, -0.14, 0.04]], dtype=float32))), dynamics=ParamsCDNLGSSMDynamics(drift=lorenz63_drift(params=Array([-22.51, -15.90, -72.80], dtype=float32)), diffusion_coefficient=LearnableMatrix(params=Array([[1.90, 1.44, -5.00],
        [1.44, 3.88, -1.36],
        [-5.00, -1.36, 3.90]], dtype=float32)), diffusion_cov=LearnableMatrix(params=Array([[0.95, 1.11, -2.51],
        [0.33, 1.94, -0.67],
        [-2.49, -0.69, 1.95]], dtype=float32)), approx_order=Array(0.00, dtype=float32, weak_type=True)), emissions=ParamsCDNLGSSMEmissions(emission_function=LearnableLinear(weights=Array([[335.84, 332.45, 1477.29]], dtype=float32), bias=Array([42.21], dtype=float32)), emission_cov=LearnableMatrix(params=Array([[17.11]], dtype=float32)))),
 ParamsCDNLGSSM(initial=ParamsLGSSMInitial(mean=Learnabl

In [25]:
print(f"Hifi marginal grad: {hifi_grad} ")
print(f"Default marginal grad: {default_grad}")
print(f"Lowfi marginal grad: {lowfi_grad}")

Hifi marginal grad: ParamsCDNLGSSM(initial=ParamsLGSSMInitial(mean=LearnableVector(params=Array([0.13, -0.56, 0.51], dtype=float32)), cov=LearnableMatrix(params=Array([[-0.09, -0.31, 0.30],
       [0.22, 0.06, -0.14],
       [-0.23, -0.14, 0.04]], dtype=float32))), dynamics=ParamsCDNLGSSMDynamics(drift=lorenz63_drift(params=Array([-22.51, -15.90, -72.80], dtype=float32)), diffusion_coefficient=LearnableMatrix(params=Array([[1.90, 1.44, -5.00],
       [1.44, 3.88, -1.36],
       [-5.00, -1.36, 3.90]], dtype=float32)), diffusion_cov=LearnableMatrix(params=Array([[0.95, 1.11, -2.51],
       [0.33, 1.94, -0.67],
       [-2.49, -0.69, 1.95]], dtype=float32)), approx_order=Array(0.00, dtype=float32, weak_type=True)), emissions=ParamsCDNLGSSMEmissions(emission_function=LearnableLinear(weights=Array([[335.84, 332.45, 1477.29]], dtype=float32), bias=Array([42.21], dtype=float32)), emission_cov=LearnableMatrix(params=Array([[17.11]], dtype=float32)))) 
Default marginal grad: ParamsCDNLGSSM(initi

In [26]:
print(f"Hifi marginal log prob: {hifi_loss} ")
print(f"Default marginal log prob: {default_loss}")
print(f"Lowfi marginal log prob: {lowfi_loss}")

Hifi marginal log prob: -14591.8759765625 
Default marginal log prob: -14591.8759765625
Lowfi marginal log prob: -14591.9013671875


In [27]:
print("Comparing hifi and lowfi gradients...")

compare_structs(hifi_grad, lowfi_grad, max_tol=-3, max_tol_rel=-5, accept_failure=True)
# compare_structs(hifi_grad, lowfi_grad, max_tol=-2, max_tol_rel=-5, accept_failure=True)
# compare_structs(hifi_grad, lowfi_grad, max_tol=-1, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, lowfi_grad, max_tol=0, max_tol_rel=-5, accept_failure=True)

Comparing hifi and lowfi gradients...
Fields that are close within atol=0.001 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.001 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']
Fields that are close within atol=1 and rtol=1e-05: ['dynamics.diffusion_cov.params', 'dynamics.approx_order', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params']
Fields that are different within tol=1 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_coefficient.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']


In [28]:
print("Comparing hifi and default gradients...")
compare_structs(hifi_grad, default_grad, max_tol=-3, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, default_grad, max_tol=-2, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, default_grad, max_tol=-1, max_tol_rel=-5, accept_failure=True)
compare_structs(hifi_grad, default_grad, max_tol=0, max_tol_rel=-5, accept_failure=True)

Comparing hifi and default gradients...
Fields that are close within atol=0.001 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.001 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']
Fields that are close within atol=0.01 and rtol=1e-05: ['dynamics.approx_order']
Fields that are different within tol=0.01 and rtol=1e-05: ['dynamics.drift.params', 'dynamics.diffusion_cov.params', 'dynamics.diffusion_coefficient.params', 'initial.mean.params', 'initial.cov.params', 'emissions.emission_cov.params', 'emissions.emission_function.weights', 'emissions.emission_function.bias']
Fields that are close within atol=0.1 and rtol=1e-05: ['dynamics.approx_order', 'initial.mean.params', 'initial.cov.params']
Fields that are different within tol=0.1 and rtol