Skip to content

Commit

Permalink
omnistaging wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed May 29, 2020
1 parent 853bcda commit 7128a12
Show file tree
Hide file tree
Showing 30 changed files with 1,325 additions and 1,368 deletions.
1 change: 0 additions & 1 deletion docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ Operators
square
sub
tan
tie_in
top_k
transpose

Expand Down
8 changes: 6 additions & 2 deletions jax/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.


from .core import (lattice_join, Primitive, Unit, unit, AbstractUnit,
from jax import core
from .core import (lattice_join, Primitive, Primitive, Unit, unit, AbstractUnit,
valid_jaxtype)
from .tree_util import register_pytree_node
from typing import Any, Dict
Expand All @@ -27,7 +28,10 @@
jaxval_adders[Unit] = lambda _, __: unit

def add_jaxvals(x, y):
return add_jaxvals_p.bind(x, y)
if core.get_aval(x) is core.get_aval(y) is core.abstract_unit:
return core.unit
else:
return add_jaxvals_p.bind(x, y)

add_jaxvals_p = Primitive('add_any')

Expand Down
69 changes: 18 additions & 51 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def xla_computation(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
backend: Optional[str] = None,
tuple_args: bool = False,
instantiate_const_outputs: bool = True) -> Callable:
tuple_args: bool = False) -> Callable:
"""Creates a function that produces its XLA computation given example args.
Args:
Expand All @@ -226,13 +225,6 @@ def xla_computation(fun: Callable,
tuple_args: Optional bool, defaults to False. If True, the resulting XLA
computation will have a single tuple argument that is unpacked into the
specified function arguments.
instantiate_const_outputs: Optional bool, defaults to True. If False, then
``xla_computation`` does not instantiate constant-valued outputs in the
XLA computation, and so the result is closer to the computation that
``jax.jit`` produces and may be more useful for studying ``jit`` behavior.
If True, then constant-valued outputs are instantiated in the XLA
computation, which may be more useful for staging computations out of JAX
entirely.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns a
Expand Down Expand Up @@ -301,11 +293,11 @@ def xla_computation(fun: Callable,

def make_axis_env(nreps):
if axis_env is None:
return xla.AxisEnv(nreps)
return xla.AxisEnv(nreps, (), (), None)
else:
nreps = nreps * prod(size for name, size in axis_env)
names, sizes = zip(*axis_env)
return xla.AxisEnv(nreps, names, sizes)
return xla.AxisEnv(nreps, names, sizes, None)

def abstractify(x):
return ShapedArray(onp.shape(x), dtypes.result_type(x))
Expand All @@ -321,10 +313,7 @@ def computation_maker(*args, **kwargs):
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
avals = map(abstractify, jax_args)
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals,
instantiate=instantiate_const_outputs,
stage_out=True)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr, _ = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
Expand Down Expand Up @@ -1026,6 +1015,9 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
else:
static_broadcasted_tuple = tuple(static_broadcasted_argnums)

if not all(axis == 0 for axis in tree_leaves(in_axes)):
raise ValueError("pmap only supports in_axes leaves of 0 or None")

# axis_size is an optional integer representing the global axis size.
# The aggregate size (across all hosts) size of the mapped axis must match
# the given value. This argument is mutually exclusive with ``devices``.
Expand Down Expand Up @@ -1054,8 +1046,6 @@ def f_pmapped(*args, **kwargs):
dyn_args, dyn_in_axes = args, in_axes
args, in_tree = tree_flatten((dyn_args, kwargs))
in_axes_flat = _flatten_axes(in_tree, (dyn_in_axes, 0))
assert all(axis in (0, None) for axis in in_axes_flat), \
"pmap currently only supports mapping over the leading axis"
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap")
_check_args(args)
flat_fun, out_tree = flatten_fun(f, in_tree)
Expand Down Expand Up @@ -1086,42 +1076,26 @@ def __eq__(self, other):
return self.obj is other.obj


def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, *,
in_axes=0, backend: Optional[str] = None) -> Callable:
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
) -> Callable:
warn("soft_pmap is an experimental feature and probably has bugs!")
_check_callable(fun)
axis_name = _TempAxisName(fun) if axis_name is None else axis_name

if not all(axis == 0 for axis in tree_leaves(in_axes)):
raise ValueError("soft_pmap only supports in_axes leaves of 0 or None")

@wraps(fun)
def f_pmapped(*args, **kwargs):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten((args, kwargs))
in_axes_flat = _flatten_axes(in_tree, (in_axes, 0))
assert all(axis in (0, None) for axis in in_axes_flat), \
"soft_pmap currently only supports mapping over the leading axis"
mapped_invars = tuple(axis is not None for axis in in_axes_flat)
axis_size = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "soft_pmap")
_check_args(args_flat)
flat_fun, out_tree = flatten_fun(f, in_tree)

chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count(backend))
if chunk_size == 0 and leftover:
return pmap(fun, axis_name, backend=backend)(*args) # can map directly onto hardware
elif leftover:
msg = ("soft_pmap mapped axis size must be divisible by the number of "
"XLA devices (or be less than or equal to that number), but got "
"an axis size of {} with {} devices.")
raise ValueError(msg.format(axis_size, pxla.unmapped_device_count()))
num_chunks = axis_size // chunk_size

reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]
soft_mapped_fun = pxla.split_axis(flat_fun, axis_name, chunk_size)
reshaped_outs = pxla.xla_pmap(soft_mapped_fun, *reshaped_args, backend=backend,
axis_name=axis_name, axis_size=num_chunks,
global_axis_size=None, devices=None,
name=soft_mapped_fun.__name__,
mapped_invars=mapped_invars)
outs = [_reshape_merge(out) for out in reshaped_outs]
outs = pxla.soft_pmap(flat_fun, *args_flat, axis_name=axis_name,
axis_size=axis_size, mapped_invars=mapped_invars)
return tree_unflatten(out_tree(), outs)

namestr = "soft_pmap({}, axis_name={})".format
Expand Down Expand Up @@ -1535,10 +1509,6 @@ def make_jaxpr(fun: Callable,
if isinstance(static_argnums, int):
static_argnums = (static_argnums,)

def pv_like(x):
aval = xla.abstractify(x)
return pe.PartialVal.unknown(aval)

@wraps(fun)
def jaxpr_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun)
Expand All @@ -1549,11 +1519,9 @@ def jaxpr_maker(*args, **kwargs):
dyn_args = args
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
in_pvals = map(pv_like, jax_args)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out=True)
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
in_avals = tuple(raise_to_shaped(in_aval) for in_aval, _ in in_pvals)
in_avals = map(xla.abstractify, jax_args)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
in_avals = tuple(raise_to_shaped(in_aval) for in_aval in in_avals)
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
return typed_jaxpr

Expand Down Expand Up @@ -1736,8 +1704,7 @@ def __call__(self, *args):
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_pvals = [pe.PartialVal.unknown(raise_to_shaped(core.get_aval(x)))
for x in args_flat]
with core.initial_style_staging():
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
in_tree=in_tree, out_tree=out_tree(),
num_consts=len(consts))
Expand Down
Loading

0 comments on commit 7128a12

Please sign in to comment.