Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
4e9782b by Jake VanderPlas <jakevdp@google.com>:

Make array_copy a primitive

Co-authored-by: Matthew Johnson <mattjj@google.com>
COPYBARA_INTEGRATE_REVIEW=#9264 from jakevdp:copy-primitive 4e9782b
PiperOrigin-RevId: 423328684
  • Loading branch information
2 people authored and jax authors committed Jan 21, 2022
1 parent 9074ed4 commit 31b5308
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
33 changes: 15 additions & 18 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,6 @@

T = TypeVar("T")

@functools.partial(jax.jit, inline=True)
def _array_copy(arr):
"""Return an on-device copy of a DeviceArray.
This is a private method; users can access this via ``jnp.array(x, copy=True)``.
Why do we need copies in a purely functional langauge? Well, JAX is *almost*
purely functional: the semantics of `donate_argnums` mean that sometimes buffers
are consumed, and you actually need to ensure a copy is generated on device.
"""
# TODO(jakevdp): There is no XLA copy operation, so for the time being we rely
# on an implementation detail: although XLA will optimize away non-operations like
# adding zero, it still results in a copied buffer. Eventually, we should move to
# a more direct method that avoids inserting a spurious add_p/or_p into the jaxpr.
if arr.dtype == bool:
return bitwise_or(arr, _const(arr, False))
return add(arr, _const(arr, 0))

def _try_broadcast_shapes(
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
assert shapes
Expand Down Expand Up @@ -4182,6 +4164,21 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
RandomAlgorithm = xops.RandomAlgorithm
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[assignment]

def _array_copy(arr):
return copy_p.bind(arr)

# The copy_p primitive exists for expressing making copies of runtime arrays.
# For that reason we don't simplify it out of jaxprs (e.g. for jit invariance).
# It's used in jnp.array(x, copy=True), which is the user-facing API.
copy_p = core.Primitive('copy')
copy_p.def_impl(partial(xla.apply_primitive, copy_p))
copy_p.def_abstract_eval(lambda x: x)
xla.register_translation(copy_p, lambda ctx, avals_in, avals_out, x: [x])
mlir.register_lowering(copy_p, lambda ctx, x: [x])
ad.deflinear(copy_p, lambda t: [copy_p.bind(t)])
batching.defvectorized(copy_p)
masking.defvectorized(copy_p)


def rng_bit_generator(key, shape, dtype=np.uint32,
algorithm=RandomAlgorithm.RNG_DEFAULT):
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal:

tf_impl[ad_util.add_jaxvals_p] = _add
tf_impl[dispatch.device_put_p] = lambda x, device=None: x
tf_impl[lax_internal.copy_p] = lambda x: x

def _neg(x: TfVal) -> TfVal:
if x.dtype.is_unsigned:
Expand Down
21 changes: 21 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3745,6 +3745,27 @@ def testArrayCopy(self, dtype):
self.assertFalse(x_copy.is_deleted())
self.assertFalse(x_copy_jit.is_deleted())

def testArrayCopyAutodiff(self):
f = lambda x: jnp.array(x, copy=True)

x = jnp.ones(10)
xdot = jnp.ones(10)
y, ydot = jax.jvp(f, (x,), (xdot,))
self.assertIsNot(x, y)
self.assertIsNot(xdot, ydot)

ybar = jnp.ones(10)
y, f_vjp = jax.vjp(f, x)
xbar, = f_vjp(ybar)
self.assertIsNot(x, y)
self.assertIsNot(xbar, ybar)

def testArrayCopyVmap(self):
f = lambda x: jnp.array(x, copy=True)
x = jnp.ones(10)
y = jax.vmap(f)(x)
self.assertIsNot(x, y)

def testArrayUnsupportedDtypeError(self):
with self.assertRaisesRegex(TypeError,
"JAX only supports number and bool dtypes.*"):
Expand Down

0 comments on commit 31b5308

Please sign in to comment.