Skip to content

Commit

Permalink
[remove units] make JaxprTrace.process_call not introduce units
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Feb 13, 2022
1 parent 06c4012 commit d59af33
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 107 deletions.
19 changes: 12 additions & 7 deletions jax/core.py
Expand Up @@ -1795,10 +1795,9 @@ def new_out_axes_thunk():
out_axes = t(out_axes)
return out_axes
params = dict(params, out_axes_thunk=new_out_axes_thunk)
params_tuple = tuple(params.items())
top_trace = find_top_trace(args)
fun, todo_and_xforms = process_env_traces_map(
fun, primitive, top_trace and top_trace.level, params_tuple)
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
outs = primitive.process(top_trace, fun, tracers, params)
env_trace_todo, _ = todo_and_xforms()
Expand Down Expand Up @@ -1826,14 +1825,16 @@ def process_env_traces_map(primitive: MapPrimitive, level: int,
yield outs, (tuple(todo), tuple(out_axes_transforms))


def mapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue:
def mapped_aval(size: int, axis: Optional[int], aval: AbstractValue
) -> AbstractValue:
handler, _ = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis, aval)
else:
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")

def unmapped_aval(size: int, axis_name, axis: int, aval: AbstractValue) -> AbstractValue:
def unmapped_aval(size: int, axis_name, axis: Optional[int], aval: AbstractValue
) -> AbstractValue:
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis_name, axis, aval)
Expand All @@ -1843,16 +1844,20 @@ def unmapped_aval(size: int, axis_name, axis: int, aval: AbstractValue) -> Abstr
def _map_unit(*_) -> AbstractUnit:
return abstract_unit

def _map_shaped_array(size: int, axis: int, aval: ShapedArray) -> ShapedArray:
assert aval.shape[axis] == size
def _map_shaped_array(size: int, axis: Optional[int], aval: ShapedArray
) -> ShapedArray:
assert axis is None or aval.shape[axis] == size
# TODO: Extend the named shape
if axis is None: return aval
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
named_shape=aval.named_shape)

def _unmap_shaped_array(size: int, axis_name, axis: int, aval: ShapedArray) -> ShapedArray:
def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
aval: ShapedArray) -> ShapedArray:
named_shape = dict(aval.named_shape)
# TODO: Make this mandatory
named_shape.pop(axis_name, None)
if axis is None: return aval.replace(named_shape=named_shape)
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
named_shape=named_shape)

Expand Down
195 changes: 95 additions & 100 deletions jax/interpreters/partial_eval.py
Expand Up @@ -37,8 +37,8 @@
as_hashable_function)
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, raise_to_shaped, Var, Atom,
JaxprEqn, Primitive, DShapedArray)
ConcreteArray, raise_to_shaped, Var, Atom, JaxprEqn,
Primitive, DShapedArray, mapped_aval, unmapped_aval)
from jax._src import source_info_util
from jax.config import config

Expand Down Expand Up @@ -221,71 +221,71 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
return merge_lists(out_knowns, out_tracers, out_consts)

def process_map(self, primitive, f: lu.WrappedFun, tracers, params):
in_pvals = [t.pval for t in tracers]
mapped_aval = partial(core.mapped_aval, params['axis_size'])
in_pvals = [pval if pval.is_known() or in_axis is None
else PartialVal.unknown(mapped_aval(in_axis, pval[0]))
for pval, in_axis in zip(in_pvals, params['in_axes'])]

def app(f, *args):
f, num_outputs = count_outputs(f)
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
out_axes = out_axes_thunk()
return out_axes + (0,) * (num_outputs() - len(out_axes))
pe_params = dict(params, out_axes_thunk=new_out_axes_thunk)
return primitive.bind(f, *args, **pe_params)
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers])

# This method is like process_call above, except:
# 1. we delete an axis from mapped-over input avals' shapes, and
# analogously add an axis to mapped-over output avals' shapes;
# 2. we update the in_axes and out_axes/out_axes_thunk parameters to
# reflect the inputs and outputs pruned from the unknown/known sides.

# Map (delete an axis from) unknown inputs' avals as dictated by in_axes.
unk_in_axes, const_in_axes = partition_list(in_knowns, params['in_axes'])
in_avals_mapped = [mapped_aval(params['axis_size'], ax, aval)
for ax, aval in zip(unk_in_axes, in_avals)]

# Wrap f to perform partial evaluation and plumb out aux data.
f = trace_to_subjaxpr_nounits(f, self.main, False)
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns),
tuple(in_avals_mapped))
# Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk)
const_params = update_params(params, in_knowns, 0) # handles donated_invars
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(closure=out_axes_thunk)
def const_out_axes_thunk():
out_knowns, _, jaxpr, _ = aux()
_, out_axes = partition_list(out_knowns, out_axes_thunk())
return tuple(out_axes) + (0,) * len(jaxpr.constvars) # res mapped axis 0
const_params = dict(const_params, in_axes=tuple(const_in_axes),
out_axes_thunk=const_out_axes_thunk)

# Run the map, getting known out vals and aux data used for staged-out map.
with core.extend_axis_env(params['axis_name'], params['axis_size'], None):
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
f, in_pvals, app, instantiate=False)
unmapped_aval = partial(core.unmapped_aval, params['axis_size'], params['axis_name'])
out_axes = params['out_axes_thunk']()
out_pvals = [pval if pval.is_known() else
PartialVal.unknown(unmapped_aval(out_axis, pval[0])) if out_axis is not None else
PartialVal.unknown(pval[0])
for pval, out_axis in zip(out_pvals, out_axes)]

# Skip known invars and outvars, and lift constants as regular invars
in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
jaxpr = _drop_vars(jaxpr, in_knowns, (False,) * len(jaxpr.outvars))
jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)

# Known tracers get propagated as if they were constants
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
if pval.is_known()]

# Unknown tracers need to have the jaxpr set up as their recipe
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
if not pval.is_known()]
unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)

# Set up new params
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
in_axes = params['in_axes']
# NOTE: const_tracers are added as map outputs, and we always map them
# along axis 0 (see `new_out_axes_thunk` above).
new_in_axes = ((0,) * len(const_tracers) + (None,) * len(env_tracers) +
tuple(axis for axis, t in zip(in_axes, tracers)
if not t.pval.is_known()))
new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals)
if not pval.is_known())
new_params = dict(new_params, in_axes=new_in_axes, out_axes=new_out_axes)
del new_params['out_axes_thunk']
update_params = call_param_updaters.get(primitive)
if update_params:
num_new_inputs = len(const_tracers) + len(env_tracers)
unknown_args = [not t.pval.is_known() for t in tracers]
new_params = update_params(new_params, unknown_args, num_new_inputs)
out = primitive.bind(f, *in_consts, **const_params)
out_knowns, out_avals_mapped, jaxpr, env = aux()
# Split apart known outputs from the original call and residuals.
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])

# We can only check_jaxpr with the dynamic axis environment extended:
with core.extend_axis_env(params['axis_name'], params['axis_size'], None):
call_jaxpr = convert_constvars_jaxpr(jaxpr)

# Compute staged and const out_axes, taking into account residuals.
out_axes = params['out_axes_thunk']()
staged_out_axes, _ = partition_list(out_knowns, out_axes)
staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,)

eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
# Create the input tracers for the staged-out (unkonwn-value) call.
const_tracers = map(self.new_instantiated_const, res)
env_tracers = map(self.full_raise, env)
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
# Adjust params for staged-out call on unknown values.
num_new_args = len(const_tracers) + len(env_tracers)
staged_params = update_params(params, map(op.not_, in_knowns), num_new_args)
staged_params = dict(staged_params, in_axes=staged_in_axes,
out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr)
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a)
for ax, a in zip(staged_out_axes, out_avals_mapped)]
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params,
source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
for t in out_tracers: t.recipe = eqn

return merge_lists(out_knowns, out_tracers, out_consts)

def post_process_call(self, primitive, out_tracers, params):
unknown_out_tracers = [t for t in out_tracers if not t.is_known()]
Expand Down Expand Up @@ -313,45 +313,45 @@ def todo(out):
return out, todo

def post_process_map(self, primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
nconsts = len(consts)
del consts, out_pv_consts
unknown_out_tracers = [t for t in out_tracers if not t.is_known()]
jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers)
out_pvals = [t.pval for t in out_tracers]
out_knowns, out_avals_mapped, out_consts = partition_pvals(out_pvals)
out = [*out_consts, *res]
main = self.main

def todo(x):
out_axes_ = params['out_axes_thunk']()
assert len(out_axes_) == len(out_pvs) + nconsts # transformed below
out_axes = out_axes_[:len(out_pvs)]
sz = params['axis_size']
out_pvs_ = [core.unmapped_aval(sz, params['axis_name'], ax, pv)
if pv is not None else None for pv, ax in zip(out_pvs, out_axes)]

n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(main, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs_, out_pv_consts)]
in_tracers = (*const_tracers, *map(trace.full_raise, env))
with core.extend_axis_env(params['axis_name'], params['axis_size'], None):
call_jaxpr = convert_constvars_jaxpr(jaxpr)

new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
# NOTE: We've assigned axis 0 to const tracers below, in out_axes_transform.
new_in_axes = (0,) * len(const_tracers) + (None,) * len(env)
new_params = dict(new_params, in_axes=new_in_axes, out_axes=out_axes)
del new_params['out_axes_thunk']
update_params = call_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, [], len(in_tracers))
def todo(out):
trace = main.with_cur_sublevel()
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env)

eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
source_info_util.current())
staged_out_axes = tuple(out_axes_unknown) # set by out_axes_transform
staged_in_axes = (0,) * len(res) + (None,) * len(env)

update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
staged_params = update_params(params, [], len(res) + len(env))
staged_params = dict(staged_params, in_axes=staged_in_axes,
out_axes=tuple(staged_out_axes),
call_jaxpr=call_jaxpr)

out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], d, a)
for d, a in zip(staged_out_axes, out_avals_mapped)]
out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None)
for a in out_avals]
eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
primitive, staged_params, source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
return merge_lists(out_knowns, out_tracers, out_consts)

def out_axes_transform(out_axes):
return out_axes + (0,) * nconsts
nonlocal out_axes_unknown
out_axes_unknown, out_axes_known = partition_list(out_knowns, out_axes)
return tuple(out_axes_known) + (0,) * len(jaxpr.constvars)
out_axes_unknown: Optional[list] = None

return out, (todo, out_axes_transform)

Expand Down Expand Up @@ -470,11 +470,6 @@ def partial_eval_wrapper_nounits(
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
yield (*out_consts, *res), (out_knowns, out_avals, jaxpr, env)

@lu.transformation_with_aux
def count_outputs(*args, **kwargs):
ans = yield args, kwargs
yield ans, len(ans)

custom_partial_eval_rules: Dict[Primitive, Callable] = {}
call_partial_eval_rules: Dict[Primitive, Callable] = {}
call_param_updaters: Dict[Primitive, Callable] = {}
Expand Down

0 comments on commit d59af33

Please sign in to comment.