Skip to content

Commit

Permalink
Merge pull request #9805 from froystig:lax-cleanup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 433347321
  • Loading branch information
jax authors committed Mar 9, 2022
2 parents 972070a + bea7771 commit 537e35b
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 63 deletions.
31 changes: 17 additions & 14 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,33 @@
import jax
from jax import core
from jax import linear_util as lu
from jax._src import dtypes
from jax.core import eval_jaxpr
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
tree_multimap, treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)

from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
from jax._src import traceback_util
from jax._src.traceback_util import api_boundary
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
tree_multimap, treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, new_name_stack, wrap_name, cache, wraps,
HashableFunction)
from jax._src import device_array
from jax._src import dispatch
from jax._src import source_info_util
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, new_name_stack, wrap_name, cache,
wraps, HashableFunction)

# Unused imports to be exported
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index,
Expand Down Expand Up @@ -1082,7 +1085,7 @@ def value_and_grad_f(*args, **kwargs):
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
_check_scalar(ans)
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
g = vjp_py(jax.lax._one(ans))
g = vjp_py(lax_internal._one(ans))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
return ans, g
Expand Down Expand Up @@ -1368,7 +1371,7 @@ def _possible_downcast(x, example):
x = x.real
dtype = None if example is None else _dtype(example)
weak_type = None if example is None else dtypes.is_weakly_typed(example)
return jax._src.lax.lax._convert_element_type(x, dtype, weak_type)
return lax_internal._convert_element_type(x, dtype, weak_type)

def _unravel_array_into_pytree(pytree, axis, example, arr):
"""Unravel an array into a PyTree with a given structure.
Expand Down
65 changes: 33 additions & 32 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,7 +2214,7 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
_check_arraylike(name, a)
lax._check_user_dtype_supported(dtype, name)
lax_internal._check_user_dtype_supported(dtype, name)
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")

if initial is None and not has_identity:
Expand Down Expand Up @@ -2405,7 +2405,7 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None):
_check_arraylike("mean", a)
lax._check_user_dtype_supported(dtype, "mean")
lax_internal._check_user_dtype_supported(dtype, "mean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")

Expand Down Expand Up @@ -2493,7 +2493,7 @@ def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, *, where=None):
_check_arraylike("var", a)
lax._check_user_dtype_supported(dtype, "var")
lax_internal._check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")

Expand Down Expand Up @@ -2549,7 +2549,7 @@ def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def _std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, *, where=None):
_check_arraylike("std", a)
lax._check_user_dtype_supported(dtype, "std")
lax_internal._check_user_dtype_supported(dtype, "std")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
Expand Down Expand Up @@ -2664,7 +2664,7 @@ def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
lax._check_user_dtype_supported(dtype, "nanprod")
lax_internal._check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
Expand All @@ -2676,7 +2676,7 @@ def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
lax._check_user_dtype_supported(dtype, "nanprod")
lax_internal._check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
Expand All @@ -2686,7 +2686,7 @@ def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, where=None):
_check_arraylike("nanmean", a)
lax._check_user_dtype_supported(dtype, "nanmean")
lax_internal._check_user_dtype_supported(dtype, "nanmean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
Expand All @@ -2705,7 +2705,7 @@ def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, where=None):
_check_arraylike("nanvar", a)
lax._check_user_dtype_supported(dtype, "nanvar")
lax_internal._check_user_dtype_supported(dtype, "nanvar")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")

Expand Down Expand Up @@ -2733,7 +2733,7 @@ def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, where=None):
_check_arraylike("nanstd", a)
lax._check_user_dtype_supported(dtype, "nanstd")
lax_internal._check_user_dtype_supported(dtype, "nanstd")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
return sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
Expand All @@ -2754,7 +2754,7 @@ def _cumulative_reduction(a,
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
f"is not supported.")
lax._check_user_dtype_supported(dtype, np_reduction.__name__)
lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

if axis is None or isscalar(a):
a = ravel(a)
Expand Down Expand Up @@ -3338,7 +3338,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
raise NotImplementedError("Only implemented for order='K'")

# check if the given dtype is compatible with JAX
lax._check_user_dtype_supported(dtype, "array")
lax_internal._check_user_dtype_supported(dtype, "array")

# Here we make a judgment call: we only return a weakly-typed array when the
# input object itself is weakly typed. That ensures asarray(x) is a no-op whenever
Expand Down Expand Up @@ -3417,7 +3417,7 @@ def _convert_to_array_if_dtype_fails(x):

@_wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a, dtype=None, order=None):
lax._check_user_dtype_supported(dtype, "asarray")
lax_internal._check_user_dtype_supported(dtype, "asarray")
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
return array(a, dtype=dtype, copy=False, order=order)

Expand All @@ -3430,7 +3430,7 @@ def copy(a, order=None):
@_wraps(np.zeros_like)
def zeros_like(a, dtype=None, shape=None):
_check_arraylike("zeros_like", a)
lax._check_user_dtype_supported(dtype, "zeros_like")
lax_internal._check_user_dtype_supported(dtype, "zeros_like")
if np.isscalar(shape):
shape = (shape,)
return lax.full_like(a, 0, dtype, shape)
Expand All @@ -3439,15 +3439,15 @@ def zeros_like(a, dtype=None, shape=None):
@_wraps(np.ones_like)
def ones_like(a, dtype=None, shape=None):
_check_arraylike("ones_like", a)
lax._check_user_dtype_supported(dtype, "ones_like")
lax_internal._check_user_dtype_supported(dtype, "ones_like")
if np.isscalar(shape):
shape = (shape,)
return lax.full_like(a, 1, dtype, shape)


@_wraps(np.full)
def full(shape, fill_value, dtype=None):
lax._check_user_dtype_supported(dtype, "full")
lax_internal._check_user_dtype_supported(dtype, "full")
_check_arraylike("full", fill_value)
if ndim(fill_value) == 0:
shape = (shape,) if ndim(shape) == 0 else shape
Expand All @@ -3458,7 +3458,7 @@ def full(shape, fill_value, dtype=None):

@_wraps(np.full_like)
def full_like(a, fill_value, dtype=None, shape=None):
lax._check_user_dtype_supported(dtype, "full_like")
lax_internal._check_user_dtype_supported(dtype, "full_like")
_check_arraylike("full_like", a, fill_value)
if shape is not None:
shape = (shape,) if ndim(shape) == 0 else shape
Expand All @@ -3474,15 +3474,15 @@ def full_like(a, fill_value, dtype=None, shape=None):
def zeros(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "zeros")
lax_internal._check_user_dtype_supported(dtype, "zeros")
shape = canonicalize_shape((shape,) if ndim(shape) == 0 else shape)
return lax.full(shape, 0, _jnp_dtype(dtype))

@_wraps(np.ones)
def ones(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "ones")
lax_internal._check_user_dtype_supported(dtype, "ones")
shape = canonicalize_shape((shape,) if ndim(shape) == 0 else shape)
return lax.full(shape, 1, _jnp_dtype(dtype))

Expand Down Expand Up @@ -3522,25 +3522,25 @@ def array_equiv(a1, a2):

@_wraps(np.eye)
def eye(N, M=None, k=0, dtype=None):
lax._check_user_dtype_supported(dtype, "eye")
lax_internal._check_user_dtype_supported(dtype, "eye")
N = core.canonicalize_dim(N, "'N' argument of jnp.eye()")
M = N if M is None else core.canonicalize_dim(M, "'M' argument of jnp.eye()")
if N < 0 or M < 0:
raise ValueError(f"negative dimensions are not allowed, got {N} and {M}")
k = operator.index(k)
return lax._eye(_jnp_dtype(dtype), (N, M), k)
return lax_internal._eye(_jnp_dtype(dtype), (N, M), k)


@_wraps(np.identity)
def identity(n, dtype=None):
lax._check_user_dtype_supported(dtype, "identity")
lax_internal._check_user_dtype_supported(dtype, "identity")
return eye(n, dtype=dtype)


@_wraps(np.arange)
def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
step: Optional[core.DimSize]=None, dtype=None):
lax._check_user_dtype_supported(dtype, "arange")
lax_internal._check_user_dtype_supported(dtype, "arange")
require = partial(core.concrete_or_error, None)
msg = "It arose in jax.numpy.arange argument `{}`.".format
if _any(core.is_special_dim_size(d) for d in (start, stop, step)):
Expand Down Expand Up @@ -3590,7 +3590,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis: int = 0):
"""Implementation of linspace differentiable in start and stop args."""
lax._check_user_dtype_supported(dtype, "linspace")
lax_internal._check_user_dtype_supported(dtype, "linspace")
if num < 0:
raise ValueError(f"Number of samples, {num}, must be non-negative.")
_check_arraylike("linspace", start, stop)
Expand Down Expand Up @@ -3653,7 +3653,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
axis: int = 0):
"""Implementation of logspace differentiable in start and stop args."""
lax._check_user_dtype_supported(dtype, "logspace")
lax_internal._check_user_dtype_supported(dtype, "logspace")
if dtype is None:
dtype = result_type(start, stop, dtypes.canonicalize_dtype(float_))
dtype = _jnp_dtype(dtype)
Expand All @@ -3676,7 +3676,7 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
"""Implementation of geomspace differentiable in start and stop args."""
lax._check_user_dtype_supported(dtype, "geomspace")
lax_internal._check_user_dtype_supported(dtype, "geomspace")
if dtype is None:
dtype = result_type(start, stop, dtypes.canonicalize_dtype(float_))
dtype = _jnp_dtype(dtype)
Expand Down Expand Up @@ -4140,10 +4140,10 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):

@_wraps(np.tri)
def tri(N, M=None, k=0, dtype=None):
lax._check_user_dtype_supported(dtype, "tri")
lax_internal._check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
return lax._tri(dtype, (N, M), k)
return lax_internal._tri(dtype, (N, M), k)


@_wraps(np.tril)
Expand Down Expand Up @@ -4174,7 +4174,7 @@ def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
_check_arraylike("trace", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
lax._check_user_dtype_supported(dtype, "trace")
lax_internal._check_user_dtype_supported(dtype, "trace")

axis1 = _canonicalize_axis(axis1, ndim(a))
axis2 = _canonicalize_axis(axis2, ndim(a))
Expand Down Expand Up @@ -4780,7 +4780,7 @@ def sum_repeats(operand, names, counts, keep_names):
for name, count in counts.items():
if count > 1:
axes = [i for i, n in enumerate(names) if n == name]
eye = lax._delta(operand.dtype, operand.shape, axes)
eye = lax_internal._delta(operand.dtype, operand.shape, axes)
if name not in keep_names:
operand = sum(operand * eye, axes)
names = names.replace(name, '')
Expand Down Expand Up @@ -6422,7 +6422,7 @@ def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
def _astype(arr, dtype):
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
lax._check_user_dtype_supported(dtype, "astype")
lax_internal._check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(arr, dtype)


Expand All @@ -6445,7 +6445,7 @@ def _clip(number, min=None, max=None, out=None, *, a_min=None, a_max=None):


def _view(arr, dtype=None, type=None):
lax._check_user_dtype_supported(dtype, "view")
lax_internal._check_user_dtype_supported(dtype, "view")
if type is not None:
raise NotImplementedError("`type` argument of array.view()")
if dtype is None:
Expand Down Expand Up @@ -6808,7 +6808,8 @@ def apply(self, func, indices_are_sorted=False, unique_indices=False,
"""
def _scatter_apply(x, indices, _, dims, **kwargs):
return lax.scatter_apply(x, indices, func, dims, **kwargs)
return scatter._scatter_update(self.array, self.index, lax._zero(self.array.dtype),
return scatter._scatter_update(self.array, self.index,
lax_internal._zero(self.array.dtype),
_scatter_apply,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _threefry2x32_abstract_eval(*args):
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
.format(args))
if all(isinstance(arg, core.ShapedArray) for arg in args):
shape = lax._broadcasting_shape_rule(*args)
shape = lax_internal._broadcasting_shape_rule(*args)
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
else:
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax import linear_util as lu
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src import util
import numpy as np
Expand Down Expand Up @@ -1252,7 +1253,7 @@ def _make_harness(group_name: str, name: str,
[RandArg((3, 4, 5), _f32)],
poly_axes=[0]),
_make_harness("delta", "0",
lambda x: lax._delta(_f32, x.shape, axes=(0, 1)),
lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)),
[RandArg((3, 4), _f32)],
poly_axes=[0]),
_make_harness("dot_general", "",
Expand Down
12 changes: 8 additions & 4 deletions jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,14 +617,18 @@ def chooser_taylor_rule(primals_in, series_in, **params):
location_indicators = lax.convert_element_type(
lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)),
primal_dtype)
counts = lax._reduce_sum(location_indicators, axes)
counts = lax_internal._reduce_sum(location_indicators, axes)
def _reduce_chooser_taylor_rule(g):
return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
return lax.div(
lax_internal._reduce_sum(lax.mul(g, location_indicators), axes),
counts)
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
return primal_out, series_out
return chooser_taylor_rule
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax._reduce_max)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax._reduce_min)
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(
lax_internal._reduce_max)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(
lax_internal._reduce_min)

def _abs_taylor_rule(x, series_in, **params):
x, = x
Expand Down

0 comments on commit 537e35b

Please sign in to comment.