Skip to content

Commit

Permalink
[remove-units] avoid making xmap partial eval deal with units
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 28, 2022
1 parent b852778 commit ca112da
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 150 deletions.
180 changes: 72 additions & 108 deletions jax/experimental/maps.py
Expand Up @@ -48,9 +48,10 @@
from jax.interpreters import ad
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.util import (safe_map, safe_zip, HashableFunction,
as_hashable_function, unzip2, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name)
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2,
as_hashable_function, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name,
merge_lists, partition_list)
from jax import lax

source_info_util.register_exclusion(__file__)
Expand Down Expand Up @@ -1126,122 +1127,85 @@ def restore_units(is_unit, vals):


def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params):
from jax.interpreters.partial_eval import (
PartialVal, JaxprTracer, _drop_vars, _dce_open_jaxpr,
convert_constvars_jaxpr, new_eqn_recipe)
assert primitive is xmap_p
in_axes = params['in_axes']
donated_invars = params['donated_invars']
global_axis_sizes = params['global_axis_sizes']

in_pvals = [t.pval for t in tracers]
in_pvals = [pval if pval.is_known()
else PartialVal.unknown(_delete_aval_axes(pval[0], axes, global_axis_sizes))
for pval, axes in zip(in_pvals, in_axes)]

const_axes_s = lu.Store()
def app(f, *args):
args_no_units, in_units = filter_units(args)
f, out_units = hide_units(f, tuple(in_units))
f, out_named_shapes = out_local_named_shapes(f, frozenset(global_axis_sizes))
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
out_axes = out_axes_thunk()
axes_units, const_units = split_list(out_units(), [len(out_axes)])
assert not any(const_units)
num_consts = len(const_units)
out_axes_no_units = [a for a, u in zip(out_axes, axes_units) if not u]
const_axes: Sequence[AxisNamePos]
if num_consts == 0:
const_axes = ()
else:
const_axes = [
AxisNamePos(zip(sort_named_shape, range(len(sort_named_shape))),
user_repr=f'<internal: {sort_named_shape}>')
for named_shape in out_named_shapes()[-num_consts:]
# We sort here to make the iteration order deterministic
for sort_named_shape in [sorted(named_shape, key=str)]
]
if not const_axes_s: # NOTE: This can be called multiple times
const_axes_s.store(const_axes)
assert const_axes_s.val == const_axes
return (*out_axes_no_units, *const_axes)
pe_params = dict(
params,
in_axes=tuple(a for a, u in zip(in_axes, in_units) if not u),
donated_invars=tuple(a for a, u in zip(donated_invars, in_units) if not u),
out_axes_thunk=new_out_axes_thunk)
outs_no_units = primitive.bind(f, *args_no_units, **pe_params)
new_out_axes_thunk() # Make sure it is called at least once to compute const_axes
return restore_units(out_units(), outs_no_units)

jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
f, in_pvals, app, instantiate=False)

out_axes = params['out_axes_thunk']()
const_axes = const_axes_s.val
out_axes_thunk = params['out_axes_thunk']

# Adjust input tracers' pvals for mapped axes, and unpack.
in_pvals = [t.pval if t.pval.is_known() else
pe.PartialVal.unknown(
_delete_aval_axes(t.pval.get_aval(), axes, global_axis_sizes))
for t, axes in zip(tracers, in_axes)]
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)

# Wrap f to perform partial evaluation, and plumb out aux data.
f = pe.trace_to_subjaxpr_nounits(f, self.main, False)
f, aux = pe.partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
# Also grab the local named shapes of the output (known and res).
f, out_named_shapes = out_local_named_shapes(f, frozenset(global_axis_sizes))

# Adjust params for knowns (donated_invars, in_axes, out_axes_thunk).
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
out_knowns, _, _, _ = aux()
_, out_axes = partition_list(out_knowns, out_axes_thunk())
return (*out_axes, *res_axes())
def res_axes():
_, _, jaxpr, _ = aux()
num_res = len(jaxpr.constvars)
res_named_shapes = out_named_shapes()[-num_res:] if num_res else []
sorted_named_shapes = [sorted(ns, key=str) for ns in res_named_shapes]
return [AxisNamePos(zip(named_shape, range(len(named_shape))),
user_repr=f'<internal: {named_shape}>')
for named_shape in sorted_named_shapes]
known_params = dict(
params, in_axes=tuple(a for a, k in zip(in_axes, in_knowns) if k),
donated_invars=tuple(d for d, k in zip(donated_invars, in_knowns) if k),
out_axes_thunk=new_out_axes_thunk)

# Run the known part.
outs = primitive.bind(f, *in_consts, **known_params)
out_knowns, out_avals, jaxpr, env = aux()
known_outvals, res = split_list(outs, [len(outs) - len(jaxpr.constvars)])
with core.extend_axis_env_nd(global_axis_sizes.items()):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)

# Set up new params.
out_axes = [a for a, k in zip(out_axes_thunk(), out_knowns) if not k]
unknown_params = dict(
params, call_jaxpr=jaxpr, out_axes=tuple(out_axes), spmd_out_axes=None,
donated_invars=(*(False for _ in res),
*(d for d, t in zip(donated_invars, tracers)
if not t.pval.is_known())),
in_axes=(*res_axes(), *(None for _ in env),
*(a for a, t in zip(in_axes, tracers) if not t.pval.is_known())))
del unknown_params['out_axes_thunk']
del unknown_params['spmd_out_axes_thunk']
# Create input tracers for unknown part.
res_tracers = map(self.new_instantiated_const, res)
env_tracers = map(self.full_raise, env)
unknown_arg_tracers = [t for t in tracers if not t.pval.is_known()]
# Create output tracers for unknown part, adjusting avals.
axis_resource_count = _get_axis_resource_count(
params['axis_resources'], params['resource_env'],
params['in_positional_semantics'])
local_axis_sizes = {
axis: axis_resource_count[axis].to_local(params['out_positional_semantics'], global_size)
for axis, global_size in global_axis_sizes.items()
}
out_pvals = [pval if pval.is_known() else
PartialVal.unknown(_insert_aval_axes(pval[0], axes, local_axis_sizes))
for pval, axes in zip(out_pvals, out_axes)]

with core.extend_axis_env_nd(global_axis_sizes.items()):
# Skip known invars and outvars, and lift constants as regular invars
in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
jaxpr = _drop_vars(jaxpr, in_knowns, (False,) * len(jaxpr.outvars))
jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
jaxpr = convert_constvars_jaxpr(jaxpr)

# Known tracers get propagated as if they were constants
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
if pval.is_known()]

# I'm not 100% if that's correct, but it is an assumption that
# JaxprTrace.process_call already makes.
if any(t.pval.is_known() for t in env_tracers):
raise AssertionError("Please open a bug report!")
# Unknown tracers need to have the jaxpr set up as their recipe
unknown_tracers_in = (*env_tracers, *(t for t in tracers if not t.pval.is_known()))
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
if not pval.is_known()]
const_tracers = map(self.new_instantiated_const, consts)

# Set up new params
new_in_axes = (*const_axes,
*(None for _ in env_tracers),
*(axis for axis, t in zip(in_axes, tracers)
if not t.pval.is_known()))
new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals)
if not pval.is_known())

assert params['spmd_in_axes'] is None and params['spmd_out_axes_thunk'] is None
new_params = dict(
params,
call_jaxpr=jaxpr,
donated_invars=(*(False for _ in const_tracers),
*(d for d, t in zip(donated_invars, tracers) if not t.pval.is_known())),
in_axes=new_in_axes,
out_axes=new_out_axes,
spmd_out_axes=None)
del new_params['out_axes_thunk']
del new_params['spmd_out_axes_thunk']

eqn = new_eqn_recipe((*const_tracers, *unknown_tracers_in),
unknown_tracers_out,
primitive, new_params, jaxpr.effects, source_info_util.current())
ax: axis_resource_count[ax].to_local(
params['out_positional_semantics'], global_size)
for ax, global_size in global_axis_sizes.items()}
out_pvals = [pe.PartialVal.unknown(_insert_aval_axes(a, ax, local_axis_sizes))
for a, ax in zip(out_avals, out_axes)]
unknown_tracers_out = [pe.JaxprTracer(self, pval, None) for pval in out_pvals]
# Build eqn to be staged out and attach it to unknown output tracers.
eqn = pe.new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers),
unknown_tracers_out, primitive, unknown_params,
jaxpr.effects, source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
return merge_lists(out_knowns, unknown_tracers_out, known_outvals)
pe.JaxprTrace.process_xmap = _jaxpr_trace_process_xmap


def _batch_trace_update_spmd_axes(
spmd_in_axes, spmd_out_axes_thunk,
axis_name, dims, dims_out_thunk):
Expand Down
43 changes: 1 addition & 42 deletions jax/interpreters/partial_eval.py
Expand Up @@ -377,19 +377,6 @@ def out_axes_transform(out_axes):
def _current_truncated_name_stack(self):
return source_info_util.current_name_stack()[len(self.name_stack):]

def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
instantiate: bool):
"""Partially evaluate f on a sequence of PartialVals."""
in_avals, in_consts = unzip2(pvals)
f = trace_to_subjaxpr(f, self.main, instantiate)
f, aux = partial_eval_wrapper(f, tuple(in_avals))
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvs = map(PartialVal, zip(out_avals, out_consts))
env_tracers = map(self.full_raise, env)
return jaxpr, out_pvs, consts, env_tracers

def process_custom_jvp_call(self, prim, fun, jvp, tracers):
tracers = map(self.instantiate_const_abstracted, tracers)
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
Expand Down Expand Up @@ -751,7 +738,7 @@ def getconstvar(c):
t, "Tracer not among input tracers {}".format(t))
assert in_tracers, "Lambda binding with no args"
elif isinstance(recipe, FreeVar):
env[cast(Var, getvar(t))] = recipe.val
env[getvar(t)] = recipe.val # type: ignore
elif isinstance(recipe, ConstVar):
v = t_to_var[id(t)] = getconstvar(recipe.val)
consts[v] = recipe.val
Expand Down Expand Up @@ -2217,31 +2204,3 @@ def fun(*vals):
in_avals_2 = [*in_avals_2, *res_avals]

return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out

@weakref_lru_cache
def _drop_vars(jaxpr: Jaxpr, drop_ins: Tuple[bool, ...], drop_outs: Tuple[bool, ...]):
return Jaxpr(jaxpr.constvars,
[v for v, d in zip(jaxpr.invars, drop_ins) if not d],
[v for v, d in zip(jaxpr.outvars, drop_outs) if not d],
jaxpr.eqns, jaxpr.effects)

@weakref_lru_cache
def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan, call, or other higher-order primitives.
# TODO(mattjj): better DCE (i.e. use above dce_jaxpr)
if drop_outputs:
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
else:
new_outvars = [var if output else core.unitvar
for var, output in zip(jaxpr.outvars, outputs)]

needed_vars = {v for v in new_outvars if type(v) is not Literal}
new_eqns = []
for eqn in jaxpr.eqns[::-1]:
if set(eqn.outvars) & needed_vars:
new_eqns.append(eqn)
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
new_eqns = new_eqns[::-1]
return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, new_eqns,
jaxpr.effects)

0 comments on commit ca112da

Please sign in to comment.