Skip to content

Commit

Permalink
Upgrade most .py sources to 3.9
Browse files Browse the repository at this point in the history
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
  • Loading branch information
superbobry committed Dec 8, 2023
1 parent 7af1c14 commit 36f6b52
Show file tree
Hide file tree
Showing 36 changed files with 198 additions and 194 deletions.
4 changes: 2 additions & 2 deletions jax/_src/api.py
Expand Up @@ -29,7 +29,7 @@
import math
import typing
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload,
cast, Optional)
cast)
import weakref

import numpy as np
Expand Down Expand Up @@ -2461,7 +2461,7 @@ def make_jaxpr_f(*args, **kwargs):
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
return make_jaxpr_f

def _infer_src_sharding(src, x) -> Optional[Sharding]:
def _infer_src_sharding(src, x) -> Sharding | None:
if src is not None:
return src
if isinstance(x, array.ArrayImpl):
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/basearray.py
Expand Up @@ -16,7 +16,8 @@

import abc
import numpy as np
from typing import Any, Sequence, Union
from typing import Any, Union
from collections.abc import Sequence

# TODO(jakevdp): fix import cycles and define these.
Shard = Any
Expand Down
56 changes: 28 additions & 28 deletions jax/_src/config.py
Expand Up @@ -22,7 +22,7 @@
import os
import sys
import threading
from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar
import warnings

from jax._src import lib
Expand Down Expand Up @@ -134,7 +134,7 @@ def _read(self, name):
raise AttributeError(f"Unrecognized config option: {name}")

def add_option(self, name, default, opt_type, meta_args, meta_kwargs,
update_hook: Optional[Callable[[Any], None]] = None):
update_hook: Callable[[Any], None] | None = None):
if name in self.values:
raise Exception(f"Config option {name} already defined")
self.values[name] = default
Expand Down Expand Up @@ -238,7 +238,7 @@ class _Unset: pass

class _StateContextManager(Generic[_T]):
def __init__(self, name, help, update_thread_local_hook,
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
validate_new_val_hook: Callable[[Any], None] | None = None,
extra_description: str = "", default_value: Any = no_default):
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
Expand Down Expand Up @@ -302,8 +302,8 @@ def define_bool_state(
default: bool,
help: str,
*,
update_global_hook: Optional[Callable[[bool], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None,
update_global_hook: Callable[[bool], None] | None = None,
update_thread_local_hook: Callable[[bool | None], None] | None = None,
upgrade: bool = False,
extra_description: str = '',
) -> _StateContextManager[bool]:
Expand Down Expand Up @@ -375,11 +375,11 @@ def define_bool_state(
def define_enum_state(
name: str,
enum_values: list[str],
default: Optional[str],
default: str | None,
help: str,
*,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[str]:
"""Set up thread-local state and return a contextmanager for managing it.
Expand Down Expand Up @@ -420,11 +420,11 @@ def validate(new_val):

def define_int_state(
name: str,
default: Optional[int],
default: int | None,
help: str,
*,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[int]:
"""Set up thread-local state and return a contextmanager for managing it.
Expand Down Expand Up @@ -463,11 +463,11 @@ def validate(new_val):

def define_float_state(
name: str,
default: Optional[float],
default: float | None,
help: str,
*,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[float]:
"""Set up thread-local state and return a contextmanager for managing it.
Expand Down Expand Up @@ -508,11 +508,11 @@ def validate(new_val):

def define_string_state(
name: str,
default: Optional[str],
default: str | None,
help: str,
*,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[str]:
"""Set up thread-local state and return a contextmanager for managing it.
Expand Down Expand Up @@ -552,9 +552,9 @@ def define_string_or_object_state(
default: Any,
help: str,
*,
update_global_hook: Optional[Callable[[Any], None]] = None,
update_thread_local_hook: Optional[Callable[[Any], None]] = None,
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
update_global_hook: Callable[[Any], None] | None = None,
update_thread_local_hook: Callable[[Any], None] | None = None,
validate_new_val_hook: Callable[[Any], None] | None = None,
) -> _StateContextManager[Any]:
"""Set up thread-local state and return a contextmanager for managing it.
Expand Down Expand Up @@ -651,9 +651,9 @@ def DEFINE_enum(name, default, *args, **kwargs) -> FlagHolder[str]:
# a global/thread-local state. These methods allow updates to part of the
# state when a configuration value changes.
class _GlobalExtraJitContext(NamedTuple):
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
threefry_partitionable: bool = False
softmax_custom_jvp: bool = False
Expand All @@ -675,12 +675,12 @@ class _ThreadLocalExtraJitContext(NamedTuple):
The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
dynamic_trace_state: Optional[Any] = None
dynamic_trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
threefry_partitionable: bool = False
softmax_custom_jvp: bool = False
Expand Down Expand Up @@ -1320,7 +1320,7 @@ def transfer_guard(new_val: str) -> Iterator[None]:
yield


def _update_debug_log_modules(module_names_str: Optional[str]):
def _update_debug_log_modules(module_names_str: str | None):
logging_config.disable_all_debug_logging()
if not module_names_str:
return
Expand Down
14 changes: 7 additions & 7 deletions jax/_src/dtypes.py
Expand Up @@ -24,7 +24,7 @@
import abc
import builtins
import functools
from typing import cast, overload, Any, Literal, Optional, Union
from typing import cast, overload, Any, Literal, Union
import warnings

import ml_dtypes
Expand Down Expand Up @@ -207,7 +207,7 @@ def to_complex_dtype(dtype: DTypeLike) -> DType:


@functools.cache
def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> Union[DType, ExtendedDType]:
def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> DType | ExtendedDType:
if issubdtype(dtype, extended):
if not allow_extended_dtype:
raise ValueError(f"Internal: canonicalize_dtype called on extended dtype {dtype} "
Expand All @@ -227,10 +227,10 @@ def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: An
def canonicalize_dtype(dtype: Any, allow_extended_dtype: Literal[False] = False) -> DType: ...

@overload
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> Union[DType, ExtendedDType]: ...
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType: ...

@export
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> Union[DType, ExtendedDType]:
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType:
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # pytype: disable=bad-return-type

Expand Down Expand Up @@ -292,7 +292,7 @@ def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType:
return dtype


def coerce_to_array(x: Any, dtype: Optional[DTypeLike] = None) -> np.ndarray:
def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray:
"""Coerces a scalar or NumPy array to an np.array.
Handles Python scalar type promotion according to JAX's rules, not NumPy's
Expand Down Expand Up @@ -643,10 +643,10 @@ def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> tuple[DType
def result_type(*args: Any, return_weak_type_flag: Literal[False] = False) -> DType: ...

@overload
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]: ...
def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]: ...

@export
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]:
def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]:
"""Convenience function to apply JAX argument dtype promotion.
Args:
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/extend/random.py
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Hashable
from typing import Callable
from collections.abc import Hashable

from jax import Array

Expand Down
10 changes: 5 additions & 5 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -25,7 +25,7 @@
import operator
import re
import typing
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
from typing import Any, Callable, NamedTuple, Protocol, Union
import warnings

import numpy as np
Expand Down Expand Up @@ -707,7 +707,7 @@ def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str |
return layout._to_xla_layout()


def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None:
if s is None:
return None
assert isinstance(s, sharding_impls.XLACompatibleSharding)
Expand Down Expand Up @@ -1454,7 +1454,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_override_lowering_rule(eqn.primitive)
platform_rules: dict[str, LoweringRule] = {}
default_rule: Optional[LoweringRule] = None
default_rule: LoweringRule | None = None
# See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule`
if override_rule is not None:
default_rule = override_rule
Expand Down Expand Up @@ -1525,7 +1525,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
def lower_per_platform(ctx: LoweringRuleContext,
description: str,
platform_rules: dict[str, LoweringRule],
default_rule: Optional[LoweringRule],
default_rule: LoweringRule | None,
effects: effects_lib.Effects,
*rule_args: ir.Value,
**rule_kwargs) -> ir.Value:
Expand Down Expand Up @@ -1710,7 +1710,7 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
return func_op


def check_backend_matches(inner_backend: Optional[str],
def check_backend_matches(inner_backend: str | None,
lowering_platforms: Sequence[str]):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -25,7 +25,8 @@
import logging
import math
import threading
from typing import (Any, Callable, NamedTuple, Iterator, Optional, Union, cast, TypeVar)
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
from collections.abc import Iterator
import warnings

import numpy as np
Expand Down Expand Up @@ -2914,8 +2915,8 @@ def _compile_replicated_mesh_executable_from_hlo(

@lru_cache
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,
memory_kind: Optional[str] = None) -> sharding_impls.NamedSharding:
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
memory_kind: str | None = None) -> sharding_impls.NamedSharding:
if pspec is None:
pspec, parsed_pspec = PartitionSpec(), None
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/jaxpr_util.py
Expand Up @@ -202,7 +202,7 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes:
pprof tool for visualization.
"""
d: DefaultDict[tuple[Optional[xla_client.Traceback], core.Primitive], int]
d = collections.defaultdict(lambda: 0)
d = collections.defaultdict(int)
for _, eqn in all_eqns(jaxpr):
d[(eqn.source_info.traceback, eqn.primitive)] += 1
return _pprof_profile(d)
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -19,7 +19,7 @@
import inspect
import itertools
import operator
from typing import Any, Callable, Optional, TypeVar
from typing import Any, Callable, TypeVar

import jax
import weakref
Expand Down Expand Up @@ -104,7 +104,7 @@ def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1) -> tuple[Carry, Y]:
"""Scan a function over leading array axes while carrying along state.
Expand Down

0 comments on commit 36f6b52

Please sign in to comment.