Skip to content

Commit

Permalink
Rename Piles to Jumbles, to avoid unfortunate Imperial entanglements.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Jul 14, 2023
1 parent f348366 commit fbb5872
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 103 deletions.
66 changes: 33 additions & 33 deletions jax/_src/interpreters/batching.py
Expand Up @@ -42,11 +42,11 @@
zip, unsafe_zip = safe_zip, zip


# Piles
# Jumbles

# i:(Fin 3) => f32[[3, 1, 4].i]
@dataclasses.dataclass(frozen=True)
class PileTy:
class JumbleTy:
binder: core.Var
length: Union[int, Tracer, core.Var]
elt_ty: core.DShapedArray
Expand All @@ -63,41 +63,41 @@ def __repr__(self) -> str:
return f'{str(self.lengths)}.Var{id(self.idx)}'
replace = dataclasses.replace

# Pile(aval=a:3 => f32[[3 1 4].a],
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
# Jumble(aval=a:3 => f32[[3 1 4].a],
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
@dataclasses.dataclass(frozen=True)
class Pile:
aval: PileTy
class Jumble:
aval: JumbleTy
data: Array

# To vmap over a pile, one must specify the axis as PileAxis.
class PileAxis: pass
pile_axis = PileAxis()
# To vmap over a jumble, one must specify the axis as JumbleAxis.
class JumbleAxis: pass
jumble_axis = JumbleAxis()

# As a temporary measure before we have more general JITable / ADable interfaces
# (analogues to vmappable), to enable Piles to be used with other
# (analogues to vmappable), to enable Jumbles to be used with other
# transformations and higher-order primitives (primarily jit, though also grad
# with allow_int=True) we register them as pytrees.
# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration
def _pile_flatten(pile):
def _jumble_flatten(jumble):
lengths = []
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
if type(d) is IndexedAxisSize else d
for d in pile.aval.elt_ty.shape]
elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape))
aval = pile.aval.replace(elt_ty=elt_ty)
return (lengths, pile.data), aval
def _pile_unflatten(aval, x):
for d in jumble.aval.elt_ty.shape]
elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape))
aval = jumble.aval.replace(elt_ty=elt_ty)
return (lengths, jumble.data), aval
def _jumble_unflatten(aval, x):
lengths, data = x
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
if type(d) is IndexedAxisSize else d
for d in aval.elt_ty.shape]
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
aval = aval.replace(elt_ty=elt_ty)
return Pile(aval, data)
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
return Jumble(aval, data)
register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten)

def _pile_result(axis_size, stacked_axis, ragged_axes, x):
def _jumble_result(axis_size, stacked_axis, ragged_axes, x):
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
if stacked_axis != 0:
raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0
Expand All @@ -106,14 +106,14 @@ def _pile_result(axis_size, stacked_axis, ragged_axes, x):
for ragged_axis, segment_lens in ragged_axes:
shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens)
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
return Pile(PileTy(binder, axis_size, elt_ty), x)
return Jumble(JumbleTy(binder, axis_size, elt_ty), x)


@dataclasses.dataclass(frozen=True)
class RaggedAxis:
stacked_axis: int
# For each axis, we store its index and the corresponding segment lengths.
# For example, the pile i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
ragged_axes: tuple[tuple[int, Array], ...]

Expand Down Expand Up @@ -234,9 +234,9 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
handler = to_elt_handlers.get(type(x))
if handler:
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
elif type(x) is Pile:
if spec is not pile_axis:
raise TypeError("pile input without using pile_axis in_axes spec")
elif type(x) is Jumble:
if spec is not jumble_axis:
raise TypeError("jumble input without using jumble_axis in_axes spec")
ias: IndexedAxisSize # Not present in the AxisSize union in core.py
(d, ias), = ((i, sz) # type: ignore
for i, sz in enumerate(x.aval.elt_ty.shape)
Expand All @@ -259,10 +259,10 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
x_ = trace.full_raise(x)
val, bdim = x_.val, x_.batch_dim
if type(bdim) is RaggedAxis:
if spec is not pile_axis:
if spec is not jumble_axis:
# TODO(mattjj): improve this error message
raise TypeError("ragged output without using pile_axis out_axes spec")
return _pile_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
raise TypeError("ragged output without using jumble_axis out_axes spec")
return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
else:
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
from_elt_handlers: dict[type, FromEltHandler] = {}
Expand All @@ -284,7 +284,7 @@ def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
from_elt_handlers[data_type] = from_elt
if make_iota: make_iota_handlers[axis_size_type] = make_iota
vmappables: dict[type, tuple[type, type]] = {}
spec_types: set[type] = {PileAxis}
spec_types: set[type] = {JumbleAxis}

def unregister_vmappable(data_type: type) -> None:
spec_type, axis_size_type = vmappables.pop(data_type)
Expand All @@ -295,7 +295,7 @@ def unregister_vmappable(data_type: type) -> None:
del make_iota_handlers[axis_size_type]

def is_vmappable(x: Any) -> bool:
return type(x) is Pile or type(x) in vmappables
return type(x) is Jumble or type(x) in vmappables

@lu.transformation_with_aux
def flatten_fun_for_vmap(in_tree, *args_flat):
Expand Down Expand Up @@ -1089,12 +1089,12 @@ def broadcast(x, sz, axis):
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)

def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
if dst == pile_axis:
if dst == jumble_axis:
x = bdim_at_front(x, src, sz)
elt_ty = x.aval.update(shape=x.shape[1:])
aval = PileTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
x.shape[0], elt_ty)
return Pile(aval, x)
aval = JumbleTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
x.shape[0], elt_ty)
return Jumble(aval, x)
try:
_ = core.get_aval(x)
except TypeError as e:
Expand Down
8 changes: 4 additions & 4 deletions jax/interpreters/batching.py
Expand Up @@ -29,9 +29,9 @@
MakeIotaHandler as MakeIotaHandler,
MapSpec as MapSpec,
NotMapped as NotMapped,
Pile as Pile,
PileAxis as PileAxis,
PileTy as PileTy,
Jumble as Jumble,
JumbleAxis as JumbleAxis,
JumbleTy as JumbleTy,
ToEltHandler as ToEltHandler,
Vmappable as Vmappable,
Zero as Zero,
Expand Down Expand Up @@ -60,7 +60,7 @@
matchaxis as matchaxis,
moveaxis as moveaxis,
not_mapped as not_mapped,
pile_axis as pile_axis,
jumble_axis as jumble_axis,
primitive_batchers as primitive_batchers,
reducer_batcher as reducer_batcher,
register_vmappable as register_vmappable,
Expand Down

0 comments on commit fbb5872

Please sign in to comment.