Skip to content

Commit

Permalink
cleanup: use itertools.chain.from_iterable
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 18, 2020
1 parent 8bb730e commit eeb0e3d
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 12 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/host_callback.py
Expand Up @@ -913,7 +913,7 @@ def _initialize_outfeed_receiver(
clients = xla_client._get_local_backends().values() # type: ignore[protected-class]
# Drop the interpreter clients
clients = tuple([c for c in clients if c.platform != "interpreter"]) # type: ignore
devices = list(itertools.chain(*[backend.devices() for backend in clients]))
devices = list(itertools.chain.from_iterable(backend.devices() for backend in clients))
_outfeed_receiver.clients = clients # type: ignore[assignment]
_outfeed_receiver.devices = devices # type: ignore[assignment]
logging.vlog(
Expand Down
11 changes: 4 additions & 7 deletions jax/interpreters/pxla.py
Expand Up @@ -237,20 +237,17 @@ def shard_args(devices: Sequence[xb.xla_client.Device],

shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
shard_arg_handlers[core.Unit] = \
lambda x, devices, _: list(it.chain(
*(xla.device_put(core.unit, d) for d in devices)))
lambda x, devices, _: list(it.chain.from_iterable(xla.device_put(core.unit, d) for d in devices))
def _shard_array(x, devices, indices):
return list(it.chain(
*(xla.device_put(x[i], d) for (i, d) in zip(indices, devices))))
return list(it.chain.from_iterable(xla.device_put(x[i], d) for (i, d) in zip(indices, devices)))
for _t in array_types:
shard_arg_handlers[_t] = _shard_array

def _shard_device_array(x, devices, indices):
start_indices, limit_indices, removed_dims = map(tuple, unzip3(
_as_slice_indices(x, idx) for idx in indices))
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
return list(it.chain(
*(xla.device_put(s, d) for s, d in zip(shards, devices))))
return list(it.chain.from_iterable(xla.device_put(s, d) for s, d in zip(shards, devices)))
shard_arg_handlers[xla.DeviceArray] = _shard_device_array

# NOTE(skye): we could refactor to generate _multi_slice parameters directly
Expand Down Expand Up @@ -865,7 +862,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):
replicated_aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
# TODO(skye): figure out how partitioning should work here
sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, True)
device_buffers = list(it.chain(*(xla.device_put(val, d) for d in devices)))
device_buffers = list(it.chain.from_iterable(xla.device_put(val, d) for d in devices))
return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)

def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped):
Expand Down
5 changes: 2 additions & 3 deletions jax/interpreters/xla.py
Expand Up @@ -782,16 +782,15 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions):

def _execute_compiled(compiled: XlaExecutable, nouts, handlers, *args):
device, = compiled.local_devices()
input_bufs = list(it.chain(
*(device_put(x, device) for x in args if x is not token)))
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
out_bufs = compiled.execute(input_bufs)
if FLAGS.jax_debug_nans:
check_nans(xla_call_p, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(nouts, out_bufs))]

def _execute_replicated(compiled: XlaExecutable, nouts, handlers, *args):
input_bufs = [
list(it.chain(*(device_put(x, device) for x in args if x is not token)))
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
for device in compiled.local_devices()]
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
if FLAGS.jax_debug_nans:
Expand Down
2 changes: 1 addition & 1 deletion tests/pmap_test.py
Expand Up @@ -1084,7 +1084,7 @@ def testReshardInput(self):
# subsequent pmap
shard_shape = (3,2)
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
bufs = list(it.chain(*(xla.device_put(shard, d) for d in xla_bridge.devices()[:4])))
bufs = list(it.chain.from_iterable(xla.device_put(shard, d) for d in xla_bridge.devices()[:4]))
aval = ShapedArray((6,4), shard.dtype)
sharding_spec = pxla.ShardingSpec(
shards_per_axis=(2, 2),
Expand Down

0 comments on commit eeb0e3d

Please sign in to comment.