Skip to content

Commit

Permalink
Cleanups for laziness. No functional changes intended.
Browse files Browse the repository at this point in the history
Use None as a trivial lazy expression in more places. Simplify some code.
  • Loading branch information
hawkinsp committed Mar 7, 2021
1 parent 12c2d0d commit 2469ad1
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 28 deletions.
3 changes: 1 addition & 2 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from jax import core
from jax import numpy as jnp
from jax import lazy
from jax.interpreters import xla
from jax.lib import xla_client
from jax.lib import xla_bridge
Expand Down Expand Up @@ -62,4 +61,4 @@ def from_dlpack(dlpack, backend=None):
xla_shape = buf.xla_shape()
assert not xla_shape.is_tuple()
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
return xla.make_device_array(aval, buf.device(), lazy.array(aval.shape), buf) # pytype: disable=attribute-error
return xla.make_device_array(aval, buf.device(), None, buf) # pytype: disable=attribute-error
18 changes: 6 additions & 12 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3311,10 +3311,8 @@ def _broadcast_in_dim_impl(operand, *, shape, broadcast_dimensions):
shape = _broadcast_in_dim_shape_rule(
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
aval = ShapedArray(shape, _dtype(operand), weak_type=dtypes.is_weakly_typed(operand))
if operand._lazy_expr is None:
lazy_expr = lazy.broadcast(lazy.array(operand.shape), shape, broadcast_dimensions)
else:
lazy_expr = lazy.broadcast(operand._lazy_expr, shape, broadcast_dimensions)
lazy_expr = operand._lazy_expr or lazy.array(operand.shape)
lazy_expr = lazy.broadcast(lazy_expr, shape, broadcast_dimensions)
return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
else:
return xla.apply_primitive(broadcast_in_dim_p, operand, shape=shape,
Expand Down Expand Up @@ -3627,10 +3625,8 @@ def _reshape_impl(operand, *, new_sizes, dimensions):
bcast_dims = _is_singleton_reshape(old_sizes, new_sizes)
if bcast_dims is not None:
aval = ShapedArray(new_sizes, operand.dtype)
if operand._lazy_expr is None:
lazy_expr = lazy.broadcast(lazy.array(operand.shape), new_sizes, bcast_dims)
else:
lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims)
lazy_expr = operand._lazy_expr or lazy.array(operand.shape)
lazy_expr = lazy.broadcast(lazy_expr, new_sizes, bcast_dims)
return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,
dimensions=dimensions)
Expand Down Expand Up @@ -3743,10 +3739,8 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions):

def _transpose_impl(operand, *, permutation):
if xla.type_is_device_array(operand):
if operand._lazy_expr is None:
lazy_expr = lazy.transpose(lazy.array(operand.shape), permutation)
else:
lazy_expr = lazy.transpose(operand._lazy_expr, permutation)
lazy_expr = operand._lazy_expr or lazy.array(operand.shape)
lazy_expr = lazy.transpose(lazy_expr, permutation)
aval = ShapedArray(lazy_expr.shape, operand.dtype)
return xla._DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
else:
Expand Down
2 changes: 1 addition & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def cache_miss(_, *args, **kwargs):
for result_handler in result_handlers:
aval, sticky_device, lazy_expr = result_handler.args
avals.append(aval)
lazy_exprs.append(None if xla.lazy.is_trivial(lazy_expr) else lazy_expr)
lazy_exprs.append(lazy_expr)
assert len(avals) == len(out_flat)
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
else:
Expand Down
3 changes: 1 addition & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from ..config import flags, config
from .. import core
from .. import linear_util as lu
from .. import lazy
from ..abstract_arrays import array_types
from ..core import ConcreteArray, ShapedArray
from .._src.util import (partial, unzip2, unzip3, prod, safe_map, safe_zip,
Expand Down Expand Up @@ -563,7 +562,7 @@ def __getitem__(self, idx):
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, lazy.array(aval.shape), buf)
return xla.make_device_array(aval, None, None, buf)
return xla.DeviceArray.__getitem__(self, idx)


Expand Down
15 changes: 7 additions & 8 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,15 +1320,14 @@ def _copy_device_array_to_device(x: Union[DeviceArrayProtocol, _DeviceArray], de
def _force(x: DeviceArrayProtocol) -> DeviceArrayProtocol:
if lazy.is_trivial(x._lazy_expr):
return x
# force x on the device where it lives, but preserve stickiness on result
if x._device:
device = x._device
else:
# force x on the device where it lives, but preserve stickiness on result
if x._device:
device = x._device
else:
device = x.device_buffer.device()
force_fun = _lazy_force_computation(x.aval, device, x._lazy_expr)
result = force_fun(x)
return make_device_array(x.aval, x._device, None, result)
device = x.device_buffer.device()
force_fun = _lazy_force_computation(x.aval, device, x._lazy_expr)
result = force_fun(x)
return make_device_array(x.aval, x._device, None, result)

@cache()
def _lazy_force_computation(aval: core.ShapedArray,
Expand Down
6 changes: 3 additions & 3 deletions tests/custom_object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from jax import test_util as jtu
import jax.numpy as jnp
from jax import core, jit, lax, lazy, make_jaxpr
from jax import core, jit, lax, make_jaxpr
from jax.interpreters import xla
from jax.lib import xla_client
xops = xla_client.ops
Expand Down Expand Up @@ -101,8 +101,8 @@ class ConcreteSparseArray(AbstractSparseArray):

def sparse_array_result_handler(device, aval):
def build_sparse_array(data_buf, indices_buf):
data = xla.make_device_array(aval.data_aval, device, lazy.array(aval.data_aval.shape), data_buf)
indices = xla.make_device_array(aval.indices_aval, device, lazy.array(aval.indices_aval.shape), indices_buf)
data = xla.make_device_array(aval.data_aval, device, None, data_buf)
indices = xla.make_device_array(aval.indices_aval, device, None, indices_buf)
return SparseArray(aval, data, indices)
return build_sparse_array

Expand Down

0 comments on commit 2469ad1

Please sign in to comment.