From f2ce5dbd0146924c21fae1e8ba42af3601448005 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Oct 2023 15:11:15 +0100 Subject: [PATCH] MAINT Do not use `str()` and `repr()` in f-string replacement fields `str()` is called by default by the formatting machinery, and `repr()` only needs `!r`. --- jax/__init__.py | 2 +- jax/_src/api_util.py | 2 +- jax/_src/array.py | 4 ++-- jax/_src/checkify.py | 2 +- jax/_src/config.py | 2 +- jax/_src/core.py | 8 ++++---- jax/_src/interpreters/batching.py | 4 ++-- jax/_src/interpreters/mlir.py | 2 +- jax/_src/lax/lax.py | 4 ++-- jax/_src/numpy/array_methods.py | 4 ++-- jax/_src/numpy/lax_numpy.py | 2 +- jax/_src/pjit.py | 4 ++-- jax/_src/sharding_impls.py | 4 ++-- jax/_src/source_info_util.py | 2 +- jax/_src/test_util.py | 2 +- jax/_src/tree_util.py | 8 ++++---- jax/experimental/export/shape_poly.py | 2 +- jax/experimental/jax2tf/impl_no_xla.py | 2 +- .../jax2tf/tests/back_compat_test_util.py | 12 ++++++------ 19 files changed, 36 insertions(+), 36 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index fc34f39ebf72..18d5b304fc0c 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -28,7 +28,7 @@ except Exception as exc: # Defensively swallow any exceptions to avoid making jax unimportable from warnings import warn as _warn - _warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report " + _warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report " f"an issue at https://github.com/google/jax/issues") del _warn del _cloud_tpu_init diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 58b661ad81fb..9c098c99b80b 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -123,7 +123,7 @@ def flatten_fun_nokwargs2(in_tree, *args_flat): pair = yield py_args, {} if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise TypeError("expected function with aux output to return a two-element " - f"tuple, but got type {type(pair)} with value {repr(pair)}") + f"tuple, but got type {type(pair)} with value {pair!r}") ans, aux = pair ans_flat, ans_tree = tree_flatten(ans) aux_flat, aux_tree = tree_flatten(aux) diff --git a/jax/_src/array.py b/jax/_src/array.py index e39142cc95e5..aa8ff62cb919 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -71,10 +71,10 @@ def __init__(self, device: Device, sharding: Sharding, global_shape: Shape, def __repr__(self): try: - return (f'Shard(device={repr(self.device)}, index={self.index}, ' + return (f'Shard(device={self.device!r}, index={self.index}, ' f'replica_id={self.replica_id}, data={self.data})') except ValueError: - return f'Shard(device={repr(self.device)}, data={self.data})' + return f'Shard(device={self.device!r}, data={self.data})' @functools.cached_property def index(self) -> Index: diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 11a293ff70e2..e151ea3f2c31 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -1176,7 +1176,7 @@ def _check(pred, msg, debug, *fmt_args, **fmt_kwargs): if not isinstance(arg, (Array, np.ndarray)): raise TypeError('Formatting arguments to checkify.check need to be ' 'PyTrees of arrays, but got ' - f'{repr(arg)} of type {type(arg)}.') + f'{arg!r} of type {type(arg)}.') new_error = FailedCheckError(get_traceback(), msg, *fmt_args, **fmt_kwargs) error = assert_func(init_error, jnp.logical_not(pred), new_error) _check_error(error, debug=debug) diff --git a/jax/_src/config.py b/jax/_src/config.py index 9e7bc5902240..f930bb99d4c4 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1015,7 +1015,7 @@ def _validate_default_device(val): repr(val), type(val)) return raise ValueError('jax.default_device must be passed a Device object (e.g. ' - f"`jax.devices('cpu')[0]`), got: {repr(val)}") + f"`jax.devices('cpu')[0]`), got: {val!r}") # TODO(skye): default_device only accepts devices for now. Make it work with diff --git a/jax/_src/core.py b/jax/_src/core.py index 8c41dc6f2ebd..6ceb3d48ec9f 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1115,7 +1115,7 @@ def _why_alive_container_info(container, obj_id) -> str: ', '.join(map(repr, keys))) if hasattr(container, '__dict__'): keys = [k for k in vars(container) if id(vars(container)[k]) == obj_id] - if len(keys) == 1: return f'{name}.{str(keys[0])}' + if len(keys) == 1: return f'{name}.{keys[0]}' elif len(keys) > 1: return f'{name} in vars ' + ', '.join(map(repr, keys)) if isinstance(container, (list, tuple)): idxs = [i for i, x in enumerate(container) if id(x) == obj_id] @@ -1123,7 +1123,7 @@ def _why_alive_container_info(container, obj_id) -> str: else: return f'{name} at indices ' + ', '.join(map(str, idxs)) if isinstance(container, dict): keys = [k for k in container if id(container[k]) == obj_id] - if len(keys) == 1: return f'{name}[{repr(keys[0])}]' + if len(keys) == 1: return f'{name}[{keys[0]!r}]' else: return f'{name} at keys ' + ', '.join(map(repr, keys)) if isinstance(container, types.ModuleType): return f' named {container.__name__}' @@ -1411,7 +1411,7 @@ def valid_jaxtype(x) -> bool: def check_valid_jaxtype(x): if not valid_jaxtype(x): raise TypeError( - f"Value {repr(x)} of type {type(x)} is not a valid JAX type") + f"Value {x!r} of type {type(x)} is not a valid JAX type") def concrete_aval(x): @@ -1420,7 +1420,7 @@ def concrete_aval(x): if handler: return handler(x) if hasattr(x, '__jax_array__'): return concrete_aval(x.__jax_array__()) - raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX " + raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX " "type") diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 37d449be8516..aa33bb4dd726 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -61,7 +61,7 @@ class IndexedAxisSize: idx: core.Var lengths: Array | core.Var | Tracer def __repr__(self) -> str: - return f'{str(self.lengths)}.Var{id(self.idx)}' + return f'{self.lengths}.Var{id(self.idx)}' replace = dataclasses.replace # Jumble(aval=a:3 => f32[[3 1 4].a], @@ -1101,7 +1101,7 @@ def matchaxis(axis_name, sz, src, dst, x, sum_match=False): try: _ = core.get_aval(x) except TypeError as e: - raise TypeError(f"Output from batched function {repr(x)} with type " + raise TypeError(f"Output from batched function {x!r} with type " f"{type(x)} is not a valid JAX type") from e if src == dst: return x diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 9d04e9d29ddf..e576c4907e3a 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -346,7 +346,7 @@ def _source_info_to_location( primitive: core.Primitive, params: dict, source_info: source_info_util.SourceInfo, name_stack: source_info_util.NameStack) -> ir.Location: - eqn_str = (f'{str(source_info.name_stack)}/' + eqn_str = (f'{source_info.name_stack}/' f'{core.str_eqn_compact(primitive.name, params)}') if config.include_full_tracebacks_in_locations.value: if source_info.traceback is None: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9b8a2485b677..e1026d5dfae0 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2442,8 +2442,8 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype): if dim_size * old_dtype.itemsize != new_dtype.itemsize: raise ValueError( f"Attempting to convert array of shape {operand.shape} " - f"from {str(old_dtype)} of size {old_dtype.itemsize} " - f"to {str(new_dtype)} of size {new_dtype.itemsize}, " + f"from {old_dtype} of size {old_dtype.itemsize} " + f"to {new_dtype} of size {new_dtype.itemsize}, " f"but {dim_size} * {old_dtype.itemsize} != {new_dtype.itemsize}") return operand.shape[:-1] diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 5f7dca10a3ab..84b34d5fc32e 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -435,7 +435,7 @@ def __getitem__(self, index): return _IndexUpdateRef(self.array, index) def __repr__(self): - return f"_IndexUpdateHelper({repr(self.array)})" + return f"_IndexUpdateHelper({self.array!r})" class _IndexUpdateRef: @@ -452,7 +452,7 @@ def __init__(self, array, index): self.index = index def __repr__(self): - return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})" + return f"_IndexUpdateRef({self.array!r}, {self.index!r})" def get(self, *, indices_are_sorted=False, unique_indices=False, mode=None, fill_value=None): diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index dadc7d887cce..83b8afcf7717 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4154,7 +4154,7 @@ def take_along_axis( idx_shape = shape(indices) if not dtypes.issubdtype(index_dtype, integer): raise TypeError("take_along_axis indices must be of integer type, got " - f"{str(index_dtype)}") + f"{index_dtype}") if axis is None: if ndim(indices) != 1: msg = "take_along_axis indices must be 1D if axis=None, got shape {}" diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 21cb48e66863..6e654453335f 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -989,7 +989,7 @@ def pjit_check_aval_sharding( except ValueError as e: raise ValueError( f'One of {what_aval}{name_str} is incompatible with its sharding ' - f'annotation {s}: {str(e)}') + f'annotation {s}: {e}') # Use the `OpSharding` proto to find out how many ways each dimension of # the aval is sharded. This approach will work across all # XLACompatibleSharding. @@ -1467,7 +1467,7 @@ def _pjit_batcher_for_sharding( if isinstance(getattr(s, '_original_sharding', None), NamedSharding): mesh = s._original_sharding.mesh # type: ignore if mesh is None or mesh.empty: - s_type = (f', got: {repr(s._original_sharding)}' + s_type = (f', got: {s._original_sharding!r}' if hasattr(s, '_original_sharding') else '') raise ValueError( 'If you are using xmap or spmd_axis_name parameter of jax.vmap,' diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 520154c7ff6f..70743d559c9b 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -397,7 +397,7 @@ def __reduce__(self): def __repr__(self): mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' - return f"SingleDeviceSharding(device={repr(self._device)}{mem})" + return f"SingleDeviceSharding(device={self._device!r}{mem})" def __hash__(self): if not hasattr(self, '_hash'): @@ -873,7 +873,7 @@ def __hash__(self): def __repr__(self): mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' - return f'GSPMDSharding({repr(self._hlo_sharding)}{mem})' + return f'GSPMDSharding({self._hlo_sharding!r}{mem})' def is_compatible_aval(self, aval_shape: Shape): num_ways_dim_sharded, _ = get_num_ways_dim_sharded(self._hlo_sharding) diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index 0226e6b35b30..efec6cd2e561 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -90,7 +90,7 @@ def extend(self, name: Union[tuple[str, ...], str]) -> 'NameStack': def wrap_name(self, name: str) -> str: if not self.stack: return name - return f'{str(self)}/{name}' + return f'{self}/{name}' def transform(self, transform_name: str) -> 'NameStack': return NameStack((*self.stack, Transform(transform_name))) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 015a688a53b9..988a1dc88c23 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1275,7 +1275,7 @@ def parameterized_filterable(*, for kw in kwargs: testcase_name = kw.get("testcase_name") if testcase_name is None: - testcase_name = "_".join(f"{k}={str(kw[k])}" # type: ignore + testcase_name = "_".join(f"{k}={kw[k]}" # type: ignore for k in sorted(kw.keys())) kw["testcase_name"] = sanitize_test_name(testcase_name) # type: ignore diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 4a7a832999b2..04e565f30222 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -361,7 +361,7 @@ def __eq__(self, other): return self.fun == other def __repr__(self): - return f'_HashableCallableShim({repr(self.fun)})' + return f'_HashableCallableShim({self.fun!r})' class Partial(functools.partial): @@ -565,7 +565,7 @@ def pprint(self) -> str: class GetitemKeyPathEntry(_DeprecatedKeyPathEntry): def pprint(self) -> str: - return f'[{repr(self.key)}]' + return f'[{self.key!r}]' def __str__(self): return self.pprint() @@ -579,13 +579,13 @@ def __str__(self): class SequenceKey(): idx: int def __str__(self): - return f'[{repr(self.idx)}]' + return f'[{self.idx!r}]' @dataclass(frozen=True) class DictKey(): key: Hashable def __str__(self): - return f'[{repr(self.key)}]' + return f'[{self.key!r}]' @dataclass(frozen=True) class GetAttrKey(): diff --git a/jax/experimental/export/shape_poly.py b/jax/experimental/export/shape_poly.py index 0edacf0ada23..9a63d1678c14 100644 --- a/jax/experimental/export/shape_poly.py +++ b/jax/experimental/export/shape_poly.py @@ -910,7 +910,7 @@ def __init__(self, *dim_specs): def __new__(cls, *dim_specs): for ds in dim_specs: if not isinstance(ds, (int, str)) and ds != ...: - msg = (f"Invalid polymorphic shape element: {repr(ds)}; must be a string " + msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string " "representing a dimension variable, or an integer, or ...") raise ValueError(msg) return tuple.__new__(PolyShape, dim_specs) diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index d3d001ec3e68..a85a7d06d251 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -1090,7 +1090,7 @@ def _gather(operand, start_indices, *, dimension_numbers, try: return gather_fn(gather_args) except ValueError as e: - errors.append(f"{gather_fn}: {repr(e)}") + errors.append(f"{gather_fn}: {e!r}") error_msg = (f"Unsupported arguments for gather: {gather_args}, errors:\n" + "\n".join(errors)) diff --git a/jax/experimental/jax2tf/tests/back_compat_test_util.py b/jax/experimental/jax2tf/tests/back_compat_test_util.py index 70d059285d37..45a7e540c094 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test_util.py +++ b/jax/experimental/jax2tf/tests/back_compat_test_util.py @@ -214,13 +214,13 @@ def run_one_test(self, func: Callable[..., jax.Array], # Pasted from the test output (see back_compat_test_util.py module docstring) data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( testdata_version={CURRENT_TESTDATA_VERSION}, - platform={repr(self.default_jax_backend())}, - custom_call_targets={repr(current_custom_call_targets)}, - serialized_date={repr(datetime.date.today())}, - inputs={repr(data.inputs)}, - expected_outputs={repr(res_run_current)}, + platform={self.default_jax_backend()!r}, + custom_call_targets={current_custom_call_targets!r}, + serialized_date={datetime.date.today()!r}, + inputs={data.inputs!r}, + expected_outputs={res_run_current!r}, mlir_module_text=r\"\"\"\n{module_str}\"\"\", - mlir_module_serialized={repr(serialized)}, + mlir_module_serialized={serialized!r}, xla_call_module_version={module_version}, ) # End paste