Skip to content

Commit

Permalink
Merge pull request #19096 from mattjj:no-more-add-jaxval-primitive
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592985946
  • Loading branch information
jax authors committed Dec 22, 2023
2 parents 3021e90 + be3ca50 commit 32e1a0c
Show file tree
Hide file tree
Showing 16 changed files with 39 additions and 132 deletions.
9 changes: 0 additions & 9 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@
canonicalize_shape = core.canonicalize_shape
raise_to_shaped = core.raise_to_shaped

def zeros_like_array(x):
dtype, weak_type = dtypes._lattice_result_type(x)
dtype = dtypes.canonicalize_dtype(dtype)
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
return ad_util.zeros_like_aval(aval)

numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
Expand All @@ -62,11 +56,9 @@ def masked_array_error(*args, **kwargs):
"Use arr.filled() to convert the value to a standard numpy array.")

core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
ad_util.jaxval_zeros_likers[np.ma.MaskedArray] = masked_array_error

for t in array_types:
core.pytype_aval_mappings[t] = canonical_concrete_aval
ad_util.jaxval_zeros_likers[t] = zeros_like_array

core.literalable_types.update(array_types)

Expand All @@ -82,6 +74,5 @@ def _make_concrete_python_scalar(t, x):

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)

core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) # type: ignore
48 changes: 16 additions & 32 deletions jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
from __future__ import annotations

import types
from typing import Any, Callable, TypeVar
from typing import Any, Callable, Type, TypeVar

from jax._src import core
from jax._src import traceback_util
from jax._src.core import (lattice_join, Primitive, valid_jaxtype,
raise_to_shaped, get_aval)
from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval
from jax._src.tree_util import register_pytree_node
from jax._src.typing import Array, ArrayLike
from jax._src.util import safe_map
Expand All @@ -30,44 +29,23 @@

map = safe_map

jaxval_adders: dict[type, Callable[[ArrayLike, ArrayLike], Array]] = {}

def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:
return add_jaxvals_p.bind(x, y)

add_jaxvals_p: Primitive = Primitive('add_any')
add_any_p = add_jaxvals_p

@add_jaxvals_p.def_impl
def add_impl(xs, ys):
return jaxval_adders[type(xs)](xs, ys)
aval = core.raise_to_shaped(core.get_aval(x))
return aval_adders[type(aval)](x, y)
aval_adders: dict[Type[core.AbstractValue], Callable] = {}

@add_jaxvals_p.def_abstract_eval
def add_abstract(xs, ys):
return lattice_join(xs, ys)
def zeros_like_aval(aval: core.AbstractValue) -> Array:
return aval_zeros_likers[type(aval)](aval)
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}

jaxval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
def zeros_like_jaxval(val):
return zeros_like_aval(core.raise_to_shaped(core.get_aval(val)))

def instantiate(z: Zero | Array) -> Array:
if isinstance(z, Zero):
return zeros_like_aval(z.aval)
return z

def zeros_like_aval(aval: core.AbstractValue) -> Array:
return aval_zeros_likers[type(aval)](aval)

aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}

def zeros_like_jaxval(val: ArrayLike) -> Array:
return zeros_like_p.bind(val)

zeros_like_p: Primitive = Primitive('zeros_like')

@zeros_like_p.def_impl
def zeros_like_impl(example):
return jaxval_zeros_likers[type(example)](example)

zeros_like_p.def_abstract_eval(lambda x: x)

class Zero:
__slots__ = ['aval']
Expand Down Expand Up @@ -128,3 +106,9 @@ def replace_internal_symbolic_zeros(
def replace_rule_output_symbolic_zeros(
x: JaxTypeOrTracer | SymbolicZero) -> JaxTypeOrTracer | Zero:
return Zero(x.aval) if type(x) is SymbolicZero else x


# TODO(mattjj): remove these after fixing downstream users relying on them
add_jaxvals_p: Primitive = Primitive('add_any')
add_any_p = add_jaxvals_p
zeros_like_p: Primitive = Primitive('zeros_like')
22 changes: 0 additions & 22 deletions jax/_src/internal_test_util/test_harnesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,19 +703,6 @@ def _to_equivalence_class(dtype):
"dtypes_to_new_dtypes", dtype=dtype, new_dtype=new_dtype)


def _make_add_any_harness(name, *, shapes=((2,), (2,)), dtype=np.float32):
define(
ad_util.add_any_p,
f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}",
ad_util.add_jaxvals_p.bind,
list(map(lambda s: RandArg(s, dtype), shapes)),
dtype=dtype,
shapes=shapes)


for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
_make_add_any_harness("dtypes", dtype=dtype)

for dtype in jtu.dtypes.all:
shape: tuple[int, ...] = (20, 20)
define(
Expand Down Expand Up @@ -765,15 +752,6 @@ def _make_comparator_harness(name,
"broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
op=op, op_name=op_name)

for dtype in jtu.dtypes.all:
shape = (3, 4, 5)
define(
"zeros_like",
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
ad_util.zeros_like_p.bind, [RandArg(shape, dtype)],
shape=shape,
dtype=dtype)

for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
if np.issubdtype(dtype, np.unsignedinteger):
arg = np.array([0, 1, 2], dtype=dtype)
Expand Down
9 changes: 3 additions & 6 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from jax._src import core
from jax._src import source_info_util
from jax._src.ad_util import (
add_jaxvals, add_jaxvals_p, replace_internal_symbolic_zeros,
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval,
zeros_like_jaxval, zeros_like_p)
add_jaxvals, replace_internal_symbolic_zeros, zeros_like_jaxval,
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval)
from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
Expand Down Expand Up @@ -584,9 +584,6 @@ def zero_jvp(primitive, primals, tangents, **params):
return r, Zero.from_value(r)


deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)])
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))

def instantiate_zeros(tangent):
return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent

Expand Down
29 changes: 2 additions & 27 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from jax._src import core
from jax._src import source_info_util
from jax._src import linear_util as lu
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, Zero, SymbolicZero,
replace_rule_output_symbolic_zeros, instantiate)
from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
replace_rule_output_symbolic_zeros)
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
Expand Down Expand Up @@ -1126,27 +1125,3 @@ def bdim_at_front(x, bdim, size):
return broadcast(x, size, 0)
else:
return moveaxis(x, bdim, 0)

# sets up primitive batchers for ad_util and xla primitives

def add_batched(batched_args, batch_dims):
bdx, bdy = batch_dims
x, y = batched_args
if bdx == bdy:
return add_jaxvals(x, y), bdx
elif bdx is not_mapped:
x = broadcast(x, y.shape[bdy], bdy)
return add_jaxvals(x, y), bdy
elif bdy is not_mapped:
y = broadcast(y, x.shape[bdx], bdx)
return add_jaxvals(x, y), bdx
else:
x = moveaxis(x, bdx, bdy)
return add_jaxvals(x, y), bdy
primitive_batchers[add_jaxvals_p] = add_batched

def zeros_like_batched(batched_args, batch_dims):
val, = batched_args
bdim, = batch_dims
return zeros_like_jaxval(val), bdim
primitive_batchers[zeros_like_p] = zeros_like_batched
10 changes: 0 additions & 10 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,16 +1993,6 @@ def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> i
zero = ir_constant(np.array(value, dtypes.canonicalize_dtype(aval.dtype)))
return broadcast_in_dim(ctx, zero, aval, broadcast_dimensions=())

def zeros_like_lowering(ctx, x):
aval, = ctx.avals_in
assert isinstance(aval, core.ShapedArray), aval
return [full_like_aval(ctx, 0, aval)]
register_lowering(ad_util.zeros_like_p, zeros_like_lowering)

def add_jaxvals_lowering(ctx, x, y):
return [hlo.add(x, y)]
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)

register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])


Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/control_flow/solves.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import split_list, safe_map
import numpy as np
Expand Down Expand Up @@ -152,7 +151,7 @@ def _root_jvp(const_lengths, jaxprs, primals, tangents):
operator.neg, linearize_and_solve(*solution, *rhs))
# append aux, create symbolic zero tangents for the aux values
solution += aux
solution_dot += _map(lax.zeros_like_array, aux)
solution_dot += _map(ad_util.zeros_like_jaxval, aux)

return solution, solution_dot

Expand Down
13 changes: 4 additions & 9 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.interpreters import ad
Expand Down Expand Up @@ -1224,12 +1223,15 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None) ->

def zeros_like_shaped_array(aval: ShapedArray) -> Array:
assert isinstance(aval, ShapedArray)
if aval.dtype == dtypes.float0:
if dtypes.issubdtype(aval.dtype, dtypes.extended):
scalar_zero = aval.dtype._rules.zero(aval)
elif aval.dtype == dtypes.float0:
scalar_zero = np.zeros((), dtype=aval.dtype)
else:
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
return broadcast(scalar_zero, aval.shape)

ad_util.aval_adders[ShapedArray] = add
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array

def iota(dtype: DTypeLike, size: int) -> Array:
Expand Down Expand Up @@ -1507,16 +1509,9 @@ def _iter(tracer):
ShapedArray._iter = staticmethod(_iter)
core.DShapedArray._iter = staticmethod(_iter)

# Add some ad handlers that use (or could use) lax primitives

def zeros_like_array(x: ArrayLike) -> Array:
return full_like(x, 0)

for t in itertools.chain(
dtypes.python_scalar_dtypes.keys(), array_types, [array.ArrayImpl]):
ad_util.jaxval_adders[t] = add
ad_util.jaxval_zeros_likers[array.ArrayImpl] = zeros_like_array


### primitives

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,7 +2150,7 @@ def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers,
raise NotImplementedError(
"scatter_mul gradients are only implemented if `unique_indices=True`")
return lax.mul(x, scatter_add(
lax.zeros_like_array(x), i, g, dimension_numbers=dimension_numbers,
ad_util.zeros_like_jaxval(x), i, g, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode))

Expand Down
10 changes: 8 additions & 2 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from jax import numpy as jnp
from jax import tree_util

from jax._src import ad_util
from jax._src import api_util
from jax._src import api
from jax._src import basearray
Expand Down Expand Up @@ -409,7 +408,6 @@ def transpose(self, *_, **__) -> PRNGKeyArray: assert False
basearray.Array.register(PRNGKeyArrayImpl)

api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')
ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type]

def prngkeyarrayimpl_flatten(x):
return (x._base_array,), x._impl
Expand Down Expand Up @@ -593,9 +591,17 @@ def device_put_replicated(val, aval, sharding, devices):
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
return random_wrap(physical_result, impl=aval.dtype._impl)

@staticmethod
def tangent_dtype(_):
return dtypes.float0

# TODO(mattjj,frostig): even though the key dtype shouldn't appear in
# tangents, our ad.replace_float0s in custom_jvp/vjp means passing in zeros
# like the primal to user rules
@staticmethod
def zero(aval):
return lax_internal.zeros_like_shaped_array(aval.update(dtype=dtypes.float0))


class KeyTy(dtypes.ExtendedDType):
_impl: PRNGImpl # TODO(mattjj,frostig): protocol really
Expand Down
2 changes: 0 additions & 2 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,14 +1514,12 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
tf_impl[unconsumed_copy_p] = lambda x: x

tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
tf_impl[ad_util.zeros_like_p] = tf.zeros_like


def _add(x: TfVal, y: TfVal) -> TfVal:
return tf.raw_ops.AddV2(x=x, y=y)


tf_impl[ad_util.add_jaxvals_p] = _add
tf_impl[dispatch.device_put_p] = lambda x, device=None, src=None: x
tf_impl[lax_internal.copy_p] = lambda x: x

Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def linear_prop(prim, primals_in, series_in, **params):
deflinear(lax.conj_p)
deflinear(lax.imag_p)
deflinear(lax.add_p)
deflinear(ad_util.add_jaxvals_p)
deflinear(lax.sub_p)
deflinear(lax.convert_element_type_p)
deflinear(lax.broadcast_in_dim_p)
Expand Down
3 changes: 0 additions & 3 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
Vmappable as Vmappable,
Zero as Zero,
ZeroIfMapped as ZeroIfMapped,
add_batched as add_batched,
axis_primitive_batchers as axis_primitive_batchers,
batch as batch,
batch_custom_jvp_subtrace as batch_custom_jvp_subtrace,
Expand Down Expand Up @@ -73,6 +72,4 @@
vmappables as vmappables,
vtile as vtile,
zero_if_mapped as zero_if_mapped,
zeros_like_batched as zeros_like_batched,
zeros_like_p as zeros_like_p,
)
1 change: 1 addition & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6072,6 +6072,7 @@ def f(x):
self.assertTrue(any(' cos ' in line for line in l.output))


@jtu.with_config(jax_pprint_use_color=False)
class JaxprTest(jtu.JaxTestCase):

def test_scalar_literals(self):
Expand Down
6 changes: 2 additions & 4 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,7 @@ def testDotGrad(self, lhs_shape, rhs_shape, dtype):
atol=tol, rtol=tol)
# check that precision config is preserved
result, pullback = jax.vjp(dot, lhs, rhs)
gresult = lax.zeros_like_array(result)
s = str(jax.make_jaxpr(pullback)(gresult))
s = str(jax.make_jaxpr(pullback)(result))
assert "Precision.HIGHEST" in s

@jtu.sample_product(
Expand Down Expand Up @@ -422,8 +421,7 @@ def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
modes=["fwd", "rev"], atol=atol)
# check that precision config is preserved
result, pullback = jax.vjp(dot_general, lhs, rhs)
gresult = lax.zeros_like_array(result)
s = str(jax.make_jaxpr(pullback)(gresult))
s = str(jax.make_jaxpr(pullback)(result))
assert "Precision.HIGHEST" in s

def testDotPreferredElementType(self):
Expand Down

0 comments on commit 32e1a0c

Please sign in to comment.