Skip to content

Commit

Permalink
Fix silly type error involving dims_out sometimes being a thunk and…
Browse files Browse the repository at this point in the history
… sometimes not.

PiperOrigin-RevId: 548343565
  • Loading branch information
axch authored and jax authors committed Jul 15, 2023
1 parent 651f877 commit cd39128
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def process_call(self, call_primitive, f, tracers, params):
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out())
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())]
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]

def post_process_call(self, call_primitive, out_tracers, params):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
Expand Down

0 comments on commit cd39128

Please sign in to comment.