In [1]:
import msprime as msp
import demes
import demesdraw

demo = msp.Demography()
demo.add_population(initial_size = 5000, name = "anc")
demo.add_population(initial_size = 5000, name = "P0")
demo.add_population(initial_size = 5000, name = "P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time = 4000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 4
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e7, random_seed = 12)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed = 13)

# demesdraw.tubes(g)

from demesinfer.fit.read_data import get_het_data
import sgkit
import bio2zarr.tskit
import tempfile

d = tempfile.TemporaryDirectory() 
bio2zarr.tskit.convert(ts, d.name + "/ts")
ds = sgkit.load_dataset(d.name + "/ts")

het_matrix, cfg_list = get_het_data(ts, ds, num_samples=5)

In [14]:
paths = { ('migrations', 0, 'rate'):0.0009,
          ('migrations', 1, 'rate'):0.0009,
        }

from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple
from demesinfer.fit.util import _vec_to_dict, _vec_to_dict_jax, process_data
from demesinfer.iicr import IICRCurve
from phlashlib.iicr import PiecewiseConstant
from phlashlib.loglik import loglik
import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap
from loguru import logger
logger.disable("demesinfer")

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

path_order: List[Var] = list(paths)
cfg_mat, deme_names, unique_cfg, matching_indices = process_data(cfg_list)
num_samples = len(cfg_mat)
rho = theta = 1e-8 * 100
k = 2
iicr = IICRCurve(demo=g, k=2)
iicr_call = jax.jit(iicr.__call__)

def process_base_model(deme_names, cfg):
    curve = iicr.curve(num_samples=dict(zip(deme_names, cfg)))
    timepoints = jax.vmap(curve.quantile)(jnp.linspace(0, 1, 64)[1:-1])
    timepoints = jnp.insert(timepoints, 0, 0.0)
    return timepoints

times = jax.vmap(process_base_model, in_axes=(None, 0))(deme_names, jnp.array(unique_cfg))

def compute_loglik(c_index, data, c_map, times, theta, rho):
    c = c_map[c_index]
    t = times[c_index]
    eta = PiecewiseConstant(c=c, t=t)
    return loglik(data, eta, t, theta, rho)

def neg_loglik(vec, path_order, unique_cfg, times, matching_indices, het_matrix, theta, rho, num_samples, deme_names):
    if (vec > jnp.array([0.001, 0.001])).any():
        return jnp.inf

    if (vec < jnp.array([0, 0])).any():
        return jnp.inf 
        
    params = _vec_to_dict_jax(vec, path_order)
    jax.debug.print("Param values: {}", jnp.array(list(params.values())))
        
    c_map = jax.vmap(lambda cfg, time: iicr_call(params=params, t=time, num_samples=dict(zip(deme_names, cfg)))["c"])(
        jnp.array(unique_cfg), times
    )

    # Batched over cfg_mat (matching_indices) and all_tmrca_spans (het_matrix) 
    batched_loglik = vmap(compute_loglik, in_axes=(0, 0, None, None, None, None))(matching_indices, het_matrix, c_map, times, theta, rho)
    loss = -jnp.sum(batched_loglik) / num_samples
    jax.debug.print("Loss: {loss}", loss=loss)
    return loss

vec = jnp.array([0.0001, 0.0001])
loss = neg_loglik(vec, path_order, unique_cfg, times, matching_indices, het_matrix, theta, rho, num_samples, deme_names)
print(loss)
grad = jax.grad(neg_loglik)(vec, path_order, unique_cfg, times, matching_indices, jnp.array(het_matrix), theta, rho, num_samples, deme_names)
print(grad)

Param values: [0.0001 0.0001]
Loss: 11591.9771971039
11591.9771971039
Param values: [0.0001 0.0001]


ERROR:2025-10-15 19:22:53,744:jax._src.callback:94: jax.pure_callback failed
Traceback (most recent call last):
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 436, in wrapped_fn_impl
    param_fn(*args, **kwargs)
  File "<@beartype(phlashlib.gpu._call_kernel) at 0x144c7980d800>", line 74, in _call_kernel
beartype.roar.BeartypeCallHintParamViolation: Function phlashlib.gpu._call_kernel() parameter grad="Array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  Tru...)" violates type hint typing.Union[jaxtyping.Bool[Array, ''], jaxtyping.Bool[ndarray, ''], numpy.bool, numpy.number, bool], as <class "jaxlib._jax.ArrayImpl"> "Array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  Tru...)" not <class "numpy.number">, <class "jaxtyping.Bool[ndarray, '']">, <class "numpy.bool">, bool, or <class "jaxtyping.Bool[Array, '']">.

During handling of the abo

XlaRuntimeError: INTERNAL: CpuCallback error calling callback: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 211, in start
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/base_events.py", line 639, in run_forever
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/base_events.py", line 1985, in _run_once
  File "/home/jkliang/miniconda3/lib/python3.12/asyncio/events.py", line 88, in _run
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3116, in run_cell
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3171, in _run_cell
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3394, in run_cell_async
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3639, in run_ast_nodes
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3699, in run_code
  File "/tmp/ipykernel_3213029/267963253.py", line 66, in <module>
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/api.py", line 405, in grad_f
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/api.py", line 478, in value_and_grad_f
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/api.py", line 2134, in _vjp
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/ad.py", line 313, in vjp
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/ad.py", line 287, in linearize
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/ad.py", line 261, in direct_linearize
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/api_util.py", line 90, in flatten_fun_nokwargs
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 429, in _get_result_paths_thunk
  File "/tmp/ipykernel_3213029/267963253.py", line 58, in neg_loglik
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/api.py", line 1114, in vmap_f
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 644, in _batch_outer
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 660, in _batch_inner
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 342, in flatten_fun_for_vmap
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 429, in _get_result_paths_thunk
  File "/tmp/ipykernel_3213029/267963253.py", line 41, in compute_loglik
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 549, in wrapped_fn
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 473, in wrapped_fn_impl
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/phlashlib/loglik.py", line 87, in loglik
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/api.py", line 1114, in vmap_f
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 644, in _batch_outer
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 660, in _batch_inner
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 342, in flatten_fun_for_vmap
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 429, in _get_result_paths_thunk
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/custom_derivatives.py", line 747, in __call__
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/custom_derivatives.py", line 999, in bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in _true_bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/custom_derivatives.py", line 1003, in bind_with_trace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 621, in process_custom_vjp_call
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/custom_derivatives.py", line 1003, in bind_with_trace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 621, in process_custom_vjp_call
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/custom_derivatives.py", line 1003, in bind_with_trace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/ad.py", line 955, in process_custom_vjp_call
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 709, in batch_subtrace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 709, in batch_subtrace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/custom_derivatives.py", line 848, in _flatten_fwd
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 429, in _get_result_paths_thunk
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 549, in wrapped_fn
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 473, in wrapped_fn_impl
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/phlashlib/gpu.py", line 162, in _gpu_ll_fwd
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/callback.py", line 388, in pure_callback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 634, in bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in _true_bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 662, in bind_with_trace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 531, in process_primitive
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/ffi.py", line 704, in ffi_batching_rule
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 634, in bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in _true_bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 662, in bind_with_trace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/batching.py", line 531, in process_primitive
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/ffi.py", line 704, in ffi_batching_rule
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 634, in bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in _true_bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 662, in bind_with_trace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 1189, in process_primitive
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 268, in cache_miss
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 634, in bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in _true_bind
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 662, in bind_with_trace
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 1189, in process_primitive
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1829, in _pjit_call_impl
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1805, in call_impl_cache_miss
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1780, in _pjit_call_impl_python
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 364, in wrapper
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1372, in __call__
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/callback.py", line 784, in _wrapped_callback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/callback.py", line 223, in _callback
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/callback.py", line 95, in pure_callback_impl
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jax/_src/callback.py", line 70, in __call__
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 551, in wrapped_fn
  File "/home/jkliang/demesinfer_updated_env/demesinfer_private/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 470, in wrapped_fn_impl
TypeCheckError: Type-check error whilst checking the parameters of phlashlib.gpu._call_kernel.
The problem arose whilst typechecking parameter 'grad'.
Actual value: bool[5,101]
Expected type: typing.Union[Bool[Array, ''], Bool[ndarray, ''], numpy.bool, numpy.number, bool].
----------------------
Called with parameters: {
  'pp':
  PSMCParams(
    b=f64[5,101,63],
    d=f64[5,101,63],
    u=f64[5,101,63],
    v=f64[5,101,63],
    emis0=f64[5,101,63],
    emis1=f64[5,101,63],
    pi=f64[5,101,63]
  ),
  'data': i8[5,101,999],
  'grad': bool[5,101],
  'float32': bool[5,101]
}
Parameter annotations: (pp: PyTree[Float[Array, '*batch M'] | Float[ndarray, '*batch M'], 'PSMCParams'], data: Union[Int8[Array, '*batch L'], Int8[ndarray, '*batch L']], grad: Union[Bool[Array, ''], Bool[ndarray, ''], numpy.bool, numpy.number, bool], float32: Union[Bool[Array, ''], Bool[ndarray, ''], numpy.bool, numpy.number, bool]) -> Any.
The current values for each jaxtyping axis annotation are as follows.
M=63
L=999
batch=(5, 101)
The current values for each jaxtyping PyTree structure annotation are as follows.
PSMCParams=PyTreeDef(CustomNode(namedtuple[PSMCParams], [*, *, *, *, *, *, *]))