Skip to content

Commit

Permalink
simplify standard named_shape_rule
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
2 people authored and jekbradbury committed Mar 9, 2021
1 parent c622422 commit e779ed8
Showing 1 changed file with 31 additions and 56 deletions.
87 changes: 31 additions & 56 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,10 +1975,9 @@ def _argnum_weak_type(*argnums):
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None,
weak_type_rule=None, named_shape_rule=None):
weak_type_rule = weak_type_rule or _standard_weak_type_rule
named_shape_rule = named_shape_rule or standard_named_shape_rule
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
named_shape_rule = named_shape_rule or partial(
fallback_named_shape_rule, prim)
prim.def_abstract_eval(
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
weak_type_rule, named_shape_rule))
Expand All @@ -2005,7 +2004,8 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
raise TypeError(avals, least_specialized)

def standard_multi_result_abstract_eval(
prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs):
prim, shape_rule, dtype_rule, weak_type_rule,
named_shape_rule, *avals, **kwargs):
assert prim.multiple_results
assert all(isinstance(aval, UnshapedArray) for aval in avals), avals
least_specialized = _max(map(type, avals),
Expand All @@ -2018,66 +2018,23 @@ def standard_multi_result_abstract_eval(
elif least_specialized is ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
return [ShapedArray(s, d, weak_type=weak_type)
for s, d, weak_type in safe_zip(out_shapes, out_dtypes, weak_types)]
out_named_shapes = named_shape_rule(*avals, **kwargs)
return [ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape)
for s, d, weak_type, named_shape
in safe_zip(out_shapes, out_dtypes, weak_types, out_named_shapes)]
elif least_specialized is UnshapedArray:
out_dtypes = dtype_rule(*avals, **kwargs)
return [UnshapedArray(dtype, weak_type=weak_type)
for dtype, weak_type in safe_zip(out_dtypes, weak_types)]
else:
raise TypeError(avals, least_specialized)


def standard_translate(name, c, *args, **kwargs):
xla_opname = ''.join(term.capitalize() for term in name.split('_'))
return getattr(xops, xla_opname)(*args, **kwargs)


@lu.transformation_with_aux
def get_dims_out(dims_in, *args, **kwargs):
vals_out, dims_out = yield (args, dims_in), kwargs
yield vals_out, dims_out

def fallback_named_shape_rule(prim, *avals, **params):
all_named_shape_tuples = set(
item for aval in avals for item in aval.named_shape.items())
out_named_shape = None
# note: this relies on independence of batching over different axes
# (= commutativity of batching rules). That isn't true for at least
# `while_loop`.
for name, size in all_named_shape_tuples:
vmap_avals = [aval.update(shape=(size, *aval.shape),
weak_type=False).strip_named_shape()
if name in aval.named_shape else aval for aval in avals]
vmap_dims = [0 if name in aval.named_shape else batching.not_mapped
for aval in avals]
if prim in batching.collective_rules:
raise NotImplementedError(
f"collective {prim} should have a custom abstract_eval rule")
elif prim in batching.initial_style_batchers:
rule = partial(batching.initial_style_batchers[prim], axis_name=name)
elif prim in batching.primitive_batchers:
rule = batching.primitive_batchers[prim]
else:
raise NotImplementedError(f"primitive {prim} cannot be used with named "
f"axes because it has no batching rule")
f = lu.wrap_init(rule, params)
f, vmap_dims_out = get_dims_out(f, vmap_dims)
_, in_tree = tree_util.tree_flatten(vmap_avals)
f, _ = api_util.flatten_fun_nokwargs(f, in_tree)
if config.omnistaging_enabled:
pe.trace_to_jaxpr_dynamic(f, vmap_avals)
else:
vmap_pvals = [pe.PartialVal.unknown(aval) for aval in vmap_avals]
pe.trace_to_jaxpr(f, vmap_pvals)
dims_out = vmap_dims_out()
if not prim.multiple_results: dims_out = [dims_out]
mapped_out = [d is not batching.not_mapped for d in dims_out]
if out_named_shape is None:
out_named_shape = [{} for m in mapped_out]
out_named_shape = [{name: size, **ns} if m else ns
for ns, m in safe_zip(out_named_shape, mapped_out)]
return out_named_shape
def standard_named_shape_rule(*avals, **kwargs):
return core.join_named_shapes(*(a.named_shape for a in avals))


def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
Expand Down Expand Up @@ -5011,15 +4968,19 @@ def _reduce_translation_rule(c, *values, computation, jaxpr,

def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
consts, dimensions):
# TODO(mattjj,frostig): use batch_jaxpr, delete computation (assumes poly??)
num_operands = len(batched_args) // 2
operands, init_values = split_list(batched_args, [num_operands])
operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
if all(init_value_bdim is None for init_value_bdim in init_value_bdims):
if all(init_value_bdim is batching.not_mapped
for init_value_bdim in init_value_bdims):
# Assume all batch dims are the same for each of the operands
assert all(operand_bdim is not None for operand_bdim in operand_bdims)
assert all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims)
# TODO(sharadmv): handle the case when batch dims are different across
# operands or when some are unbatched
if not all(operand_bdim is not batching.not_mapped for operand_bdim in operand_bdims):
raise NotImplementedError
if not all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims):
raise NotImplementedError
operand_bdim = operand_bdims[0]
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim)))
Expand Down Expand Up @@ -5060,12 +5021,26 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
bind = prim_bind if input_shape is None else partial(prim_bind, input_shape=padded_shape)
return bind(masked_val, axes=axes)

def _reduce_named_shape_rule(*avals, computation, jaxpr, consts, dimensions):
# TODO(mattjj,frostig): see the TODOs noting limitations/assumptions in
# _reduce_batching_rule. We're making the same assumptions here for now.
num_operands = len(avals) // 2
operand_avals, init_avals = split_list(avals, [num_operands])
if any(a.named_shape for a in init_avals):
raise NotImplementedError
named_shapes = [a.named_shape for a in operand_avals]
if not all(named_shapes[0] == named_shape for named_shape in named_shapes):
raise NotImplementedError
return named_shapes


reduce_p = core.Primitive('reduce')
reduce_p.multiple_results = True
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p))
reduce_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
_reduce_dtype_rule, _reduce_weak_type_rule))
_reduce_dtype_rule, _reduce_weak_type_rule,
_reduce_named_shape_rule))
xla.translations[reduce_p] = _reduce_translation_rule
batching.primitive_batchers[reduce_p] = _reduce_batch_rule

Expand Down

0 comments on commit e779ed8

Please sign in to comment.