Skip to content

Commit

Permalink
Test and implement squeeze under ragged batching.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Jun 8, 2023
1 parent ab59fef commit 5aa6cc3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
20 changes: 20 additions & 0 deletions jax/_src/interpreters/batching.py
Expand Up @@ -123,6 +123,16 @@ def size(self):
# same length!
return len(self.ragged_axes[0][1])

def move_stacked_axis(self, dst):
# Assumes that all stored and incoming axes are already canonicalized
def move_axis(ax):
if self.stacked_axis > ax and ax >= dst:
return ax + 1
if self.stacked_axis < ax and ax <= dst:
return ax - 1
return ax
new_ragged_axes = [(move_axis(ax), sizes) for ax, sizes in self.ragged_axes]
return RaggedAxis(dst, new_ragged_axes)

def make_batch_axis(ndim, stacked_axis, ragged_axes):
if ragged_axes:
Expand Down Expand Up @@ -931,6 +941,16 @@ def _mask_one_ragged_axis(operand, ident, axis_spec):
mask = positions < limits
return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape))

def move_stacked_axis(operand, bdim, dst):
dst = canonicalize_axis(dst, operand.ndim)
if isinstance(bdim, int):
return moveaxis(operand, bdim, dst), dst
elif isinstance(bdim, RaggedAxis):
result = moveaxis(operand, bdim.stacked_axis, dst)
return result, bdim.move_stacked_axis(dst)
else:
raise TypeError("Unrecognized batch dimension type {}".format(bdim))

### general utilities for manipulating axes on jaxpr types (not vmappables)

def broadcast(x, sz, axis):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Expand Up @@ -3232,9 +3232,9 @@ def _squeeze_transpose_rule(t, operand, *, dimensions):
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
operand, = batched_args
bdim, = batch_dims
operand = batching.moveaxis(operand, bdim, 0)
operand, bdim_out = batching.move_stacked_axis(operand, bdim, 0)
dimensions = tuple(np.add(1, dimensions))
return squeeze(operand, dimensions=dimensions), 0
return squeeze(operand, dimensions=dimensions), bdim_out

squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
'squeeze')
Expand Down
12 changes: 12 additions & 0 deletions tests/dynamic_api_test.py
Expand Up @@ -1574,6 +1574,18 @@ def func(size1, size2):
data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1)
self.assertAllClose(p.data, data)

def test_squeeze_ragged(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
one_d = jnp.arange(size, dtype='int32')
two_d = jax.lax.broadcast_in_dim(one_d, (size, 1), (0,))
one_again = jax.lax.squeeze(two_d, dimensions=[1])
return one_again
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
self.assertAllClose(p.data, data)

def pile_map(f):
def mapped(*piles):
return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,
Expand Down

0 comments on commit 5aa6cc3

Please sign in to comment.