Skip to content

Commit

Permalink
Small fix to scan type-check error message.
Browse files Browse the repository at this point in the history
  • Loading branch information
LenaMartens committed Dec 9, 2022
1 parent 440b25b commit 7fe466c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Expand Up @@ -965,7 +965,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry
if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)):
raise core.JaxprTypeError(
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
f'called with sequence of type\n{_avals_short(x_avals)}')
f'called with sequence whose items have type\n{_avals_short(x_avals_mapped)}')
return [*init_avals, *y_avals], jaxpr.effects

def _scan_pp_rule(eqn, context, settings):
Expand Down

0 comments on commit 7fe466c

Please sign in to comment.