Skip to content

Commit

Permalink
Merge pull request #12476 from jakevdp:match-sharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 476220059
  • Loading branch information
jax authors committed Sep 22, 2022
2 parents dfdf00c + 3d23592 commit 11a6fd9
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,7 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:

### convenience wrappers around traceables

def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
def full_like(x: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Optional[Shape] = None) -> Array:
"""Create a full array like np.full based on the example array `x`.
Expand All @@ -1325,7 +1325,8 @@ def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None
An ndarray with the same shape as `x` with its entries set equal to
`fill_value`, similar to the output of np.full.
"""
from jax.experimental import sharding, array
from jax.experimental import array
from jax.experimental.sharding import PmapSharding

fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
weak_type = dtype is None and dtypes.is_weakly_typed(x)
Expand All @@ -1338,11 +1339,12 @@ def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None
# probably in the form of a primitive like `val = match_sharding_p.bind(x, val)`
# (so it works in staged-out code as well as 'eager' code). Related to
# equi-sharding.
if (config.jax_array and hasattr(x, 'sharding') and
not dispatch.is_single_device_sharding(x.sharding) and
not isinstance(x.sharding, sharding.PmapSharding)):
return array.make_array_from_callback(
fill_shape, x.sharding, lambda idx: val[idx]) # type: ignore[arg-type]
if config.jax_array and shape is None and hasattr(x, 'sharding'):
sharding = x.sharding # type: ignore[union-attr]
if (not dispatch.is_single_device_sharding(sharding) and
not isinstance(sharding, PmapSharding)):
return array.make_array_from_callback(
type_cast(array.Shape, fill_shape), sharding, lambda idx: val[idx])
return val


Expand Down

0 comments on commit 11a6fd9

Please sign in to comment.