Skip to content

Commit

Permalink
Teach make_batch_axis to canonicalize the axes as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Jun 8, 2023
1 parent 611c12b commit ab59fef
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
13 changes: 7 additions & 6 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ def size(self):
return len(self.ragged_axes[0][1])


def make_batch_axis(stacked_axis, ragged_axes):
def make_batch_axis(ndim, stacked_axis, ragged_axes):
if ragged_axes:
return RaggedAxis(stacked_axis, ragged_axes)
canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
return RaggedAxis(canonicalize_axis(stacked_axis, ndim), canonical)
else:
return stacked_axis
return canonicalize_axis(stacked_axis, ndim)


def _update_annotation(
Expand Down Expand Up @@ -196,8 +197,8 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
raise TypeError("pile input without using pile_axis in_axes spec")
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
if type(sz) is IndexedAxisSize)
return BatchTracer(
trace, x.data, make_batch_axis(0, [(d+1, ias.lengths)])) # type: ignore
batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)])
return BatchTracer(trace, x.data, batch_axis) # type: ignore
elif isinstance(spec, int) or spec is None:
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
return (BatchTracer(trace, x, spec, source_info_util.current())
Expand Down Expand Up @@ -905,7 +906,7 @@ def out_axis(axes, axis):
operand = mask_ragged_axes(
operand, ident, RaggedAxis(bdim.stacked_axis, axes_to_mask))
result = prim.bind(operand, axes=axes, **params)
return result, make_batch_axis(bdim_out, ragged_axes_out)
return result, make_batch_axis(operand.ndim, bdim_out, ragged_axes_out)
else:
assert False

Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2859,7 +2859,8 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits)
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions)
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
out_bdim = batching.make_batch_axis(0, zip(out_ragged_axes, out_ragged_sizes))
out_bdim = batching.make_batch_axis(
result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes))
return result, out_bdim

def _broadcast_in_dim_fwd_rule(eqn):
Expand Down

0 comments on commit ab59fef

Please sign in to comment.