Skip to content

Commit

Permalink
MAINT Use a generator expression in tuple([... for ... in ...])
Browse files Browse the repository at this point in the history
In a few cases I also replaced tuple([*xs, *ys]) with (*xs, ys), because
tuple literals support unpacking as well.
  • Loading branch information
superbobry committed Sep 21, 2023
1 parent daa8ec2 commit df7f6a0
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 25 deletions.
2 changes: 1 addition & 1 deletion jax/_src/core.py
Expand Up @@ -2864,7 +2864,7 @@ def _check_call(ctx_factory, prim, in_atoms, params):
env: dict[Var, Atom] = {}
def substitute(aval: AbstractValue):
if isinstance(aval, DShapedArray):
aval = aval.update(shape=tuple([env.get(d, d) for d in aval.shape])) # type: ignore
aval = aval.update(shape=tuple(env.get(d, d) for d in aval.shape)) # type: ignore
return aval
for v, x in zip(call_jaxpr.invars, in_atoms):
if not typecompat(substitute(v.aval), x.aval):
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/interpreters/batching.py
Expand Up @@ -132,7 +132,7 @@ def move_axis(ax):
if self.stacked_axis < ax and ax <= dst:
return ax - 1
return ax
new_axes = tuple([(move_axis(ax), sizes) for ax, sizes in self.ragged_axes])
new_axes = tuple((move_axis(ax), sizes) for ax, sizes in self.ragged_axes)
return RaggedAxis(dst, new_axes)

def transpose_ragged_axes(dim: RaggedAxis, perm: tuple[int, ...]) -> RaggedAxis:
Expand Down Expand Up @@ -726,8 +726,8 @@ def resolve_ragged_axes(vals, dims):
idxs = {lengths_idx.val for d in dims if isinstance(d, RaggedAxis)
for (_, lengths_idx) in d.ragged_axes}
dims = [RaggedAxis(d.stacked_axis,
tuple([(ragged_axis, vals[lengths_idx.val])
for ragged_axis, lengths_idx in d.ragged_axes]))
tuple((ragged_axis, vals[lengths_idx.val])
for ragged_axis, lengths_idx in d.ragged_axes))
if isinstance(d, RaggedAxis) else d for d in dims]
vals = [x for i, x in enumerate(vals) if i not in idxs]
return vals, dims
Expand All @@ -741,8 +741,8 @@ def fetch(idx):
return out_vals[idx.val]

dims = [RaggedAxis(d.stacked_axis,
tuple([(ragged_axis, fetch(lengths_idx))
for ragged_axis, lengths_idx in d.ragged_axes]))
tuple((ragged_axis, fetch(lengths_idx))
for ragged_axis, lengths_idx in d.ragged_axes))
if isinstance(d, RaggedAxis) else d for d in dims]
return dims

Expand Down
6 changes: 3 additions & 3 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1449,7 +1449,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
result_names=result_names).name.value
tokens = [tokens_in.get(eff) for eff in effects]
args = tuple([*dim_var_values, *tokens, *args])
args = (*dim_var_values, *tokens, *args)
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(symbol_name),
flatten_lowering_ir_args(args))
Expand Down Expand Up @@ -2086,10 +2086,10 @@ def _wrapped_callback(*args):
# callback only the non-empty results, and we will create empty constants
# in the receiving computation.
# TODO(b/238239458): fix TPU Recv to work with empty arrays.
non_empty_out_vals = tuple([
non_empty_out_vals = tuple(
out_val
for out_val, result_aval in zip(out_vals, result_avals)
if not is_empty_shape(result_aval.shape)])
if not is_empty_shape(result_aval.shape))
return non_empty_out_vals
else:
return out_vals
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/partial_eval.py
Expand Up @@ -648,7 +648,7 @@ def trace_to_subjaxpr_nounits_dyn(
id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore
if c is not None}
fwds: list[int | None] = [id_map.get(id(c)) for c in res]
res = tuple([c for c, fwd in zip(res, fwds) if fwd is None])
res = tuple(c for c, fwd in zip(res, fwds) if fwd is None)

del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \
in_tracers, in_args, ans, out_tracers, out_avals
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -2858,7 +2858,7 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
"shape; got operand of shape {}, target broadcast shape {}, "
"broadcast_dimensions {} ")
raise TypeError(msg.format(
tuple([core.replace_tracer_for_error_message(d) for d in operand.shape]),
tuple(core.replace_tracer_for_error_message(d) for d in operand.shape),
shape, broadcast_dimensions))
if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -395,9 +395,8 @@ def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
)
splatted_idx, splatted_idx_avals = unzip2(splatted_idx_idx_avals)
if non_slice_idx:
(int_indexer_shape,) = set([idx_aval.shape for idx_aval
in splatted_idx_avals
if not isinstance(idx_aval, primitives.Slice)])
(int_indexer_shape,) = {idx_aval.shape for idx_aval in splatted_idx_avals
if not isinstance(idx_aval, primitives.Slice)}
else:
int_indexer_shape = ()
nd_indexer = NDIndexer(splatted_idx, ref_aval.shape, int_indexer_shape)
Expand Down
14 changes: 10 additions & 4 deletions jax/_src/scipy/fft.py
Expand Up @@ -165,13 +165,19 @@ def idctn(x: Array, type: int = 2,

def _dct_deinterleave(x: Array, axis: int) -> Array:
empty_slice = slice(None, None, None)
ix0 = tuple([slice(None, math.ceil(x.shape[axis]/2), 1) if i == axis else empty_slice for i in range(len(x.shape))])
ix1 = tuple([slice(math.ceil(x.shape[axis]/2), None, 1) if i == axis else empty_slice for i in range(len(x.shape))])
ix0 = tuple(
slice(None, math.ceil(x.shape[axis]/2), 1) if i == axis else empty_slice
for i in range(len(x.shape)))
ix1 = tuple(
slice(math.ceil(x.shape[axis]/2), None, 1) if i == axis else empty_slice
for i in range(len(x.shape)))
v0 = x[ix0]
v1 = lax.rev(x[ix1], (axis,))
out = jnp.zeros(x.shape, dtype=x.dtype)
evens = tuple([slice(None, None, 2) if i == axis else empty_slice for i in range(len(x.shape))])
odds = tuple([slice(1, None, 2) if i == axis else empty_slice for i in range(len(x.shape))])
evens = tuple(
slice(None, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
odds = tuple(
slice(1, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
out = out.at[evens].set(v0)
out = out.at[odds].set(v1)
return out
2 changes: 1 addition & 1 deletion jax/_src/state/discharge.py
Expand Up @@ -174,7 +174,7 @@ def _prepend_scatter(x, idx, indexed_dims, val, *, add=False):

def _indexer(idx, indexed_dims):
idx_ = iter(idx)
indexer = tuple([next(idx_) if b else slice(None) for b in indexed_dims])
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
assert next(idx_, None) is None
return indexer

Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/impl_no_xla.py
Expand Up @@ -1154,7 +1154,7 @@ def shift_axes_forward(operand,
inverse: bool = False,
forward: bool = True):
"""Shifts the tuple of axes to the front of an array"""
other_axes = tuple([i for i in range(len(operand.shape)) if i not in axes])
other_axes = tuple(i for i in range(len(operand.shape)) if i not in axes)
fwd_order = axes + other_axes if forward else other_axes + axes
order = fwd_order if not inverse else _invert_permutation(fwd_order)
return tf.transpose(operand, order)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/sparse/bcoo.py
Expand Up @@ -673,7 +673,7 @@ def _bcoo_rdot_general(lhs: Array, rhs_data: Array, rhs_indices: Array, *,
preferred_element_type=preferred_element_type)
n_contract, n_batch = (len(d[0]) for d in dimension_numbers)
n_swap = len(rhs_spinfo.shape) - n_contract
permutation = tuple([*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap)])
permutation = (*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap))
return lax.transpose(result, permutation)

def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers,
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/sparse/transform.py
Expand Up @@ -825,10 +825,10 @@ def _scan_sparse(spenv, *spvalues, jaxpr, num_consts, num_carry, **params):
# params['linear'] has one entry per arg; expand it to match the sparsified args.
const_linear, carry_linear, xs_linear = split_list(
params.pop('linear'), [num_consts, num_carry])
sp_linear = tuple([
sp_linear = (
*_duplicate_for_sparse_spvalues(const_spvalues, const_linear),
*_duplicate_for_sparse_spvalues(carry_spvalues, carry_linear),
*_duplicate_for_sparse_spvalues(xs_spvalues, xs_linear)])
*_duplicate_for_sparse_spvalues(xs_spvalues, xs_linear))

out = lax.scan_p.bind(*consts, *carry, *xs, jaxpr=sp_jaxpr, linear=sp_linear,
num_consts=len(consts), num_carry=len(carry), **params)
Expand Down
4 changes: 2 additions & 2 deletions tests/state_test.py
Expand Up @@ -418,7 +418,7 @@ def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims,
int_ = (jnp.dtype('int64') if jax.config.jax_enable_x64 else
jnp.dtype('int32'))
axis_size = 7
out_shape = tuple([d for d, b in zip(ref_shape, indexed_dims) if not b])
out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b)
if any(indexed_dims):
out_shape = (*idx_shape, *out_shape)

Expand All @@ -439,7 +439,7 @@ def maybe_insert(shape, idx):

def f(x_ref, *idxs):
idxs_ = iter(idxs)
indexer = tuple([next(idxs_) if b else slice(None) for b in indexed_dims])
indexer = tuple(next(idxs_) if b else slice(None) for b in indexed_dims)
return op(x_ref, indexer)

rng = self.rng()
Expand Down

0 comments on commit df7f6a0

Please sign in to comment.