Skip to content

Commit

Permalink
Test and implement broadcast_in_dim for all permutations of ragged axes.
Browse files Browse the repository at this point in the history
Add tests for
- Broadcasting an already-ragged array
- Broadcasting that creates an array that's ragged in two dimensions
  • Loading branch information
axch committed Jun 8, 2023
1 parent 0119945 commit 611c12b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 17 deletions.
64 changes: 47 additions & 17 deletions jax/_src/lax/lax.py
Expand Up @@ -777,7 +777,8 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
operand: an array
shape: the shape of the target array
broadcast_dimensions: to which dimension in the target shape each dimension
of the operand shape corresponds to
of the operand shape corresponds to. That is, dimension i of the operand
becomes dimension broadcast_dimensions[i] of the result.
Returns:
An array containing the result.
Expand Down Expand Up @@ -2811,26 +2812,55 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,

def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
broadcast_dimensions):
# `dyn_shape` is the dynamic portion of the target shape. `shape`
# is the target shape, with `None` for dynamic sections.
# broadcast_dimensions gives indices where dimensions of the input
# have to go: dimension i of the input becomes dimension
# broadcast_dimensions[i] of the output.
operand, *dyn_shape = batched_args
operand_bdim, *dyn_shape_bdims = batch_dims
if len(dyn_shape) > 1: raise NotImplementedError
if (operand_bdim is not None and
(not dyn_shape_bdims or dyn_shape_bdims[0] is None)):
new_operand = batching.moveaxis(operand, operand_bdim, 0)
new_shape = (operand.shape[operand_bdim],) + _merge_dyn_shape(shape, dyn_shape)

stacked_size = None
if operand_bdim is not None:
if isinstance(operand_bdim, RaggedAxis):
stacked_axis = operand_bdim.stacked_axis
else:
stacked_axis = operand_bdim
new_operand = batching.moveaxis(operand, stacked_axis, 0)
if isinstance(operand_bdim, RaggedAxis):
stacked_size = operand_bdim.size
else:
stacked_size = operand.shape[stacked_axis]
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0
elif (operand_bdim is None and dyn_shape_bdims and
dyn_shape_bdims[0] is not None):
(d,), (d_bdim,) = dyn_shape, dyn_shape_bdims # NotImplementedError above
assert d_bdim == 0 # must be scalar in the program to be batched
bound = d.dtype.bound
new_shape = (len(d),) + _merge_dyn_shape(shape, (bound,))
out = broadcast_in_dim(operand, new_shape, broadcast_dimensions)
idx, = (i for i, s in enumerate(shape) if s is None)
return out, batching.RaggedAxis(0, [(idx+1, d)])
else:
raise NotImplementedError # TODO(mattjj,axch)
new_operand = operand
new_broadcast_dimensions = tuple(np.add(1, broadcast_dimensions))

# TODO(reviewer) This section assumes that the shape of the operand
# is broadcast-compatible with the requested shape. Where should
# that be checked, and what should be new rules be, in light of
# raggedness?
dyn_limits = []
out_ragged_sizes = []
for sizes, bdim in zip(dyn_shape, dyn_shape_bdims):
if bdim is None:
# TODO(mattjj,axch) Is this what bdim == None means?
assert isinstance(sizes, int)
bound = sizes
else:
bound = sizes.dtype.bound
out_ragged_sizes.append(sizes)
if stacked_size is None:
stacked_size = len(sizes)
else:
msg = "All segments lengths arrays must be the same length"
assert len(sizes) == stacked_size, msg
dyn_limits.append(bound)
new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits)
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions)
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
out_bdim = batching.make_batch_axis(0, zip(out_ragged_axes, out_ragged_sizes))
return result, out_bdim

def _broadcast_in_dim_fwd_rule(eqn):
v, *dyn = eqn.invars
Expand Down
34 changes: 34 additions & 0 deletions tests/dynamic_api_test.py
Expand Up @@ -1540,6 +1540,40 @@ def test_pile_map_matrix_dot(self):
self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)),
check_dtypes=False)

def test_broadcast_while_ragged(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
one_d = jnp.arange(size, dtype='int32')
two_d = jax.lax.broadcast_in_dim(one_d, (size, 7), (0,))
return two_d
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1)
self.assertAllClose(p.data, data)

def test_broadcast_to_ragged(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
one_d = jnp.arange(12, dtype='int32')
two_d = jax.lax.broadcast_in_dim(one_d, (size, 12), (1,))
return two_d
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
data = jax.lax.broadcasted_iota('int32', (3, 5, 12), 2)
self.assertAllClose(p.data, data)

def test_broadcast_to_doubly_ragged(self):
ins1 = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
ins2 = lax.convert_element_type(jnp.array([2, 5, 1]), core.bint(6))
def func(size1, size2):
one_d = jnp.arange(size1, dtype='int32')
two_d = jax.lax.broadcast_in_dim(one_d, (size1, size2), (0,))
return two_d
p = jax.vmap(func, out_axes=batching.pile_axis)(ins1, ins2)
self.assertIsInstance(p, batching.Pile)
data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1)
self.assertAllClose(p.data, data)

def pile_map(f):
def mapped(*piles):
return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,
Expand Down

0 comments on commit 611c12b

Please sign in to comment.