diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index d8c84779b6c5..7df20bcef0cd 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -42,11 +42,11 @@ zip, unsafe_zip = safe_zip, zip -# Piles +# Jumbles # i:(Fin 3) => f32[[3, 1, 4].i] @dataclasses.dataclass(frozen=True) -class PileTy: +class JumbleTy: binder: core.Var length: Union[int, Tracer, core.Var] elt_ty: core.DShapedArray @@ -63,41 +63,41 @@ def __repr__(self) -> str: return f'{str(self.lengths)}.Var{id(self.idx)}' replace = dataclasses.replace -# Pile(aval=a:3 => f32[[3 1 4].a], -# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32)) +# Jumble(aval=a:3 => f32[[3 1 4].a], +# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32)) @dataclasses.dataclass(frozen=True) -class Pile: - aval: PileTy +class Jumble: + aval: JumbleTy data: Array -# To vmap over a pile, one must specify the axis as PileAxis. -class PileAxis: pass -pile_axis = PileAxis() +# To vmap over a jumble, one must specify the axis as JumbleAxis. +class JumbleAxis: pass +jumble_axis = JumbleAxis() # As a temporary measure before we have more general JITable / ADable interfaces -# (analogues to vmappable), to enable Piles to be used with other +# (analogues to vmappable), to enable Jumbles to be used with other # transformations and higher-order primitives (primarily jit, though also grad # with allow_int=True) we register them as pytrees. # TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration -def _pile_flatten(pile): +def _jumble_flatten(jumble): lengths = [] new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths)) if type(d) is IndexedAxisSize else d - for d in pile.aval.elt_ty.shape] - elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape)) - aval = pile.aval.replace(elt_ty=elt_ty) - return (lengths, pile.data), aval -def _pile_unflatten(aval, x): + for d in jumble.aval.elt_ty.shape] + elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) + aval = jumble.aval.replace(elt_ty=elt_ty) + return (lengths, jumble.data), aval +def _jumble_unflatten(aval, x): lengths, data = x new_shape = [d.replace(lengths=lengths[d.lengths - 1]) if type(d) is IndexedAxisSize else d for d in aval.elt_ty.shape] elt_ty = aval.elt_ty.update(shape=tuple(new_shape)) aval = aval.replace(elt_ty=elt_ty) - return Pile(aval, data) -register_pytree_node(Pile, _pile_flatten, _pile_unflatten) + return Jumble(aval, data) +register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten) -def _pile_result(axis_size, stacked_axis, ragged_axes, x): +def _jumble_result(axis_size, stacked_axis, ragged_axes, x): binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32'))) if stacked_axis != 0: raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0 @@ -106,14 +106,14 @@ def _pile_result(axis_size, stacked_axis, ragged_axes, x): for ragged_axis, segment_lens in ragged_axes: shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens) elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type) - return Pile(PileTy(binder, axis_size, elt_ty), x) + return Jumble(JumbleTy(binder, axis_size, elt_ty), x) @dataclasses.dataclass(frozen=True) class RaggedAxis: stacked_axis: int # For each axis, we store its index and the corresponding segment lengths. - # For example, the pile i:(Fin 3) => f32[lens1.i, 7, lens2.i] + # For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i] # would be represented with ragged_axes = [(1, lens1), (3, lens2)] ragged_axes: tuple[tuple[int, Array], ...] @@ -234,9 +234,9 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: handler = to_elt_handlers.get(type(x)) if handler: return handler(partial(to_elt, trace, get_idx), get_idx, x, spec) - elif type(x) is Pile: - if spec is not pile_axis: - raise TypeError("pile input without using pile_axis in_axes spec") + elif type(x) is Jumble: + if spec is not jumble_axis: + raise TypeError("jumble input without using jumble_axis in_axes spec") ias: IndexedAxisSize # Not present in the AxisSize union in core.py (d, ias), = ((i, sz) # type: ignore for i, sz in enumerate(x.aval.elt_ty.shape) @@ -259,10 +259,10 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec x_ = trace.full_raise(x) val, bdim = x_.val, x_.batch_dim if type(bdim) is RaggedAxis: - if spec is not pile_axis: + if spec is not jumble_axis: # TODO(mattjj): improve this error message - raise TypeError("ragged output without using pile_axis out_axes spec") - return _pile_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val) + raise TypeError("ragged output without using jumble_axis out_axes spec") + return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val) else: return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val) from_elt_handlers: dict[type, FromEltHandler] = {} @@ -284,7 +284,7 @@ def register_vmappable(data_type: type, spec_type: type, axis_size_type: type, from_elt_handlers[data_type] = from_elt if make_iota: make_iota_handlers[axis_size_type] = make_iota vmappables: dict[type, tuple[type, type]] = {} -spec_types: set[type] = {PileAxis} +spec_types: set[type] = {JumbleAxis} def unregister_vmappable(data_type: type) -> None: spec_type, axis_size_type = vmappables.pop(data_type) @@ -295,7 +295,7 @@ def unregister_vmappable(data_type: type) -> None: del make_iota_handlers[axis_size_type] def is_vmappable(x: Any) -> bool: - return type(x) is Pile or type(x) in vmappables + return type(x) is Jumble or type(x) in vmappables @lu.transformation_with_aux def flatten_fun_for_vmap(in_tree, *args_flat): @@ -1089,12 +1089,12 @@ def broadcast(x, sz, axis): return jax.lax.broadcast_in_dim(x, shape, broadcast_dims) def matchaxis(axis_name, sz, src, dst, x, sum_match=False): - if dst == pile_axis: + if dst == jumble_axis: x = bdim_at_front(x, src, sz) elt_ty = x.aval.update(shape=x.shape[1:]) - aval = PileTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))), - x.shape[0], elt_ty) - return Pile(aval, x) + aval = JumbleTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))), + x.shape[0], elt_ty) + return Jumble(aval, x) try: _ = core.get_aval(x) except TypeError as e: diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 04382a30cdb9..8bc3a3e94d7e 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -29,9 +29,9 @@ MakeIotaHandler as MakeIotaHandler, MapSpec as MapSpec, NotMapped as NotMapped, - Pile as Pile, - PileAxis as PileAxis, - PileTy as PileTy, + Jumble as Jumble, + JumbleAxis as JumbleAxis, + JumbleTy as JumbleTy, ToEltHandler as ToEltHandler, Vmappable as Vmappable, Zero as Zero, @@ -60,7 +60,7 @@ matchaxis as matchaxis, moveaxis as moveaxis, not_mapped as not_mapped, - pile_axis as pile_axis, + jumble_axis as jumble_axis, primitive_batchers as primitive_batchers, reducer_batcher as reducer_batcher, register_vmappable as register_vmappable, diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 6a4d0dd5fd36..0c3b45150b5f 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1490,68 +1490,69 @@ def f(i): @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow", jax_traceback_filtering='off') -class PileTest(jtu.JaxTestCase): +class JumbleTest(jtu.JaxTestCase): @parameterized.parameters((True,), (False,)) - def test_internal_pile(self, disable_jit): + def test_internal_jumble(self, disable_jit): config.update('jax_disable_jit', disable_jit) ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) xs = jax.vmap(lambda n: jax.lax.iota('int32', n).sum())(ins) self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False) - def test_pile_escapes(self): + def test_jumble_escapes(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) xs = jax.vmap(jax.jit(lambda n: jax.lax.iota('int32', n)), - out_axes=batching.pile_axis)(ins) - self.assertIsInstance(xs, batching.Pile) + out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(xs, batching.Jumble) data = jax.lax.broadcasted_iota('int32', (3, 5), 1) self.assertAllClose(xs.data, data, check_dtypes=False) - def test_make_pile_from_dynamic_shape(self): - # We may not want to support returning piles from vmapped functions (instead - # preferring to have a separate API which allows piles). But for now it - # makes for a convenient way to construct piles for the other tests! + def test_make_jumble_from_dynamic_shape(self): + # We may not want to support returning jumbles from vmapped functions + # (instead preferring to have a separate API which allows jumbles). But for + # now it makes for a convenient way to construct jumbles for the other + # tests! ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.pile_axis)(ins) - self.assertIsInstance(p, batching.Pile) + out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(p, batching.Jumble) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') data = jax.lax.broadcasted_iota('int32', (3, 5), 1) self.assertAllClose(p.data, data, check_dtypes=False) - def test_pile_map_eltwise(self): + def test_jumble_map_eltwise(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.pile_axis)(ins) - p = pile_map(jax.jit(lambda x: x ** 2))(p) - self.assertIsInstance(p, batching.Pile) + out_axes=batching.jumble_axis)(ins) + p = jumble_map(jax.jit(lambda x: x ** 2))(p) + self.assertIsInstance(p, batching.Jumble) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2 self.assertAllClose(p.data, data, check_dtypes=False) - def test_pile_map_vector_dot(self): + def test_jumble_map_vector_dot(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.pile_axis)(ins) - y = pile_map(jnp.dot)(p, p) - self.assertIsInstance(y, batching.Pile) + out_axes=batching.jumble_axis)(ins) + y = jumble_map(jnp.dot)(p, p) + self.assertIsInstance(y, batching.Jumble) self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32')) @parameterized.parameters((True,), (False,)) - def test_pile_map_matrix_dot_ragged_contract(self, disable_jit): + def test_jumble_map_matrix_dot_ragged_contract(self, disable_jit): config.update('jax_disable_jit', disable_jit) sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.pile_axis + p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.jumble_axis )(sizes) - p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.pile_axis + p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.jumble_axis )(sizes) - y = jax.vmap(jnp.dot, in_axes=batching.pile_axis, out_axes=0, + y = jax.vmap(jnp.dot, in_axes=batching.jumble_axis, out_axes=0, axis_size=3)(p1, p2) self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)), check_dtypes=False) @parameterized.parameters((True,), (False,)) - def test_pile_map_matrix_dot_ragged_tensor(self, disable_jit): + def test_jumble_map_matrix_dot_ragged_tensor(self, disable_jit): config.update('jax_disable_jit', disable_jit) sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): @@ -1559,8 +1560,8 @@ def func(size): lhs_two_d = jax.lax.broadcast_in_dim(lhs_one_d, (size, 2), (0,)) rhs = jax.lax.broadcasted_iota('int32', (2, 4), 0) + 1 return jnp.dot(lhs_two_d, rhs) - p = jax.vmap(func, out_axes=batching.pile_axis)(sizes) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(sizes) + self.assertIsInstance(p, batching.Jumble) self.assertEqual(p.data.shape, (3, 5, 4)) def test_broadcast_in_dim_while_ragged(self): @@ -1569,8 +1570,8 @@ def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jax.lax.broadcast_in_dim(one_d, (size, 7), (0,)) return two_d - p = jax.vmap(func, out_axes=batching.pile_axis)(ins) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(p, batching.Jumble) data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1) self.assertAllClose(p.data, data) @@ -1580,8 +1581,8 @@ def func(size): one_d = jnp.arange(12, dtype='int32') two_d = jax.lax.broadcast_in_dim(one_d, (size, 12), (1,)) return two_d - p = jax.vmap(func, out_axes=batching.pile_axis)(ins) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(p, batching.Jumble) data = jax.lax.broadcasted_iota('int32', (3, 5, 12), 2) self.assertAllClose(p.data, data) @@ -1595,7 +1596,7 @@ def func(size): return two_d msg = r"got operand of shape \(\[dynamic\],\), target broadcast shape \(4, 5\)" with self.assertRaisesRegex(TypeError, msg): - jax.vmap(func, out_axes=batching.pile_axis)(ins) + jax.vmap(func, out_axes=batching.jumble_axis)(ins) def test_broadcast_in_dim_to_doubly_ragged(self): ins1 = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) @@ -1604,8 +1605,8 @@ def func(size1, size2): one_d = jnp.arange(size1, dtype='int32') two_d = jax.lax.broadcast_in_dim(one_d, (size1, size2), (0,)) return two_d - p = jax.vmap(func, out_axes=batching.pile_axis)(ins1, ins2) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(ins1, ins2) + self.assertIsInstance(p, batching.Jumble) data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1) self.assertAllClose(p.data, data) @@ -1616,8 +1617,8 @@ def func(size): two_d = jax.lax.broadcast_in_dim(one_d, (size, 1), (0,)) one_again = jax.lax.squeeze(two_d, dimensions=[1]) return one_again - p = jax.vmap(func, out_axes=batching.pile_axis)(ins) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(p, batching.Jumble) data = jax.lax.broadcasted_iota('int32', (3, 5), 1) self.assertAllClose(p.data, data) @@ -1627,8 +1628,8 @@ def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jnp.broadcast_to(one_d, (4, size)) return two_d - p = jax.vmap(func, out_axes=batching.pile_axis)(ins) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(p, batching.Jumble) data = jax.lax.broadcasted_iota('int32', (3, 4, 5), 2) self.assertAllClose(p.data, data) @@ -1638,8 +1639,8 @@ def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jnp.broadcast_to(one_d, (size, size)) return two_d - p = jax.vmap(func, out_axes=batching.pile_axis)(ins) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(p, batching.Jumble) data = jax.lax.broadcasted_iota('int32', (3, 5, 5), 2) self.assertAllClose(p.data, data) @@ -1649,8 +1650,8 @@ def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jnp.broadcast_to(one_d, (7, size)) return jnp.transpose(two_d, [1, 0]) - p = jax.vmap(func, out_axes=batching.pile_axis)(ins) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(p, batching.Jumble) data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1) self.assertAllClose(p.data, data) @@ -1662,8 +1663,8 @@ def fprop_layer(x_size): wqkv = jax.lax.broadcasted_iota('int32', (3, 2, 7, 11), 1) qkv = jnp.einsum('te,ihqe->ithq', x, wqkv) return qkv - p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(x_sizes) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(x_sizes) + self.assertIsInstance(p, batching.Jumble) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') self.assertEqual(p.data.shape, (3, 3, 5, 2, 7)) @@ -1677,8 +1678,8 @@ def fprop_layer(ragged_size): v = jax.lax.broadcast_in_dim(one_d, (ragged_size, 2, 7), [0]) inner = jnp.einsum('tsh,shq->thq', alpha, v) return inner - p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(ragged_sizes) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(ragged_sizes) + self.assertIsInstance(p, batching.Jumble) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') self.assertEqual(p.data.shape, (3, 5, 2, 7)) @@ -1689,14 +1690,14 @@ def func(size): two_d = jnp.broadcast_to(one_d, (2, size)) part_1, part_2 = two_d return part_1 - p = jax.vmap(func, out_axes=batching.pile_axis)(ins) - self.assertIsInstance(p, batching.Pile) + p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) + self.assertIsInstance(p, batching.Jumble) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') data = jax.lax.broadcasted_iota('int32', (3, 5), 1) self.assertAllClose(p.data, data) @parameterized.parameters((True,), (False,)) - def test_pile_map_end_to_end_fprop_layer(self, disable_jit): + def test_jumble_map_end_to_end_fprop_layer(self, disable_jit): config.update('jax_disable_jit', disable_jit) def fprop_layer(params, x): @@ -1731,13 +1732,12 @@ def fprop_layer(params, x): jnp.zeros((420, 128)), ] - def pile_stack(xs: list[jax.Array]) -> batching.Pile: + def jumble_stack(xs: list[jax.Array]) -> batching.Jumble: max_length = max(len(x) for x in xs) lengths = jnp.array([len(x) for x in xs]) lengths = jax.lax.convert_element_type(lengths, core.bint(max_length)) xs_padded = jnp.stack([jnp.zeros((max_length, 128), dtype=x.dtype ).at[:x.shape[0]].set(x) for x in xs]) - # jax.vmap(lambda l, xp: xp[:l, :], out_axes=pile_axis)(lengths, xs_padded) # binder = i binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32'))) @@ -1745,26 +1745,26 @@ def pile_stack(xs: list[jax.Array]) -> batching.Pile: elt_ty = core.DShapedArray((batching.IndexedAxisSize(binder, lengths), 128), xs_padded.dtype) # aval = i:(Fin 3) => f32[[3, 1, 4].i, 128] - aval = batching.PileTy(binder, len(xs), elt_ty) - xs_pile = batching.Pile(aval, xs_padded) - return xs_pile + aval = batching.JumbleTy(binder, len(xs), elt_ty) + xs_jumble = batching.Jumble(aval, xs_padded) + return xs_jumble - xs_pile = pile_stack(xs) + xs_jumble = jumble_stack(xs) fprop_batched = jax.vmap(fprop_layer, - in_axes=(None, batching.pile_axis), - out_axes=batching.pile_axis, + in_axes=(None, batching.jumble_axis), + out_axes=batching.jumble_axis, axis_size=3) - - result_pile = fprop_batched(params, xs_pile) - self.assertIsInstance(result_pile, batching.Pile) - self.assertRegex(str(result_pile.aval), r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]') - self.assertAllClose(result_pile.data.shape, (3, 512, 128)) - -def pile_map(f): - def mapped(*piles): - return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis, - axis_size=piles[0].aval.length)(*piles) + result_jumble = fprop_batched(params, xs_jumble) + self.assertIsInstance(result_jumble, batching.Jumble) + regex = r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]' + self.assertRegex(str(result_jumble.aval), regex) + self.assertAllClose(result_jumble.data.shape, (3, 512, 128)) + +def jumble_map(f): + def mapped(*jumbles): + return jax.vmap(f, in_axes=batching.jumble_axis, out_axes=batching.jumble_axis, + axis_size=jumbles[0].aval.length)(*jumbles) return mapped if __name__ == '__main__':