Skip to content

Commit

Permalink
add scan dce rule
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 28, 2022
1 parent a161d6a commit 4608d36
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 5 deletions.
30 changes: 28 additions & 2 deletions jax/_src/lax/control_flow.py
Expand Up @@ -24,7 +24,7 @@
import itertools
import operator
import os
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, List

import numpy as np

Expand Down Expand Up @@ -1980,6 +1980,32 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)

def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], core.JaxprEqn]:
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry])
for i in range(1 + num_carry):
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'].jaxpr,
used_carry_out + used_extensive_out)
used_consts, used_carry_in, used_extensive_in = \
split_list(used_inputs, [num_consts, num_carry])
if used_carry_in == used_carry_out:
break
else:
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
else:
assert False, "Fixpoint not reached"

new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
new_params = dict(eqn.params, num_consts=sum(used_consts),
num_carry=sum(used_carry_in), linear=tuple(new_linear),
jaxpr=core.ClosedJaxpr(jaxpr, eqn.params['jaxpr'].consts))
new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, eqn.effects,
eqn.source_info)
return used_inputs, new_eqn

def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
tc = partial(_typecheck_param, 'scan')
Expand Down Expand Up @@ -2049,7 +2075,7 @@ def scan_bind(*args, **params):
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
pe.padding_rules[scan_p] = _scan_padding_rule

pe.dce_rules[scan_p] = _scan_dce_rule


@api_boundary
Expand Down
4 changes: 1 addition & 3 deletions jax/interpreters/partial_eval.py
Expand Up @@ -1185,15 +1185,13 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
for v in jaxpr.outvars]


# TODO(mattjj): unify with dce code below
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool]
) -> Tuple[Jaxpr, List[bool]]:
return _dce_jaxpr(jaxpr, tuple(used_outputs))

@weakref_lru_cache
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
) -> Tuple[Jaxpr, List[bool]]:
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
env: Dict[Var, bool] = {}

def read(v: Var) -> bool:
Expand Down Expand Up @@ -1224,7 +1222,7 @@ def write(x: Atom, b: bool) -> None:
map(write, eqn.invars, used_ins)
used_inputs = map(read, jaxpr.invars)

new_jaxpr = Jaxpr((),
new_jaxpr = Jaxpr(jaxpr.constvars,
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
[v for v, b in zip(jaxpr.outvars, used_outputs) if b],
new_eqns[::-1], jaxpr.effects)
Expand Down
71 changes: 71 additions & 0 deletions tests/api_test.py
Expand Up @@ -50,6 +50,7 @@
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import partial_eval as pe
from jax.interpreters.pxla import PartitionSpec as P
from jax._src import device_array
import jax._src.lib
Expand Down Expand Up @@ -4492,6 +4493,76 @@ def test_pretty_print_unitvar(self):
self.assertIn('in (*,)', str(jaxpr))
self.assertNotIn('in (a,)', str(jaxpr))

def test_dce_jaxpr_scan(self):
@api.remat
def scanned_f(c, x):
out = jnp.tanh(c * x)
return out, out

def f(xs):
return lax.scan(scanned_f, 1., xs)

jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(jnp.arange(10.)).jaxpr
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars))

self.assertLen(jaxpr.eqns, 1)
self.assertLen(jaxpr.eqns[-1].params['jaxpr'].jaxpr.eqns, 2)

def test_dce_jaxpr_scan_nontrivial_fixedpoint(self):
def f(lst):
def body(c, _):
return [c[0]] + [c1 + c2 for c1, c2 in zip(c[:-1], c[1:])], None
out, _ = jax.lax.scan(body, lst, None, length=len(lst))
return out
jaxpr = api.make_jaxpr(f)([1, 2, 3, 4]).jaxpr
self.assertLen(jaxpr.eqns, 1)
self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 3)

# If we use all but the last element, only one eqn is pruned.
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, True, False])
self.assertLen(jaxpr_pruned.eqns, 1)
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 2)
# And all but the first input is used.
self.assertEqual(used_inputs, [True, True, True, False])

# If we use all but the last two elements, two eqns can be pruned.
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, False, False])
self.assertLen(jaxpr_pruned.eqns, 1)
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 1)
# And the last two inputs are not used.
self.assertEqual(used_inputs, [True, True, False, False])

# If we only use the last element, no eqns can be pruned.
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [False, False, False, True])
self.assertLen(jaxpr_pruned.eqns, 1)
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 3)
# And all inputs are used.
self.assertEqual(used_inputs, [True, True, True, True])

def test_dce_jaxpr_scan_const_in_jvp(self):
@api.custom_jvp
def f(x):
return x * np.arange(3.)
@f.defjvp
def f_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return f(x), xdot * np.arange(3.)

def g(x):
def body(c, _):
return f(c), None
y, _ = jax.lax.scan(body, x, None, length=1)
return y

jvp_jaxpr = api.make_jaxpr(lambda x, xdot: api.jvp(g, (x,), (xdot,)))(
np.arange(3.), np.arange(3.)).jaxpr

jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, True])
self.assertTrue(all(used_inputs))

jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, False])
self.assertEqual(used_inputs, [True, False])


class CustomJVPTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 4608d36

Please sign in to comment.