diff --git a/neural_tangents/_src/empirical.py b/neural_tangents/_src/empirical.py index 9009012..1c41abf 100644 --- a/neural_tangents/_src/empirical.py +++ b/neural_tangents/_src/empirical.py @@ -125,6 +125,7 @@ from jax.core import Var from jax.extend import linear_util as lu +from jax.extend import source_info_util from jax.interpreters import ad from jax.interpreters.ad import UndefinedPrimal @@ -1863,7 +1864,7 @@ def read_cotangent(v: Var) -> Union[jnp.ndarray, Zero]: map(functools.partial(_write_primal, primal_env), jaxpr.invars, primals_in) ct_env: dict[Var, jnp.ndarray] = {} - ctx = ad.source_info_util.transform_name_stack('transpose') + ctx = source_info_util.transform_name_stack('transpose') with ctx: map(functools.partial(_write_cotangent, 'outvars', ct_env), jaxpr.outvars, cotangents_in) @@ -2174,10 +2175,10 @@ def _eqn_vjp_fn( # Identity function return cts_in, - name_stack = (ad.source_info_util.current_name_stack() + + name_stack = (source_info_util.current_name_stack() + eqn.source_info.name_stack) - with ad.source_info_util.user_context(eqn.source_info.traceback, - name_stack=name_stack): + with source_info_util.user_context(eqn.source_info.traceback, + name_stack=name_stack): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] params = dict(eqn.params) @@ -2263,7 +2264,7 @@ def _write_cotangent( return ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct - if ad.config.jax_enable_checks: + if jax.config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join( v.aval, ct_aval).strip_weak_type().strip_named_shape()