## Here is Bug #1

In [1]:
import msprime as msp

demo = msp.Demography()
demo.add_population(initial_size=4000, name="anc")
demo.add_population(initial_size=3000, name="P0")
demo.add_population(initial_size=5000, name="P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.00005)
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time=1000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
seed = 90
anc = msp.sim_ancestry(
    samples=samples,
    demography=demo,
    recombination_rate=1e-8,
    sequence_length=1e8,
    random_seed=seed,
)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed=seed + 1)

# demesdraw.tubes(g)

afs_samples = {f"P{i}": int(sample_size * 2) for i in range(2)}
afs = ts.allele_frequency_spectrum(
    sample_sets=[ts.samples([1]), ts.samples([2])], span_normalise=False, polarised=True
)

In [2]:
from typing import Any, List, Mapping, Set, Tuple

import jax.numpy as jnp
import numpy as np
from loguru import logger

from demestats.fit.fit_sfs import _compute_sfs_likelihood
from demestats.fit.util import (
    _dict_to_vec,
    make_whitening_from_hessian,
    pullback_objective,
)
from demestats.loglik.sfs_loglik import prepare_projection
from demestats.sfs import ExpectedSFS

paths = {
    frozenset({("migrations", 0, "rate")}): 0.00001,
    frozenset({("migrations", 1, "rate")}): 0.00001,
    # frozenset({('demes', 0, 'epochs', 0, 'end_size'),
    #   ('demes', 0, 'epochs', 0, 'start_size')}): 3000.,
    frozenset(
        {("demes", 1, "epochs", 0, "end_size"), ("demes", 1, "epochs", 0, "start_size")}
    ): 7000.0,
    frozenset(
        {("demes", 2, "epochs", 0, "end_size"), ("demes", 2, "epochs", 0, "start_size")}
    ): 300.0,
    frozenset(
        {
            ("demes", 0, "epochs", 0, "end_time"),
            ("demes", 1, "start_time"),
            ("demes", 2, "start_time"),
            ("migrations", 0, "start_time"),
            ("migrations", 1, "start_time"),
        }
    ): 7000,
}

logger.disable("demestats")

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

path_order: List[Var] = list(paths)
x0 = _dict_to_vec(paths, path_order)
x0 = jnp.array(x0)

esfs = ExpectedSFS(demo.to_demes(), num_samples=afs_samples)

projection = True
sequence_length = None
num_projections = 150
theta = None
seed = seed * 3
if projection:
    proj_dict, einsum_str, input_arrays = prepare_projection(
        afs, afs_samples, sequence_length, num_projections, seed
    )
else:
    proj_dict, einsum_str, input_arrays = None, None, None

args = (
    path_order,
    esfs,
    proj_dict,
    einsum_str,
    input_arrays,
    sequence_length,
    theta,
    projection,
    afs,
)
L, LinvT = make_whitening_from_hessian(_compute_sfs_likelihood, x0, *args)
g = pullback_objective(_compute_sfs_likelihood, x0, LinvT, *args)

[32m2025-10-31 16:19:06.367[0m | [34m[1mDEBUG   [0m | [36mdemestats.event_tree[0m:[36m_add_node[0m:[36m511[0m - [34m[1mcreating node 0 with attributes {'event': PopulationStart(), 'block': frozenset({'anc'}), 't': ('demes', 0, 'epochs', 0, 'end_time')}[0m
[32m2025-10-31 16:19:06.369[0m | [34m[1mDEBUG   [0m | [36mdemestats.event_tree[0m:[36m_add_node[0m:[36m511[0m - [34m[1mcreating node 1 with attributes {'event': PopulationStart(), 'block': frozenset({'P0'}), 't': ('demes', 1, 'epochs', 0, 'end_time')}[0m
[32m2025-10-31 16:19:06.370[0m | [34m[1mDEBUG   [0m | [36mdemestats.event_tree[0m:[36m_add_node[0m:[36m511[0m - [34m[1mcreating node 2 with attributes {'event': PopulationStart(), 'block': frozenset({'P1'}), 't': ('demes', 2, 'epochs', 0, 'end_time')}[0m
[32m2025-10-31 16:19:06.373[0m | [34m[1mDEBUG   [0m | [36mdemestats.event_tree[0m:[36m_add_node[0m:[36m511[0m - [34m[1mcreating node 3 with attributes {'event': Merge(), 'block': 

Params: [1.1e-05 1.0e-05 7.0e+03 3.0e+02 7.0e+03]
Loss: 16994736.484189354
Params: [9.e-06 1.e-05 7.e+03 3.e+02 7.e+03]
Loss: 17029483.74286553
Params: [1.0e-05 1.1e-05 7.0e+03 3.0e+02 7.0e+03]
Loss: 17010728.420447085
Params: [1.e-05 9.e-06 7.e+03 3.e+02 7.e+03]
Loss: 17012437.63652862
Params: [1.e-05 1.e-05 7.e+03 3.e+02 7.e+03]
Loss: 17011497.901058327
Params: [1.e-05 1.e-05 7.e+03 3.e+02 7.e+03]
Loss: 17011497.900951758
Params: [1.00000000e-05 1.00000000e-05 7.00000000e+03 3.00000001e+02
 7.00000000e+03]
Loss: 17011497.89925146
Params: [1.00000000e-05 1.00000000e-05 7.00000000e+03 2.99999999e+02
 7.00000000e+03]
Loss: 17011497.902758624
Params: [1.e-05 1.e-05 7.e+03 3.e+02 7.e+03]
Loss: 17011497.901008915
Params: [1.e-05 1.e-05 7.e+03 3.e+02 7.e+03]
Loss: 17011497.90100117


In [3]:
error_vec = jnp.array(
    [1.92047513e-05, 1.92047513e-05, 6.60044829e00, 4.32221720e02, 5.80089035e03]
)
untransformed = L.T @ (error_vec - x0)
likelihood, gradients = g(untransformed)

Params: [1.92047513e-05 1.92047513e-05 6.60044829e+00 4.32221720e+02
 5.80089035e+03]


E1031 16:22:28.579905 2403750 pjrt_stream_executor_client.cc:3314] Execution of replica 0 failed: INTERNAL: CpuCallback error calling callback: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 211, in start
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/base_events.py", line 639, in run_forever
  File "/home/jkliang/minico

EquinoxRuntimeError: Above is the stack outside of JIT. Below is the stack inside of JIT:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 211, in start
    self.asyncio_loop.run_forever()
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/base_events.py", line 639, in run_forever
    self._run_once()
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/base_events.py", line 1985, in _run_once
    handle._run()
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    await self.process_one()
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    await dispatch(*args)
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    await super().execute_request(stream, ident, parent)
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    reply_content = await reply_content
                    ^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    res = shell.run_cell(
          ^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3116, in run_cell
    result = self._run_cell(
             ^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3171, in _run_cell
    result = runner(coro)
             ^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3394, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3639, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3699, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_2403750/850597006.py", line 47, in <module>
    L, LinvT = make_whitening_from_hessian(_compute_sfs_likelihood, x0, *args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/fit/util.py", line 53, in make_whitening_from_hessian
    H = finite_difference_hessian(f, x0, *args)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/fit/util.py", line 46, in finite_difference_hessian
    grad_plus_i = grad_f(x_plus)[i]
                  ^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/fit/util.py", line 35, in loglik_static
    return f(params, *args)
           ^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/fit/fit_sfs.py", line 19, in _compute_sfs_likelihood
    loss = -projection_sfs_loglik(esfs, params, proj_dict, einsum_str, input_arrays, sequence_length, theta)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/loglik/sfs_loglik.py", line 42, in projection_sfs_loglik
    result1 = esfs.tensor_prod(proj_dict, params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/__init__.py", line 151, in tensor_prod
    states = _call(
             ^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/__init__.py", line 209, in _call
    states, _ = traverse(et, states, node_callback, lift_callback, aux=aux)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/traverse.py", line 129, in traverse
    state, node_aux = lift_callback(
                      ^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/__init__.py", line 205, in lift_callback
    return events.lift(
           ^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/events/lift.py", line 160, in lift
    etbl = f(pl0.untag(*pops), True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/penzai/core/named_axes.py", line 310, in wrapped_fun
    result_data = recursive_vectorize_step(named_array_arg_leaves, all_names)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/penzai/core/named_axes.py", line 251, in recursive_vectorize_step
    return flat_array_fun([view.unwrap() for view in current_views])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/penzai/core/named_axes.py", line 235, in flat_array_fun
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/events/lift.py", line 158, in f
    return lift_cm(pl, t0, t1, etas, mu, demo, aux["mats"][pops], etbl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/migration.py", line 73, in lift_cm
    return f(pl, t0, t1, etas, mu, demo, aux, etbl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/migration.py", line 196, in _lift_cm_exp
    res = dfx.diffeqsolve(
          ^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/diffrax/_integrate.py", line 1466, in diffeqsolve
    sol = result.error_if(sol, jnp.invert(is_okay(result)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

equinox.EquinoxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

-------------------

An error occurred during the runtime of your JAX program.

1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful
way to debug such errors. This can be interacted with using most of the usual commands
for the Python debugger: `u` and `d` to move up and down frames, the name of a variable
to print its value, etc.

2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can
(mostly) inspect the state of your program as if it was normal Python.

3) See `https://docs.kidger.site/equinox/api/debug/` for more suggestions.


## Bug #2 AHHHHH

In [1]:
import msprime as msp

demo = msp.Demography()
demo.add_population(initial_size=4000, name="anc")
demo.add_population(initial_size=3000, name="P0")
demo.add_population(initial_size=5000, name="P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.00005)
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time=1000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
seed = 90
anc = msp.sim_ancestry(
    samples=samples,
    demography=demo,
    recombination_rate=1e-8,
    sequence_length=1e8,
    random_seed=seed,
)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed=seed + 1)

# demesdraw.tubes(g)

afs_samples = {f"P{i}": int(sample_size * 2) for i in range(2)}
afs = ts.allele_frequency_spectrum(
    sample_sets=[ts.samples([1]), ts.samples([2])], span_normalise=False, polarised=True
)

In [5]:
from typing import Any, List, Mapping, Set, Tuple

import jax.numpy as jnp

from demestats.fit.fit_sfs import _compute_sfs_likelihood
from demestats.fit.util import (
    _dict_to_vec,
    make_whitening_from_hessian,
    pullback_objective,
)
from demestats.loglik.sfs_loglik import prepare_projection
from demestats.sfs import ExpectedSFS

logger.disable("demestats")
x0_bounds = jnp.array(
    [[0, 0.001], [0, 0.001], [100, 10000], [100, 10000], [100, 10000], [100, 3000]]
)
param_paths = [
    ["migrations", 0, "rate"],
    ["migrations", 1, "rate"],
    ["demes", 1, "epochs", 0, "end_size"],
    ["demes", 2, "epochs", 0, "end_size"],
    ["demes", 0, "epochs", 0, "end_time"],
]
et = EventTree(demo.to_demes())
np.random.seed(seed * 3)
paths = {
    et.variable_for(tuple(path)): np.random.uniform(bounds[0], bounds[1])
    for path, bounds in zip(param_paths, x0_bounds)
}

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

path_order: List[Var] = list(paths)
x0 = _dict_to_vec(paths, path_order)
x0 = jnp.array(x0)

esfs = ExpectedSFS(demo.to_demes(), num_samples=afs_samples)

projection = True
sequence_length = None
num_projections = 150
theta = None
seed = seed * 3
if projection:
    proj_dict, einsum_str, input_arrays = prepare_projection(
        afs, afs_samples, sequence_length, num_projections, seed
    )
else:
    proj_dict, einsum_str, input_arrays = None, None, None

args = (
    path_order,
    esfs,
    proj_dict,
    einsum_str,
    input_arrays,
    sequence_length,
    theta,
    projection,
    afs,
)
L, LinvT = make_whitening_from_hessian(_compute_sfs_likelihood, x0, *args)
g = pullback_objective(_compute_sfs_likelihood, x0, LinvT, *args)

Params: [6.95124950e-04 8.08429479e-04 5.76764227e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765298.558383767
Params: [6.93124950e-04 8.08429479e-04 5.76764227e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765300.531173177
Params: [6.94124950e-04 8.09429479e-04 5.76764227e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765213.995668586
Params: [6.94124950e-04 8.07429479e-04 5.76764227e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765385.290961705
Params: [6.94124950e-04 8.08429479e-04 5.76764228e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765299.57818421
Params: [6.94124950e-04 8.08429479e-04 5.76764226e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765299.578556452
Params: [6.94124950e-04 8.08429479e-04 5.76764227e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765299.578375665
Params: [6.94124950e-04 8.08429479e-04 5.76764227e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765299.578364989
Params: [6.94124950e-04 8.08429479e-04 5.76764227e+02 9.04407443e+03
 1.44992273e+03]
Loss: 15765299.578385301
Pa

In [7]:
error_vec = jnp.array(
    [4.17206325e-04, 4.17206325e-04, 4.00882442e03, 1.62232203e01, 2.71517317e03]
)
untransformed = L.T @ (error_vec - x0)
likelihood, gradients = g(untransformed)

Params: [4.17206325e-04 4.17206325e-04 4.00882442e+03 1.62232203e+01
 2.71517317e+03]


E1031 16:27:08.148294 2409806 pjrt_stream_executor_client.cc:3314] Execution of replica 0 failed: INTERNAL: CpuCallback error calling callback: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 211, in start
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/base_events.py", line 639, in run_forever
  File "/home/jkliang/minico

EquinoxRuntimeError: Above is the stack outside of JIT. Below is the stack inside of JIT:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 211, in start
    self.asyncio_loop.run_forever()
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/base_events.py", line 639, in run_forever
    self._run_once()
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/base_events.py", line 1985, in _run_once
    handle._run()
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    await self.process_one()
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    await dispatch(*args)
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    await super().execute_request(stream, ident, parent)
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    reply_content = await reply_content
                    ^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    res = shell.run_cell(
          ^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3116, in run_cell
    result = self._run_cell(
             ^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3171, in _run_cell
    result = runner(coro)
             ^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3394, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3639, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3699, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_2409806/3869709451.py", line 41, in <module>
    L, LinvT = make_whitening_from_hessian(_compute_sfs_likelihood, x0, *args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/fit/util.py", line 53, in make_whitening_from_hessian
    H = finite_difference_hessian(f, x0, *args)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/fit/util.py", line 46, in finite_difference_hessian
    grad_plus_i = grad_f(x_plus)[i]
                  ^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/fit/util.py", line 35, in loglik_static
    return f(params, *args)
           ^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/fit/fit_sfs.py", line 19, in _compute_sfs_likelihood
    loss = -projection_sfs_loglik(esfs, params, proj_dict, einsum_str, input_arrays, sequence_length, theta)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/loglik/sfs_loglik.py", line 42, in projection_sfs_loglik
    result1 = esfs.tensor_prod(proj_dict, params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/__init__.py", line 151, in tensor_prod
    states = _call(
             ^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/__init__.py", line 209, in _call
    states, _ = traverse(et, states, node_callback, lift_callback, aux=aux)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/traverse.py", line 129, in traverse
    state, node_aux = lift_callback(
                      ^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/__init__.py", line 205, in lift_callback
    return events.lift(
           ^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/events/lift.py", line 160, in lift
    etbl = f(pl0.untag(*pops), True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/penzai/core/named_axes.py", line 310, in wrapped_fun
    result_data = recursive_vectorize_step(named_array_arg_leaves, all_names)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/penzai/core/named_axes.py", line 251, in recursive_vectorize_step
    return flat_array_fun([view.unwrap() for view in current_views])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/penzai/core/named_axes.py", line 235, in flat_array_fun
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/events/lift.py", line 158, in f
    return lift_cm(pl, t0, t1, etas, mu, demo, aux["mats"][pops], etbl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/migration.py", line 73, in lift_cm
    return f(pl, t0, t1, etas, mu, demo, aux, etbl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/src/demestats/sfs/migration.py", line 196, in _lift_cm_exp
    res = dfx.diffeqsolve(
          ^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/diffrax/_integrate.py", line 1466, in diffeqsolve
    sol = result.error_if(sol, jnp.invert(is_okay(result)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jkliang/demestats_updated_env/demestats_private/.venv/lib/python3.12/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

equinox.EquinoxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

-------------------

An error occurred during the runtime of your JAX program.

1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful
way to debug such errors. This can be interacted with using most of the usual commands
for the Python debugger: `u` and `d` to move up and down frames, the name of a variable
to print its value, etc.

2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can
(mostly) inspect the state of your program as if it was normal Python.

3) See `https://docs.kidger.site/equinox/api/debug/` for more suggestions.
