Skip to content

Commit

Permalink
Make empirical NTK support list of primitive outputs of length 1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 502904434
  • Loading branch information
romanngg committed Feb 10, 2023
1 parent db1c240 commit 5854a14
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,7 @@ def _backprop_step(
"""Adapted from `jax.interpreters.ad`."""
invals = map(functools.partial(_read_primal, primal_env), eqn.invars)
cts_in = map(read_cotangent, eqn.outvars)
if not eqn.primitive.multiple_results:
if len(cts_in) == 1:
cts_in = cts_in[0]
else:
raise NotImplementedError(
Expand Down

0 comments on commit 5854a14

Please sign in to comment.