Skip to content

Commit

Permalink
device_put_sharded: remove incorrect type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 21, 2020
1 parent 695e8d8 commit 05cc7e7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/api.py
Expand Up @@ -2031,7 +2031,7 @@ def device_put(x, device: Optional[xc.Device] = None):
return tree_map(lambda y: xla.device_put_p.bind(y, device=device), x)


def device_put_sharded(x: Sequence[Any], devices: Sequence[xc.Device]) -> pxla.ShardedDeviceArray:
def device_put_sharded(x: Sequence[Any], devices: Sequence[xc.Device]):
"""Transfers pre-sharded input to the specified devices, returning ShardedDeviceArrays.
Args:
Expand Down

0 comments on commit 05cc7e7

Please sign in to comment.