Skip to content

Commit

Permalink
roll back breakage
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 472949225
  • Loading branch information
jax authors committed Sep 8, 2022
1 parent 672c497 commit 14f1a34
Show file tree
Hide file tree
Showing 14 changed files with 53 additions and 198 deletions.
2 changes: 0 additions & 2 deletions jax/_src/ad_checkpoint.py
Expand Up @@ -439,8 +439,6 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
ad.primitive_jvps[remat_p] = remat_jvp

remat_allowed_effects: Set[core.Effect] = set()
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Infeed)
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Outfeed)

def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars
Expand Down
5 changes: 0 additions & 5 deletions jax/_src/custom_derivatives.py
Expand Up @@ -26,7 +26,6 @@
register_pytree_node_class, tree_leaves)
from jax._src import custom_api_util
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
from jax.core import raise_to_shaped
Expand Down Expand Up @@ -340,10 +339,6 @@ def _apply_todos(todos, outs):


allowed_effects: Set[core.Effect] = set()
allowed_effects.add(lax.InOutFeedEffect.Infeed)
allowed_effects.add(lax.InOutFeedEffect.Outfeed)


custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')

def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts):
Expand Down
36 changes: 17 additions & 19 deletions jax/_src/dispatch.py
Expand Up @@ -316,25 +316,22 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html.")

if not in_shardings:
inp_device_assignment = da
else:
inp_device_assignment = None

# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings, pjit._UNSPECIFIED,
donated_invars, in_avals,
in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
committed=committed, inp_device_assignment=inp_device_assignment).compile(
committed=committed).compile(
_allow_propagation_to_outputs=True).unsafe_call


def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, keep_unused, *arg_specs):
if config.jax_array:
# TODO(yashkatariya): Remove the `and arg_specs` from here once
# lower_sharding_computation supports no avals as input.
if config.jax_array and arg_specs:
return sharded_lowering(fun, device, backend, name,
donated_invars, keep_unused, *arg_specs)
else:
Expand All @@ -344,13 +341,6 @@ def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
xla_callable = lu.cache(_xla_callable_uncached)


def is_single_device_sharding(sharding) -> bool:
from jax.experimental.sharding import PmapSharding
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)


@contextlib.contextmanager
def log_elapsed_time(fmt: str):
if _on_exit:
Expand Down Expand Up @@ -541,11 +531,19 @@ def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:

def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
used_outputs = [True] * len(jaxpr.outvars)
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
kept_var_idx = {i for i, b in enumerate(used_inputs) if b}
return new_jaxpr, kept_const_idx, kept_var_idx
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
# applications that do not produce used outputs. Must handle side-effecting
# primitives and nested jaxpr.
used.update(
v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
kept_const_idx, new_constvars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
kept_var_idx, new_invars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns,
jaxpr.effects)
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)


# We can optionally set a Jaxpr rewriter that can be applied just before
Expand Down
3 changes: 0 additions & 3 deletions jax/_src/lax/control_flow/common.py
Expand Up @@ -21,7 +21,6 @@
from jax import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src import ad_util
from jax._src import util
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
Expand All @@ -30,8 +29,6 @@
map, unsafe_map = safe_map, map

allowed_effects: Set[core.Effect] = set()
allowed_effects.add(lax.InOutFeedEffect.Infeed)
allowed_effects.add(lax.InOutFeedEffect.Outfeed)


def _abstractify(x):
Expand Down
16 changes: 5 additions & 11 deletions jax/_src/lax/lax.py
Expand Up @@ -30,7 +30,6 @@
from jax._src import api
from jax._src import api_util
from jax._src import device_array
from jax._src import dispatch
from jax import linear_util as lu
from jax._src import dtypes
from jax import tree_util
Expand Down Expand Up @@ -1327,7 +1326,7 @@ def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
# (so it works in staged-out code as well as 'eager' code). Related to
# equi-sharding.
if (config.jax_array and hasattr(x, 'sharding') and
not dispatch.is_single_device_sharding(x.sharding)):
not isinstance(x.sharding, sharding.SingleDeviceSharding)):
return array.make_array_from_callback(
fill_shape, x.sharding, lambda idx: val[idx]) # type: ignore[arg-type]
return val
Expand Down Expand Up @@ -4109,9 +4108,6 @@ def _after_all_lowering(ctx, *operands):
mlir.register_lowering(after_all_p, _after_all_lowering)


InOutFeedEffect = enum.Enum('InOutFeedEffect', ['Infeed', 'Outfeed'])


def infeed(token, shape=None, partitions=None):
"""Consumes an infeed value of `shape` from the host. Experimental.
Expand All @@ -4137,14 +4133,13 @@ def infeed(token, shape=None, partitions=None):
def _infeed_abstract_eval(token, *, shapes, partitions):
if token is not abstract_token:
raise TypeError("First argument to infeed must be a token")
return (*shapes, abstract_token), {InOutFeedEffect.Infeed}
return shapes + (abstract_token,)


infeed_p = Primitive("infeed")
infeed_p.multiple_results = True
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
mlir.lowerable_effects.add(InOutFeedEffect.Infeed)
infeed_p.def_abstract_eval(_infeed_abstract_eval)


def _infeed_lowering(ctx, token, *, shapes, partitions):
Expand Down Expand Up @@ -4190,12 +4185,11 @@ def outfeed(token, xs, partitions = None):
def _outfeed_abstract_eval(token, *xs, partitions):
if token is not abstract_token:
raise TypeError("First argument to outfeed must be a token")
return abstract_token, {InOutFeedEffect.Outfeed}
return abstract_token

outfeed_p = Primitive("outfeed")
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
mlir.lowerable_effects.add(InOutFeedEffect.Outfeed)
outfeed_p.def_abstract_eval(_outfeed_abstract_eval)


def _outfeed_lowering(ctx, token, *xs, partitions):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/prng.py
Expand Up @@ -31,7 +31,7 @@
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.experimental.sharding import (
MeshPspecSharding, PmapSharding, OpShardingSharding)
MeshPspecSharding, SingleDeviceSharding, PmapSharding, OpShardingSharding)

from jax._src import dispatch
from jax._src import dtypes
Expand Down Expand Up @@ -364,7 +364,7 @@ def global_sharded_result_handler(aval, out_sharding, committed,
phys_handler_maker = pxla.global_result_handlers[
(core.ShapedArray, output_type)]

if dispatch.is_single_device_sharding(out_sharding):
if isinstance(out_sharding, SingleDeviceSharding):
phys_sharding = out_sharding
elif isinstance(out_sharding, MeshPspecSharding):
trailing_spec = [None] * len(key_shape)
Expand Down
2 changes: 0 additions & 2 deletions jax/core.py
Expand Up @@ -208,7 +208,6 @@ def __repr__(self):
def replace(self, *args, **kwargs):
return self._replace(*args, **kwargs)

# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
source_info = source_info or source_info_util.new_source_info()
if config.jax_enable_checks:
Expand Down Expand Up @@ -2163,7 +2162,6 @@ def _unmap_dshaped_array(
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
}

@contextmanager
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/array.py
Expand Up @@ -451,7 +451,7 @@ def _device_put_array(x, device: Optional[Device]):
# TODO(yashkatariya): Remove this restriction and the round trip via host
# once lowering to XLA goes through `lower_mesh_computation`.
assert x.is_fully_addressable()
if dispatch.is_single_device_sharding(x.sharding):
if isinstance(x.sharding, SingleDeviceSharding):
x = dispatch._copy_device_array_to_device(pxla._set_aval(x._arrays[0]), device)
return (x,)
else:
Expand All @@ -462,7 +462,7 @@ def _device_put_array(x, device: Optional[Device]):


def _array_pmap_shard_arg(x, devices, indices, mode):
if dispatch.is_single_device_sharding(x.sharding):
if isinstance(x.sharding, SingleDeviceSharding):
return pxla._shard_device_array(x, devices, indices, mode)

if x._fast_path_args is None:
Expand All @@ -484,7 +484,7 @@ def _array_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
return _array_pmap_shard_arg(x, devices, indices, mode)
else:
if dispatch.is_single_device_sharding(x.sharding):
if isinstance(x.sharding, SingleDeviceSharding):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
# If PmapSharding exists, then do a round trip via host. This will happen
Expand Down
22 changes: 0 additions & 22 deletions jax/interpreters/partial_eval.py
Expand Up @@ -961,17 +961,6 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr

@weakref_lru_cache
def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
"""Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr."""
config.jax_enable_checks and core.check_jaxpr(jaxpr)
constvars, invars = split_list(jaxpr.invars, [n])
lifted_jaxpr = Jaxpr(constvars=tuple(constvars), invars=invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr

def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
config.jax_enable_checks and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
Expand Down Expand Up @@ -1317,17 +1306,6 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate = (instantiate,) * len(jaxpr.invars)
return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate))


def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, List[bool], List[bool]]:
jaxpr_ = convert_constvars_jaxpr(jaxpr)
new_jaxpr_, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs)
used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)])
new_jaxpr = convert_invars_to_constvars(new_jaxpr_, sum(used_consts))
return new_jaxpr, used_consts, used_inputs


@weakref_lru_cache
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
instantiate: Tuple[bool, ...]
Expand Down

0 comments on commit 14f1a34

Please sign in to comment.