Skip to content

Commit

Permalink
add scan dce rule tests, fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed May 6, 2022
1 parent d57e364 commit d0863a1
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 71 deletions.
15 changes: 10 additions & 5 deletions jax/_src/lax/control_flow.py
Expand Up @@ -2037,30 +2037,36 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):

def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], core.JaxprEqn]:
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry])
for i in range(1 + num_carry):
used_outputs = used_carry_out + used_extensive_out
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'].jaxpr, used_outputs)
jaxpr_dce, used_inputs = pe.dce_jaxpr(
jaxpr.jaxpr, used_outputs,
instantiate=[False] * num_consts + used_carry_out + [False] * num_xs)
used_consts, used_carry_in, used_extensive_in = \
split_list(used_inputs, [num_consts, num_carry])
if used_carry_in == used_carry_out:
if list(used_carry_in) == list(used_carry_out):
break
else:
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
else:
assert False, "Fixpoint not reached"
core.check_jaxpr(jaxpr.jaxpr)

new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
new_params = dict(eqn.params, num_consts=sum(used_consts),
num_carry=sum(used_carry_in), linear=tuple(new_linear),
jaxpr=core.ClosedJaxpr(jaxpr, eqn.params['jaxpr'].consts))
jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts))
new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs)
if used],
[v for v, used in zip(eqn.outvars, used_outputs)
if used],
eqn.primitive, new_params, eqn.effects,
eqn.source_info)
assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals )
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
return used_inputs, new_eqn

Expand Down Expand Up @@ -2133,8 +2139,7 @@ def scan_bind(*args, **params):
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
pe.padding_rules[scan_p] = _scan_padding_rule
# TODO(mattjj): re-enable
# pe.dce_rules[scan_p] = _scan_dce_rule
pe.dce_rules[scan_p] = _scan_dce_rule


@api_boundary
Expand Down
48 changes: 28 additions & 20 deletions jax/interpreters/partial_eval.py
Expand Up @@ -1161,12 +1161,16 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
for v in jaxpr.outvars]


def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool]
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, List[bool]]:
return _dce_jaxpr(jaxpr, tuple(used_outputs))
if type(instantiate) is bool:
instantiate = (instantiate,) * len(jaxpr.invars)
return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate))

@weakref_lru_cache
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
instantiate: Tuple[bool, ...]
) -> Tuple[Jaxpr, List[bool]]:
env: Dict[Var, bool] = {}

Expand All @@ -1177,26 +1181,23 @@ def write(x: Atom, b: bool) -> None:
if type(x) is Var:
env[x] = read(x) or b

def has_effects(e: JaxprEqn) -> bool:
return bool(e.effects) or core.primitive_uses_outfeed(e.primitive, e.params)

new_eqns = []
map(write, jaxpr.outvars, used_outputs)
for eqn in jaxpr.eqns[::-1]:
used_outs = map(read, eqn.outvars)
# If any outputs are used, then we need to keep a version of the eqn and
# potentially mark some inputs as used. Otherwise mark all inputs as unused.
if any(used_outs) or core.primitive_uses_outfeed(eqn.primitive, eqn.params):
# If there's a rule for modifying the eqn and computing used inputs, apply
# it. Otherwise, keep the eqn unmodified and mark all inputs as used.
rule = dce_rules.get(eqn.primitive)
if rule:
used_ins, new_eqn = rule(used_outs, eqn)
else:
used_ins = [True] * len(eqn.invars)
new_eqn = eqn
new_eqns.append(new_eqn)
else:
if not any(used_outs) and not has_effects(eqn):
used_ins = [False] * len(eqn.invars)
else:
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
used_ins, new_eqn = rule(used_outs, eqn)
if new_eqn is not None:
new_eqns.append(new_eqn)
map(write, eqn.invars, used_ins)
used_inputs = map(read, jaxpr.invars)
used_inputs = map(op.or_, instantiate, used_inputs)

new_jaxpr = Jaxpr(jaxpr.constvars,
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
Expand All @@ -1206,7 +1207,13 @@ def write(x: Atom, b: bool) -> None:

return new_jaxpr, used_inputs

DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], JaxprEqn]]
DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], Optional[JaxprEqn]]]

def _default_dce_rule(
used_outs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], JaxprEqn]:
return [True] * len(eqn.invars), eqn

dce_rules: Dict[Primitive, DCERule] = {}


Expand All @@ -1217,9 +1224,10 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
update_params = call_param_updaters.get(eqn.primitive)
if update_params:
new_params = update_params(new_params, used_inputs, 0)
new_eqn = new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
Expand Down

0 comments on commit d0863a1

Please sign in to comment.