Skip to content

Commit

Permalink
rewrite axis_index implementation, use custom bind (#2807)
Browse files Browse the repository at this point in the history
* rewrite axis_index implementation, use custom bind

fixes #2716

Co-authored-by: Trevor Cai <tycai@google.com>

* add test for #2716

Co-authored-by: Trevor Cai <tycai@google.com>
  • Loading branch information
mattjj and trevorcai committed Apr 23, 2020
1 parent 13a1728 commit d2653a1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
27 changes: 14 additions & 13 deletions jax/interpreters/pxla.py
Expand Up @@ -383,26 +383,27 @@ def axis_index(axis_name):
[0 1]
[0 1]]
"""
return axis_index_p.bind(axis_name=axis_name)

def _axis_index_bind(*, axis_name):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
frame = dynamic_axis_env[axis_name]
sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
nreps = dynamic_axis_env.nreps
dummy_arg = frame.pmap_trace.pure(core.unit)
if frame.soft_trace:
dummy_arg = frame.soft_trace.pure(dummy_arg)

return axis_index_p.bind(dummy_arg, nreps=nreps, sizes=sizes,
soft_size=frame.soft_size, axis_name=axis_name)
trace = frame.pmap_trace

def _axis_index_partial_eval(trace, _, **params):
# This partial_eval rule adds the axis_index primitive into the jaxpr formed
# during pmap lowering. It is like the standard JaxprTrace.process_primitive
# rule except that we don't attempt to lower out of the trace.
out_aval = ShapedArray((), onp.int32)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, params)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
dict(nreps=nreps, sizes=sizes,
soft_size=frame.soft_size, axis_name=axis_name))
out_tracer.recipe = eqn
return out_tracer

if not frame.soft_trace:
return out_tracer
else:
val_out = out_tracer * frame.soft_size + onp.arange(frame.soft_size)
return SplitAxisTracer(frame.soft_trace, axis_name, val_out)

def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):
div = c.Constant(onp.array(nreps // prod(sizes), dtype=onp.uint32))
Expand All @@ -411,8 +412,8 @@ def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):
return c.ConvertElementType(unsigned_index, xb.dtype_to_etype(onp.int32))

axis_index_p = core.Primitive('axis_index')
axis_index_p.def_custom_bind(_axis_index_bind)
xla.translations[axis_index_p] = _axis_index_translation_rule
pe.custom_partial_eval_rules[axis_index_p] = _axis_index_partial_eval


### lazy device-memory persistence and result handling
Expand Down
11 changes: 11 additions & 0 deletions tests/pmap_test.py
Expand Up @@ -1037,6 +1037,17 @@ def distributed_matrix_vector(x, y):
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)

def testAxisIndexRemat(self):
# https://github.com/google/jax/issues/2716
n = len(jax.devices())

def f(key):
key = random.fold_in(key, jax.lax.axis_index('i'))
return random.bernoulli(key, p=0.5)

keys = random.split(random.PRNGKey(0), n)
jax.pmap(jax.remat(f), axis_name='i')(keys)


class PmapWithDevicesTest(jtu.JaxTestCase):

Expand Down

0 comments on commit d2653a1

Please sign in to comment.