From 5aa6cc354278f47d992b1fe2845d24ab090a0c9f Mon Sep 17 00:00:00 2001 From: Alexey Radul Date: Tue, 6 Jun 2023 09:45:32 -0400 Subject: [PATCH] Test and implement squeeze under ragged batching. --- jax/_src/interpreters/batching.py | 20 ++++++++++++++++++++ jax/_src/lax/lax.py | 4 ++-- tests/dynamic_api_test.py | 12 ++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index e1ebd88a493d..2c67230c324c 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -123,6 +123,16 @@ def size(self): # same length! return len(self.ragged_axes[0][1]) + def move_stacked_axis(self, dst): + # Assumes that all stored and incoming axes are already canonicalized + def move_axis(ax): + if self.stacked_axis > ax and ax >= dst: + return ax + 1 + if self.stacked_axis < ax and ax <= dst: + return ax - 1 + return ax + new_ragged_axes = [(move_axis(ax), sizes) for ax, sizes in self.ragged_axes] + return RaggedAxis(dst, new_ragged_axes) def make_batch_axis(ndim, stacked_axis, ragged_axes): if ragged_axes: @@ -931,6 +941,16 @@ def _mask_one_ragged_axis(operand, ident, axis_spec): mask = positions < limits return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape)) +def move_stacked_axis(operand, bdim, dst): + dst = canonicalize_axis(dst, operand.ndim) + if isinstance(bdim, int): + return moveaxis(operand, bdim, dst), dst + elif isinstance(bdim, RaggedAxis): + result = moveaxis(operand, bdim.stacked_axis, dst) + return result, bdim.move_stacked_axis(dst) + else: + raise TypeError("Unrecognized batch dimension type {}".format(bdim)) + ### general utilities for manipulating axes on jaxpr types (not vmappables) def broadcast(x, sz, axis): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c8470f687180..6970a75ad9aa 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3232,9 +3232,9 @@ def _squeeze_transpose_rule(t, operand, *, dimensions): def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): operand, = batched_args bdim, = batch_dims - operand = batching.moveaxis(operand, bdim, 0) + operand, bdim_out = batching.move_stacked_axis(operand, bdim, 0) dimensions = tuple(np.add(1, dimensions)) - return squeeze(operand, dimensions=dimensions), 0 + return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, 'squeeze') diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index e98411594634..99d5260bc7f1 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1574,6 +1574,18 @@ def func(size1, size2): data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1) self.assertAllClose(p.data, data) + def test_squeeze_ragged(self): + ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) + def func(size): + one_d = jnp.arange(size, dtype='int32') + two_d = jax.lax.broadcast_in_dim(one_d, (size, 1), (0,)) + one_again = jax.lax.squeeze(two_d, dimensions=[1]) + return one_again + p = jax.vmap(func, out_axes=batching.pile_axis)(ins) + self.assertIsInstance(p, batching.Pile) + data = jax.lax.broadcasted_iota('int32', (3, 5), 1) + self.assertAllClose(p.data, data) + def pile_map(f): def mapped(*piles): return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,