Skip to content

Commit

Permalink
[JAX] Remove references to deprecated types DeviceArray and ShardedDe…
Browse files Browse the repository at this point in the history
…viceArray.

Both are slated for deletion, and neither can be produced at runtime if jax.Array is enabled, which will be the only option as of March 15 (go/jax-array).

PiperOrigin-RevId: 517200975
  • Loading branch information
hawkinsp authored and romanngg committed Apr 19, 2023
1 parent 5f2dab2 commit 9eafe24
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions neural_tangents/_src/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from jax import device_put, devices
from jax import jit
from jax import pmap
from jax.interpreters.pxla import ShardedDeviceArray
from jax import random
import jax.numpy as np
from jax.tree_util import tree_all
Expand Down Expand Up @@ -706,9 +705,6 @@ def jit_or_pmap_broadcast(f: Callable, device_count: int = -1) -> Callable:
def broadcast(arg: np.ndarray) -> np.ndarray:
if device_count == 0:
return arg
# If the argument has already been sharded, no need to broadcast it.
if isinstance(arg, ShardedDeviceArray) and arg.shape[0] == device_count:
return arg
return np.broadcast_to(arg, (device_count,) + arg.shape)

@utils.wraps(f)
Expand Down

0 comments on commit 9eafe24

Please sign in to comment.