Skip to content

Commit

Permalink
Merge pull request #431 from google/pjit
Browse files Browse the repository at this point in the history
update and simplify pjit
  • Loading branch information
mattjj committed Feb 23, 2019
2 parents aca4941 + eacf065 commit b1686a3
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 500 deletions.
3 changes: 1 addition & 2 deletions jax/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from __future__ import division
from __future__ import print_function

from .core import JaxTuple, lattice_join
from .interpreters.partial_eval import Primitive
from .core import JaxTuple, lattice_join, Primitive
from .tree_util import register_pytree_node
from .util import safe_map

Expand Down
18 changes: 11 additions & 7 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import numpy as onp
from contextlib import contextmanager
from distutils.util import strtobool
from six.moves import reduce

from . import core
from . import linear_util as lu
Expand Down Expand Up @@ -405,20 +406,23 @@ def batched_fun(*args, **kwargs):
return batched_fun


def pjit(fun, axis_name, in_axes=0, out_axes=0, mesh_axis=0):
def pjit(fun, axis_name):
"""Set up SPMD function for JIT compilation and parallel execution with XLA."""
@wraps(fun)
def f_jitted(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
leaves, _ = tree_flatten(args)
axis_sizes = set(onp.shape(leaf)[0] for leaf in leaves)
if len(axis_sizes) != 1:
msg = "pjit requires all leading axes to have equal length, got {}."
raise TypeError(msg.format(axis_sizes))
axis_size = axis_sizes.pop()

jaxtupletree_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
_check_args(jaxtupletree_args)
f = lu.wrap_init(fun, kwargs)
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
chunksize = pxla.chunk_size(axis_name, mesh_axis, in_axes_, jaxtupletree_args)
f = pxla.chunk_transform(f, chunksize, axis_name, in_axes_, out_axes)
jaxtupletree_out = pxla.xla_pcall(f, *jaxtupletree_args,
axis_name=axis_name, in_axes=in_axes_,
out_axes=out_axes, mesh_axis=mesh_axis)
axis_name=axis_name, axis_size=axis_size)
return build_tree(out_tree(), jaxtupletree_out)

f_jitted.__name__ = "pjit({})".format(f_jitted.__name__)
Expand Down
8 changes: 8 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def def_impl(self, impl):
self.impl = impl
return impl

def def_abstract_eval(self, abstract_eval):
self.abstract_eval = abstract_eval
return abstract_eval

def def_custom_bind(self, bind):
self.bind = bind
return bind
Expand All @@ -90,6 +94,10 @@ def impl(self, *args, **kwargs):
raise NotImplementedError("Evaluation rule for '{}' not implemented"
.format(self.name))

def abstract_eval(self, *args, **kwargs):
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
.format(self.name))


# -------------------- lifting --------------------

Expand Down
51 changes: 21 additions & 30 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from __future__ import print_function

from . import partial_eval as pe
from . import xla
from . import pxla
from .. import core as core
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
Expand All @@ -31,6 +29,8 @@

zip = safe_zip
map = safe_map
def identity(x): return x


def jvp(fun):
return jvpfun(jvp_subtrace(fun))
Expand Down Expand Up @@ -75,7 +75,8 @@ def vjp(traceable, primals):
def vjp_(ct):
ct = ignore_consts(ct, pval)
dummy_primal_and_ct = pack((core.unit, ct))
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_primal_and_ct)
dummy_args = (None,) * len(jaxpr.invars)
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primal_and_ct)
return instantiate_zeros(pack(primals), arg_cts[1])

return out_primal, vjp_
Expand All @@ -100,7 +101,7 @@ def unpair_pval(pval):
aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2)

def backward_pass(jaxpr, consts, freevar_vals, cotangent_in):
def backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in):
def write_cotangent(v, ct):
# assert v not in primal_env
if ct is not None:
Expand All @@ -109,10 +110,11 @@ def write_cotangent(v, ct):
def read_cotangent(v):
return ct_env.get(v, zero)

primal_env = {v: val
for v, val in zip(jaxpr.freevars, freevar_vals)
primal_env = {v: val for v, val in zip(jaxpr.freevars, freevar_vals)
if val is not None}
primal_env.update(zip(jaxpr.constvars, consts))
primal_env.update((v, val) for v, val in zip(jaxpr.invars, args)
if val is not None)
ct_env = {jaxpr.outvar: cotangent_in}

for eqn in jaxpr.eqns[::-1]:
Expand Down Expand Up @@ -185,12 +187,7 @@ def process_call(self, call_primitive, f, tracers, params):
tangents = [t.tangent for t in tracers]
nonzero_tangents, in_tree_def = tree_to_jaxtuples(tangents)
f, out_tree_def = traceable(jvp_subtrace(f, self.master), in_tree_def)
if call_primitive is pxla.xla_pcall_p:
in_ax, out_ax = params['in_axes'], params['out_axes']
new_params = dict(params, in_axes=(in_ax, in_ax), out_axes=(out_ax, out_ax))
result = call_primitive.bind(f, pack(primals), nonzero_tangents, **new_params)
else:
result = call_primitive.bind(f, pack(primals), nonzero_tangents, **params)
result = call_primitive.bind(f, pack(primals), nonzero_tangents, **params)
primal_out, tangent_out = build_tree(out_tree_def(), result)
return JVPTracer(self, primal_out, tangent_out)

Expand Down Expand Up @@ -299,8 +296,8 @@ def defjvp(primitive, *jvprules):

def standard_jvp(jvprules, primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
tangents_out = (rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and t is not zero)
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and t is not zero]
return val_out, reduce(add_tangents, tangents_out, zero)


Expand Down Expand Up @@ -376,9 +373,9 @@ def traceable(in_tree_def, new_primals, new_tangents):

@transformation_with_aux
def transposed_fun(jaxpr, in_tree_def, args):
consts, freevar_vals, ct = args
ct, freevar_vals = build_tree(in_tree_def, (ct, freevar_vals))
freevar_cts, cotangents_out = yield jaxpr, consts, freevar_vals, ct
args, consts, freevar_vals, ct = args
args, ct, freevar_vals = build_tree(in_tree_def, (args, ct, freevar_vals))
freevar_cts, cotangents_out = yield jaxpr, consts, freevar_vals, args, ct
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
yield out_jtuple, tree_def

Expand All @@ -387,26 +384,20 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
consts, = consts
freevar_vals, = freevar_vals
assert isinstance(jaxpr, core.Jaxpr)
assert all(a is None for a in args), "TODO(dougalm): handle non-tangent primal args"
(ct, freevar_vals), in_tree_def = tree_to_jaxtuples((ct, freevar_vals))
(args, ct, freevar_vals), in_tree_def = tree_to_jaxtuples((args, ct, freevar_vals))
fun = wrap_init(backward_pass)
fun, out_tree_def = transposed_fun(fun, jaxpr, in_tree_def)
all_args = pack((pack(consts), pack(freevar_vals), ct))
all_args = pack((pack(args), pack(consts), pack(freevar_vals), ct))
# TODO(dougalm): consider signalling to bind that no traces in fun closure
if primitive is pxla.xla_pcall_p:
in_axes, out_axes = params['in_axes'], params['out_axes']
trans_in_axes = (None, None, out_axes),
trans_out_axes = (in_axes, None)
new_params = dict(params, in_axes=trans_in_axes, out_axes=trans_out_axes)
ans = primitive.bind(fun, all_args, **new_params)
else:
ans = primitive.bind(fun, all_args, **params)
ans = primitive.bind(fun, all_args, **params)
return build_tree(out_tree_def(), ans)


primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.compiled_call_p] = partial(call_transpose, pe.compiled_call_p)
primitive_transposes[xla.xla_call_p] = partial(call_transpose, xla.xla_call_p)
primitive_transposes[pxla.xla_pcall_p] = partial(call_transpose, pxla.xla_pcall_p)


tree_to_jaxtuples = partial(process_pytree, pack)


call_primitive_jvp_params = {}
123 changes: 0 additions & 123 deletions jax/interpreters/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,129 +164,6 @@ def pack(self, tracers):
pmap_primitive_rules = {}


### axis variable splitting and computation chunking


@lu.transformation
def axisvar_split(name, new_names, *args):
with new_master(SplitTrace) as master:
trace = SplitTrace(master, core.cur_sublevel())
in_tracers = map(partial(SplitTracer, trace, name, new_names), args)
ans = yield in_tracers
out_tracer = trace.full_raise(ans)
out_val = out_tracer.val
del master, out_tracer
yield out_val

@lu.transformation
def axisvar_split_subtrace(master, name, new_names, *vals):
trace = SplitTrace(master, core.cur_sublevel())
ans = yield map(partial(SplitTracer, trace, name, new_names), vals)
out_tracer = trace.full_raise(ans)
out_val = out_tracer.val
yield out_val

class SplitTracer(Tracer):
def __init__(self, trace, name, new_names, val):
self.trace = trace
self.name = name
self.new_names = new_names
self.val = val

@property
def aval(self):
return core.get_aval(self.val)

def unpack(self):
if self.name is None:
return self.full_lower()
else:
elt_tracer = partial(SplitTracer, self.trace, self.name, self.new_names)
return map(elt_tracer, self.val)

def full_lower(self):
if self.name is None:
return core.full_lower(self.val)
else:
return self

class SplitTrace(Trace):
def pure(self, val):
return SplitTracer(self, None, (), val)

def lift(self, val):
return SplitTracer(self, None, (), val)

def sublift(self, val):
return SplitTracer(self, val.name, val.new_names, val.val)

def process_primitive(self, primitive, tracers, params):
names_in, vals_in = unzip2((t.name, t.val) for t in tracers)
if all(name is None for name in names_in):
return primitive.bind(*vals_in, **params)
else:
name = next(name for name in names_in if name is not None)
new_names = next(t.new_names for t in tracers if t.name is not None)
if primitive in pmap_primitive_rules:
val_in, = vals_in
if name == params['axis_name']:
new_params = {k: params[k] for k in params if k != 'axis_name'}
val = val_in
for new_name in new_names:
val = primitive.bind(val, axis_name=new_name, **new_params)
val_out = val
return SplitTracer(self, name, new_names, val_out)
else:
val_out = primitive.bind(val_in, **params)
return SplitTracer(self, name, new_names, val_out)
else:
val_out = primitive.bind(*vals_in, **params)
return SplitTracer(self, name, new_names, val_out)

def process_call(self, call_primitive, f, tracers, params):
names_in, vals_in = unzip2((t.name, t.val) for t in tracers)
if all(name is None for name in names_in):
return call_primitive.bind(f, *vals, **params)
else:
name = next(name for name in names_in if name is not None)
new_names = next(t.new_names for t in tracers if t.name is not None)
f = axisvar_split_subtrace(f, self.master, name, new_names)
val_out = call_primitive.bind(f, *vals_in, **params)
return SplitTracer(self, name, new_names, val_out)

def post_process_call(self, _, out_tracer):
name, new_names, val = out_tracer.name, out_tracer.new_names, out_tracer.val
master = self.master
def todo(x):
trace = SplitTrace(master, core.cur_sublevel())
return SplitTracer(trace, name, new_names, x)

return val, todo

def pack(self, tracers):
vals = core.pack([t.val for t in tracers])
name = next(t.name for t in tracers if t.name is not None)
new_names = next(t.new_names for t in tracers if t.name is not None)
return SplitTracer(self, name, new_names, vals)

def reshape_axis(chunksize, in_axis, arg):
aval = core.get_aval(arg)
if type(aval) is core.AbstractTuple:
if type(in_axis) is int:
return core.pack(map(partial(reshape_axis, chunksize, in_axis), arg))
elif isinstance(in_axis, (list, tuple)):
return core.pack(map(partial(reshape_axis, chunksize), in_axis, arg))
else:
raise TypeError("unexpected in_axis type: {}".format(type(in_axis)))
elif isinstance(aval, ShapedArray):
in_axis = in_axis % arg.ndim
split_shape = (arg.shape[in_axis] // chunksize, chunksize)
new_shape = arg.shape[:in_axis] + split_shape + arg.shape[in_axis+1:]
return arg.reshape(new_shape)
else:
raise TypeError(type(arg))


### papply


Expand Down

0 comments on commit b1686a3

Please sign in to comment.