Skip to content

Commit

Permalink
Merge pull request #439 from google/pjit
Browse files Browse the repository at this point in the history
fix nested pjit transpose bug
  • Loading branch information
mattjj committed Feb 24, 2019
2 parents 12ed52e + dd75c56 commit daf3e3f
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 22 deletions.
41 changes: 33 additions & 8 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from __future__ import division
from __future__ import print_function

import itertools as it

from . import partial_eval as pe
from .. import core as core
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, zero, Zero)
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
from ..tree_util import process_pytree, build_tree, register_pytree_node, prune
from ..tree_util import process_pytree, build_tree, register_pytree_node, prune, tree_map
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init

from six.moves import builtins, reduce
Expand Down Expand Up @@ -138,11 +140,10 @@ def read_cotangent(v):

if cts_out is zero:
cts_out = [zero for _ in eqn.invars]
map(write_cotangent, eqn.invars, cts_out)

for var, ct in zip(eqn.invars, cts_out):
write_cotangent(var, ct)

cotangents_out = map(read_cotangent, jaxpr.invars)
cotangents_out = [read_cotangent(var) if argval is None else None
for var, argval in zip(jaxpr.invars, args)]
freevar_cts = map(read_cotangent, jaxpr.freevars)
return freevar_cts, cotangents_out

Expand Down Expand Up @@ -392,12 +393,36 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
ans = primitive.bind(fun, all_args, **params)
return build_tree(out_tree_def(), ans)

@transformation_with_aux
def transposed_mapped(jaxpr, in_tree_def, freevar_vals, args):
args, consts, ct = args
args, ct = build_tree(in_tree_def, (args, ct))
freevar_cts, cotangents_out = yield jaxpr, consts, freevar_vals, args, ct
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
yield out_jtuple, tree_def

def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
jaxpr, = jaxpr
consts, = consts
freevar_vals, = freevar_vals
(args, ct), in_tree_def = tree_to_jaxtuples((args, ct))
fun = wrap_init(backward_pass)
fun, out_tree_def = transposed_mapped(fun, jaxpr, in_tree_def, tuple(freevar_vals))
all_args = pack((pack(args), pack(consts), ct))
ans = primitive.bind(fun, all_args, **params)
cts_out, freevar_cts = build_tree(out_tree_def(), ans)
freevar_cts = tree_map(_sum_leading_axis, freevar_cts)
return cts_out, freevar_cts

def _sum_leading_axis(x):
try:
return x.sum(0)
except AttributeError:
return onp.sum(x, 0)


primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.compiled_call_p] = partial(call_transpose, pe.compiled_call_p)


tree_to_jaxtuples = partial(process_pytree, pack)


call_primitive_jvp_params = {}
47 changes: 35 additions & 12 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,36 +143,58 @@ def axis_env_extend(name, size):
else:
ans = translation_rule(eqn.primitive)(c, *in_nodes, **eqn.params)

out_nodes = xla_destructure(c, ans) if eqn.destructure else [ans]
try:
out_nodes = xla_destructure(c, ans) if eqn.destructure else [ans]
except:
import ipdb; ipdb.set_trace()
map(write, eqn.outvars, out_nodes)
return c.Build(read(jaxpr.outvar))

def xla_split(c, axis_sizes, x):
def _xla_split(shape, x):
if shape.is_tuple():
elts = map(_xla_split, shape.tuple_shapes(), xla_destructure(c, x))
return c.Tuple(*elts)
def split_array(shape, x):
if xb.get_replica_count() == 1:
# TODO(mattjj): remove this special case, used for debugging on CPU
# because CPU doesn't have some collectives implemented
dims = c.GetShape(x).dimensions()
return c.Reshape(x, None, dims[1:])
else:
size = onp.array(prod(axis_sizes), onp.uint32)
idx = c.Rem(c.ReplicaId(), c.Constant(size))
dims = list(shape.dimensions())
zero = onp.zeros(len(dims) - 1, onp.uint32)
start_indices = c.Concatenate([c.Reshape(idx, None, [1]),
c.Constant(zero)], 0)
start_indices = c.Concatenate([c.Reshape(idx, None, [1]), c.Constant(zero)], 0)
return c.Reshape(c.DynamicSlice(x, start_indices, [1] + dims[1:]),
None, dims[1:])

def _xla_split(shape, x):
if shape.is_tuple():
elts = map(_xla_split, shape.tuple_shapes(), xla_destructure(c, x))
return c.Tuple(*elts)
else:
return split_array(shape, x)

return _xla_split(c.GetShape(x), x)

# TODO(b/110096942): more efficient gather
def xla_join(c, device_groups, x):
def join_arrays(x):
# TODO(mattjj): remove this special case, used for debugging on CPU
# because CPU doesn't have some collectives implemented
if xb.get_replica_count() == 1:
dims = c.GetShape(x).dimensions()
return c.Reshape(x, None, (1,) + tuple(dims))
else:
group_size = len(device_groups[0])
broadcasted = c.Broadcast(x, (group_size,))
return c.AllToAll(broadcasted, 0, 0, device_groups)

def _xla_join(shape, x):
if shape.is_tuple():
elts = map(_xla_join, shape.tuple_shapes(), xla_destructure(c, x))
return c.Tuple(*elts)
else:
group_size = len(device_groups[0])
broadcasted = c.Broadcast(x, (group_size,))
return c.AllToAll(broadcasted, 0, 0, device_groups)
return join_arrays(x)

return _xla_join(c.GetShape(x), x)


Expand Down Expand Up @@ -223,9 +245,10 @@ def execute_replicated(compiled, pval, axis_size, out_tree, *args):
xla_pcall = partial(core.call_bind, xla_pcall_p)
xla_pcall_p.def_custom_bind(xla_pcall)
xla_pcall_p.def_impl(xla_pcall_impl)
ad.primitive_transposes[xla_pcall_p] = partial(ad.call_transpose, xla_pcall_p)
# xla.translations[xla_pcall_p] = xla.xla_call_translation_rule # TODO(mattjj)
ad.primitive_transposes[xla_pcall_p] = partial(ad.map_transpose, xla_pcall_p)
pe.map_primitives.add(xla_pcall_p)
# TODO(mattjj): enable pjit inside jit
# xla.translations[xla_pcall_p] = xla.xla_call_translation_rule


parallel_translation_rules = {}
13 changes: 13 additions & 0 deletions jax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,10 @@ def __hash__(self):
def tie_in(x, y):
return tie_in_p.bind(x, y)

def shaped_identity(x):
return shaped_identity_p.bind(x, shape=x.shape)


def full(shape, fill_value, dtype):
try:
shape = tuple(map(int, shape))
Expand Down Expand Up @@ -3311,6 +3315,15 @@ def _tie_in_batch_rule(batched_args, batch_dims):
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule


shaped_identity_p = Primitive('shape_id')
shaped_identity_p.def_impl(lambda x, shape: x)
shaped_identity_p.def_abstract_eval(lambda x, shape: x)
xla.translations[shaped_identity_p] = lambda c, x, shape: x
ad.deflinear(shaped_identity_p, lambda t, shape: [shaped_identity(t)])
batching.primitive_batchers[shaped_identity_p] = \
lambda a, d, shape: (shaped_identity(a[0]), d[0])


### constants


Expand Down
36 changes: 34 additions & 2 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,53 @@
from __future__ import division
from __future__ import print_function

from functools import partial

import numpy as onp
from absl.testing import absltest
from absl.testing import parameterized

import jax.numpy as np
from jax import test_util as jtu
from jax.api import pjit, pmap, jvp, grad
from jax import lax
from jax.api import pjit, pmap, vmap, jvp, grad, make_jaxpr, linearize
from jax.lax import psum
from jax.lib import xla_bridge

from jax.config import config
config.parse_flags_with_absl()


class PjitTest(jtu.JaxTestCase):
pass # TODO(mattjj)

@jtu.skip_on_devices("gpu", "tpu")
def testNestedWithClosure(self):
assert xla_bridge.get_replica_count() == 1 # OSS CPU testing only
x = onp.arange(3, dtype=onp.float32).reshape(1, 1, 3)

@partial(pjit, axis_name='i')
def test_fun(x):
y = np.sum(np.sin(x))

@partial(pjit, axis_name='j')
def g(z):
return 3. * np.exp(np.sin(x).sum() * np.cos(y) * np.tan(z))

return grad(lambda w: np.sum(g(w)))(x)

@vmap
def baseline_fun(x):
y = np.sum(np.sin(x))

@vmap
def g(z):
return 3. * np.exp(np.sin(x).sum() * np.cos(y) * np.tan(z))

return grad(lambda w: np.sum(g(w)))(x)

ans = grad(lambda x: np.sum(test_fun(x)))(x)
expected = grad(lambda x: np.sum(baseline_fun(x)))(x)
self.assertAllClose(ans, expected, check_dtypes=True)


if __name__ == '__main__':
Expand Down

0 comments on commit daf3e3f

Please sign in to comment.