Skip to content

Commit

Permalink
MAINT Do not use str() and repr() in f-string replacement fields
Browse files Browse the repository at this point in the history
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
  • Loading branch information
superbobry committed Oct 23, 2023
1 parent c569f8e commit f2ce5db
Show file tree
Hide file tree
Showing 19 changed files with 36 additions and 36 deletions.
2 changes: 1 addition & 1 deletion jax/__init__.py
Expand Up @@ -28,7 +28,7 @@
except Exception as exc:
# Defensively swallow any exceptions to avoid making jax unimportable
from warnings import warn as _warn
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
_warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report "
f"an issue at https://github.com/google/jax/issues")
del _warn
del _cloud_tpu_init
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/api_util.py
Expand Up @@ -123,7 +123,7 @@ def flatten_fun_nokwargs2(in_tree, *args_flat):
pair = yield py_args, {}
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
raise TypeError("expected function with aux output to return a two-element "
f"tuple, but got type {type(pair)} with value {repr(pair)}")
f"tuple, but got type {type(pair)} with value {pair!r}")
ans, aux = pair
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/array.py
Expand Up @@ -71,10 +71,10 @@ def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,

def __repr__(self):
try:
return (f'Shard(device={repr(self.device)}, index={self.index}, '
return (f'Shard(device={self.device!r}, index={self.index}, '
f'replica_id={self.replica_id}, data={self.data})')
except ValueError:
return f'Shard(device={repr(self.device)}, data={self.data})'
return f'Shard(device={self.device!r}, data={self.data})'

@functools.cached_property
def index(self) -> Index:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/checkify.py
Expand Up @@ -1176,7 +1176,7 @@ def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
if not isinstance(arg, (Array, np.ndarray)):
raise TypeError('Formatting arguments to checkify.check need to be '
'PyTrees of arrays, but got '
f'{repr(arg)} of type {type(arg)}.')
f'{arg!r} of type {type(arg)}.')
new_error = FailedCheckError(get_traceback(), msg, *fmt_args, **fmt_kwargs)
error = assert_func(init_error, jnp.logical_not(pred), new_error)
_check_error(error, debug=debug)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/config.py
Expand Up @@ -1015,7 +1015,7 @@ def _validate_default_device(val):
repr(val), type(val))
return
raise ValueError('jax.default_device must be passed a Device object (e.g. '
f"`jax.devices('cpu')[0]`), got: {repr(val)}")
f"`jax.devices('cpu')[0]`), got: {val!r}")


# TODO(skye): default_device only accepts devices for now. Make it work with
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/core.py
Expand Up @@ -1115,15 +1115,15 @@ def _why_alive_container_info(container, obj_id) -> str:
', '.join(map(repr, keys)))
if hasattr(container, '__dict__'):
keys = [k for k in vars(container) if id(vars(container)[k]) == obj_id]
if len(keys) == 1: return f'{name}.{str(keys[0])}'
if len(keys) == 1: return f'{name}.{keys[0]}'
elif len(keys) > 1: return f'{name} in vars ' + ', '.join(map(repr, keys))
if isinstance(container, (list, tuple)):
idxs = [i for i, x in enumerate(container) if id(x) == obj_id]
if len(idxs) == 1: return f'{name}[{idxs[0]}]'
else: return f'{name} at indices ' + ', '.join(map(str, idxs))
if isinstance(container, dict):
keys = [k for k in container if id(container[k]) == obj_id]
if len(keys) == 1: return f'{name}[{repr(keys[0])}]'
if len(keys) == 1: return f'{name}[{keys[0]!r}]'
else: return f'{name} at keys ' + ', '.join(map(repr, keys))
if isinstance(container, types.ModuleType):
return f' named {container.__name__}'
Expand Down Expand Up @@ -1411,7 +1411,7 @@ def valid_jaxtype(x) -> bool:
def check_valid_jaxtype(x):
if not valid_jaxtype(x):
raise TypeError(
f"Value {repr(x)} of type {type(x)} is not a valid JAX type")
f"Value {x!r} of type {type(x)} is not a valid JAX type")


def concrete_aval(x):
Expand All @@ -1420,7 +1420,7 @@ def concrete_aval(x):
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return concrete_aval(x.__jax_array__())
raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
"type")


Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/batching.py
Expand Up @@ -61,7 +61,7 @@ class IndexedAxisSize:
idx: core.Var
lengths: Array | core.Var | Tracer
def __repr__(self) -> str:
return f'{str(self.lengths)}.Var{id(self.idx)}'
return f'{self.lengths}.Var{id(self.idx)}'
replace = dataclasses.replace

# Jumble(aval=a:3 => f32[[3 1 4].a],
Expand Down Expand Up @@ -1101,7 +1101,7 @@ def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
try:
_ = core.get_aval(x)
except TypeError as e:
raise TypeError(f"Output from batched function {repr(x)} with type "
raise TypeError(f"Output from batched function {x!r} with type "
f"{type(x)} is not a valid JAX type") from e
if src == dst:
return x
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/mlir.py
Expand Up @@ -346,7 +346,7 @@ def _source_info_to_location(
primitive: core.Primitive, params: dict,
source_info: source_info_util.SourceInfo,
name_stack: source_info_util.NameStack) -> ir.Location:
eqn_str = (f'{str(source_info.name_stack)}/'
eqn_str = (f'{source_info.name_stack}/'
f'{core.str_eqn_compact(primitive.name, params)}')
if config.include_full_tracebacks_in_locations.value:
if source_info.traceback is None:
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Expand Up @@ -2442,8 +2442,8 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
if dim_size * old_dtype.itemsize != new_dtype.itemsize:
raise ValueError(
f"Attempting to convert array of shape {operand.shape} "
f"from {str(old_dtype)} of size {old_dtype.itemsize} "
f"to {str(new_dtype)} of size {new_dtype.itemsize}, "
f"from {old_dtype} of size {old_dtype.itemsize} "
f"to {new_dtype} of size {new_dtype.itemsize}, "
f"but {dim_size} * {old_dtype.itemsize} != {new_dtype.itemsize}")
return operand.shape[:-1]

Expand Down
4 changes: 2 additions & 2 deletions jax/_src/numpy/array_methods.py
Expand Up @@ -435,7 +435,7 @@ def __getitem__(self, index):
return _IndexUpdateRef(self.array, index)

def __repr__(self):
return f"_IndexUpdateHelper({repr(self.array)})"
return f"_IndexUpdateHelper({self.array!r})"


class _IndexUpdateRef:
Expand All @@ -452,7 +452,7 @@ def __init__(self, array, index):
self.index = index

def __repr__(self):
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
return f"_IndexUpdateRef({self.array!r}, {self.index!r})"

def get(self, *, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/lax_numpy.py
Expand Up @@ -4154,7 +4154,7 @@ def take_along_axis(
idx_shape = shape(indices)
if not dtypes.issubdtype(index_dtype, integer):
raise TypeError("take_along_axis indices must be of integer type, got "
f"{str(index_dtype)}")
f"{index_dtype}")
if axis is None:
if ndim(indices) != 1:
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pjit.py
Expand Up @@ -989,7 +989,7 @@ def pjit_check_aval_sharding(
except ValueError as e:
raise ValueError(
f'One of {what_aval}{name_str} is incompatible with its sharding '
f'annotation {s}: {str(e)}')
f'annotation {s}: {e}')
# Use the `OpSharding` proto to find out how many ways each dimension of
# the aval is sharded. This approach will work across all
# XLACompatibleSharding.
Expand Down Expand Up @@ -1467,7 +1467,7 @@ def _pjit_batcher_for_sharding(
if isinstance(getattr(s, '_original_sharding', None), NamedSharding):
mesh = s._original_sharding.mesh # type: ignore
if mesh is None or mesh.empty:
s_type = (f', got: {repr(s._original_sharding)}'
s_type = (f', got: {s._original_sharding!r}'
if hasattr(s, '_original_sharding') else '')
raise ValueError(
'If you are using xmap or spmd_axis_name parameter of jax.vmap,'
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/sharding_impls.py
Expand Up @@ -397,7 +397,7 @@ def __reduce__(self):

def __repr__(self):
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f"SingleDeviceSharding(device={repr(self._device)}{mem})"
return f"SingleDeviceSharding(device={self._device!r}{mem})"

def __hash__(self):
if not hasattr(self, '_hash'):
Expand Down Expand Up @@ -873,7 +873,7 @@ def __hash__(self):

def __repr__(self):
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f'GSPMDSharding({repr(self._hlo_sharding)}{mem})'
return f'GSPMDSharding({self._hlo_sharding!r}{mem})'

def is_compatible_aval(self, aval_shape: Shape):
num_ways_dim_sharded, _ = get_num_ways_dim_sharded(self._hlo_sharding)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/source_info_util.py
Expand Up @@ -90,7 +90,7 @@ def extend(self, name: Union[tuple[str, ...], str]) -> 'NameStack':
def wrap_name(self, name: str) -> str:
if not self.stack:
return name
return f'{str(self)}/{name}'
return f'{self}/{name}'

def transform(self, transform_name: str) -> 'NameStack':
return NameStack((*self.stack, Transform(transform_name)))
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/test_util.py
Expand Up @@ -1275,7 +1275,7 @@ def parameterized_filterable(*,
for kw in kwargs:
testcase_name = kw.get("testcase_name")
if testcase_name is None:
testcase_name = "_".join(f"{k}={str(kw[k])}" # type: ignore
testcase_name = "_".join(f"{k}={kw[k]}" # type: ignore
for k in sorted(kw.keys()))
kw["testcase_name"] = sanitize_test_name(testcase_name) # type: ignore

Expand Down
8 changes: 4 additions & 4 deletions jax/_src/tree_util.py
Expand Up @@ -361,7 +361,7 @@ def __eq__(self, other):
return self.fun == other

def __repr__(self):
return f'_HashableCallableShim({repr(self.fun)})'
return f'_HashableCallableShim({self.fun!r})'


class Partial(functools.partial):
Expand Down Expand Up @@ -565,7 +565,7 @@ def pprint(self) -> str:

class GetitemKeyPathEntry(_DeprecatedKeyPathEntry):
def pprint(self) -> str:
return f'[{repr(self.key)}]'
return f'[{self.key!r}]'
def __str__(self):
return self.pprint()

Expand All @@ -579,13 +579,13 @@ def __str__(self):
class SequenceKey():
idx: int
def __str__(self):
return f'[{repr(self.idx)}]'
return f'[{self.idx!r}]'

@dataclass(frozen=True)
class DictKey():
key: Hashable
def __str__(self):
return f'[{repr(self.key)}]'
return f'[{self.key!r}]'

@dataclass(frozen=True)
class GetAttrKey():
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/export/shape_poly.py
Expand Up @@ -910,7 +910,7 @@ def __init__(self, *dim_specs):
def __new__(cls, *dim_specs):
for ds in dim_specs:
if not isinstance(ds, (int, str)) and ds != ...:
msg = (f"Invalid polymorphic shape element: {repr(ds)}; must be a string "
msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string "
"representing a dimension variable, or an integer, or ...")
raise ValueError(msg)
return tuple.__new__(PolyShape, dim_specs)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/impl_no_xla.py
Expand Up @@ -1090,7 +1090,7 @@ def _gather(operand, start_indices, *, dimension_numbers,
try:
return gather_fn(gather_args)
except ValueError as e:
errors.append(f"{gather_fn}: {repr(e)}")
errors.append(f"{gather_fn}: {e!r}")

error_msg = (f"Unsupported arguments for gather: {gather_args}, errors:\n" +
"\n".join(errors))
Expand Down
12 changes: 6 additions & 6 deletions jax/experimental/jax2tf/tests/back_compat_test_util.py
Expand Up @@ -214,13 +214,13 @@ def run_one_test(self, func: Callable[..., jax.Array],
# Pasted from the test output (see back_compat_test_util.py module docstring)
data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
testdata_version={CURRENT_TESTDATA_VERSION},
platform={repr(self.default_jax_backend())},
custom_call_targets={repr(current_custom_call_targets)},
serialized_date={repr(datetime.date.today())},
inputs={repr(data.inputs)},
expected_outputs={repr(res_run_current)},
platform={self.default_jax_backend()!r},
custom_call_targets={current_custom_call_targets!r},
serialized_date={datetime.date.today()!r},
inputs={data.inputs!r},
expected_outputs={res_run_current!r},
mlir_module_text=r\"\"\"\n{module_str}\"\"\",
mlir_module_serialized={repr(serialized)},
mlir_module_serialized={serialized!r},
xla_call_module_version={module_version},
) # End paste
Expand Down

0 comments on commit f2ce5db

Please sign in to comment.