Skip to content

Commit

Permalink
Ran pyupgrade --py39-plus **/*.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565650743
  • Loading branch information
superbobry authored and Copybara-Service committed Sep 15, 2023
1 parent e1bcd2c commit cbcbbaa
Show file tree
Hide file tree
Showing 71 changed files with 415 additions and 343 deletions.
3 changes: 2 additions & 1 deletion docs/ext/coverage_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# ==============================================================================
"""Asserts all public symbols are covered in the docs."""

from typing import Any, Mapping
from collections.abc import Mapping
from typing import Any

import haiku as hk
from haiku._src import test_utils
Expand Down
9 changes: 5 additions & 4 deletions examples/imagenet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# ==============================================================================
"""ImageNet dataset with typical pre-processing."""

from collections.abc import Iterable, Iterator, Mapping, Sequence
import enum
import itertools as it
import types
from typing import Iterable, Iterator, Mapping, Optional, Sequence, Tuple
from typing import Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -204,7 +205,7 @@ def _to_tfds_split(split: Split) -> tfds.Split:
return tfds.Split.VALIDATION


def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
def _shard(split: Split, shard_index: int, num_shards: int) -> tuple[int, int]:
"""Returns [start, end) for the given shard index."""
assert shard_index < num_shards
arange = np.arange(split.num_examples)
Expand Down Expand Up @@ -250,8 +251,8 @@ def _distorted_bounding_box_crop(
jpeg_shape: tf.Tensor,
bbox: tf.Tensor,
min_object_covered: float,
aspect_ratio_range: Tuple[float, float],
area_range: Tuple[float, float],
aspect_ratio_range: tuple[float, float],
area_range: tuple[float, float],
max_attempts: int,
) -> tf.Tensor:
"""Generates cropped_image using one of the bboxes randomly distorted."""
Expand Down
7 changes: 4 additions & 3 deletions examples/imagenet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# ==============================================================================
"""ResNet50 on ImageNet2012."""

from collections.abc import Iterable, Mapping
import contextlib
import functools
import timeit
from typing import Iterable, Mapping, NamedTuple, Tuple
from typing import NamedTuple

from absl import app
from absl import flags
Expand Down Expand Up @@ -131,7 +132,7 @@ def loss_fn(
state: hk.State,
loss_scale: jmp.LossScale,
batch: dataset.Batch,
) -> Tuple[jax.Array, Tuple[jax.Array, hk.State]]:
) -> tuple[jax.Array, tuple[jax.Array, hk.State]]:
"""Computes a regularized loss for the given batch."""
logits, state = forward.apply(params, state, None, batch, is_training=True)
labels = jax.nn.one_hot(batch['labels'], 1000)
Expand All @@ -148,7 +149,7 @@ def loss_fn(
def train_step(
train_state: TrainState,
batch: dataset.Batch,
) -> Tuple[TrainState, Scalars]:
) -> tuple[TrainState, Scalars]:
"""Applies an update to parameters and returns new state."""
params, state, opt_state, loss_scale = train_state
grads, (loss, new_state) = (
Expand Down
4 changes: 2 additions & 2 deletions examples/impala/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""A stateless agent interface."""
import collections
import functools
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional

import dm_env
import haiku as hk
Expand Down Expand Up @@ -81,7 +81,7 @@ def step(
params: hk.Params,
timestep: dm_env.TimeStep,
state: Nest,
) -> Tuple[AgentOutput, Nest]:
) -> tuple[AgentOutput, Nest]:
"""For a given single-step, unbatched timestep, output the chosen action."""
# Pad timestep, state to be [T, B, ...] and [B, ...] respectively.
timestep = jax.tree_util.tree_map(lambda t: t[None, None, ...], timestep)
Expand Down
2 changes: 1 addition & 1 deletion examples/impala/haiku_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __call__(self, x):
padding='SAME',
)
for j in range(num_blocks):
block = ResidualBlock(num_channels, name='residual_{}_{}'.format(i, j))
block = ResidualBlock(num_channels, name=f'residual_{i}_{j}')
torso_out = block(torso_out)

torso_out = jax.nn.relu(torso_out)
Expand Down
5 changes: 2 additions & 3 deletions examples/impala/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import itertools
import queue
import threading
from typing import Dict, Tuple
import warnings

import dm_env
Expand Down Expand Up @@ -92,7 +91,7 @@ def _loss(
self,
theta: hk.Params,
trajectories: util.Transition,
) -> Tuple[jax.Array, Dict[str, jax.Array]]:
) -> tuple[jax.Array, dict[str, jax.Array]]:
"""Compute vtrace-based actor-critic loss."""
initial_state = jax.tree_util.tree_map(
lambda t: t[0], trajectories.agent_state)
Expand Down Expand Up @@ -163,7 +162,7 @@ def enqueue_traj(self, traj: util.Transition):
"""Enqueue trajectory."""
self._host_q.put(traj)

def params_for_actor(self) -> Tuple[int, hk.Params]:
def params_for_actor(self) -> tuple[int, hk.Params]:
return self._params_for_actor

def host_to_device_worker(self):
Expand Down
3 changes: 1 addition & 2 deletions examples/impala/run_catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Single-process IMPALA wiring."""

import threading
from typing import List

from absl import app
from bsuite.environments import catch
Expand All @@ -37,7 +36,7 @@
FRAMES_PER_ITER = ACTION_REPEAT * BATCH_SIZE * UNROLL_LENGTH


def run_actor(actor: actor_lib.Actor, stop_signal: List[bool]):
def run_actor(actor: actor_lib.Actor, stop_signal: list[bool]):
"""Runs an actor to produce num_trajectories trajectories."""
while not stop_signal[0]:
frame_count, params = actor.pull_params()
Expand Down
8 changes: 4 additions & 4 deletions examples/impala_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import functools
import queue
import threading
from typing import Any, Callable, NamedTuple, Tuple
from typing import Any, Callable, NamedTuple

from absl import app
from absl import logging
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, num_actions: int):
def __call__(
self,
timestep: dm_env.TimeStep,
) -> Tuple[jax.Array, jax.Array]:
) -> tuple[jax.Array, jax.Array]:
"""Process a batch of observations."""
torso = hk.Sequential([hk.Flatten(),
hk.Linear(128), jax.nn.relu,
Expand All @@ -78,7 +78,7 @@ def step(
params: hk.Params,
rng: jax.Array,
timestep: dm_env.TimeStep,
) -> Tuple[jax.Array, jax.Array]:
) -> tuple[jax.Array, jax.Array]:
"""Steps on a single observation."""
timestep = jax.tree_util.tree_map(lambda t: jnp.expand_dims(t, 0), timestep)
logits, _ = self._net(params, timestep)
Expand Down Expand Up @@ -199,7 +199,7 @@ def update(
params: hk.Params,
opt_state: optax.OptState,
trajs: Transition,
) -> Tuple[hk.Params, optax.OptState]:
) -> tuple[hk.Params, optax.OptState]:
g = jax.grad(self._agent.loss)(params, trajs)
updates, new_opt_state = self._opt_update(g, opt_state)
return optax.apply_updates(params, updates), new_opt_state
Expand Down
3 changes: 2 additions & 1 deletion examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# ==============================================================================
"""A minimal MNIST classifier example."""

from typing import Iterator, NamedTuple
from collections.abc import Iterator
from typing import NamedTuple

from absl import app
import haiku as hk
Expand Down
9 changes: 5 additions & 4 deletions examples/mnist_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# ==============================================================================
"""MNIST classifier with pruning as in https://arxiv.org/abs/1710.01878 ."""

from collections.abc import Iterator, Mapping, Sequence
import functools
from typing import Callable, Iterator, Mapping, Sequence, Tuple
from typing import Callable

from absl import app
import haiku as hk
Expand All @@ -28,7 +29,7 @@
Batch = Mapping[str, np.ndarray]
Predicate = Callable[[str, str, jax.Array], bool]
PredicateMap = Mapping[Predicate, jax.Array]
ModuleSparsity = Sequence[Tuple[Predicate, jax.Array]]
ModuleSparsity = Sequence[tuple[Predicate, jax.Array]]


def topk_mask(value: jax.Array, density_fraction: float) -> jax.Array:
Expand Down Expand Up @@ -76,7 +77,7 @@ def zhugupta_func(progress: float) -> float:

def _create_partitions(
module_sparsity: ModuleSparsity, params: hk.Params
) -> Tuple[Sequence[hk.Params], Sequence[jax.Array], hk.Params]:
) -> tuple[Sequence[hk.Params], Sequence[jax.Array], hk.Params]:
"""Partition params based on sparsity_predicate_map.
Args:
Expand Down Expand Up @@ -238,7 +239,7 @@ def get_updates(
params: hk.Params,
opt_state: optax.OptState,
batch: Batch,
) -> Tuple[hk.Params, optax.OptState]:
) -> tuple[hk.Params, optax.OptState]:
"""Learning rule (stochastic gradient descent)."""
grads = jax.grad(loss)(params, batch)
updates, opt_state = opt.update(grads, opt_state)
Expand Down
2 changes: 1 addition & 1 deletion examples/rnn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Tiny Shakespeare as a language modelling dataset."""

from typing import Iterator, Mapping
from collections.abc import Iterator, Mapping

import numpy as np
import tensorflow.compat.v2 as tf
Expand Down
3 changes: 2 additions & 1 deletion examples/transformer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# ==============================================================================
"""A simple example loader for an ASCII language-modelling dataset."""

from collections.abc import Iterable, Iterator
import itertools
import random
from typing import Iterable, Iterator, NamedTuple, TypeVar
from typing import NamedTuple, TypeVar

import numpy as np

Expand Down
8 changes: 5 additions & 3 deletions examples/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
--dataset_path=/tmp/shakespeare.txt --alsologtostderr
"""

from collections.abc import MutableMapping
import time
from typing import Any, MutableMapping, NamedTuple, Tuple, Union
from typing import Any, NamedTuple, Union

from absl import app
from absl import flags
Expand Down Expand Up @@ -126,7 +127,8 @@ def init(rng: jax.Array, data: _Batch) -> TrainingState:

@jax.jit
def update(
state: TrainingState, data: _Batch) -> Tuple[TrainingState, _Metrics]:
state: TrainingState, data: _Batch
) -> tuple[TrainingState, _Metrics]:
"""Does an SGD step, returning a new training state and metrics."""
rng, net_rng = jax.random.split(state.rng_key)
loss_and_grad_fn = jax.value_and_grad(loss_fn.apply)
Expand All @@ -152,7 +154,7 @@ def update(
def main(_):

# Create the dataset.
with open(DATASET_PATH.value, mode='r') as file:
with open(DATASET_PATH.value) as file:
train_dataset = dataset.load_ascii_dataset(
corpus=file.read(),
batch_size=BATCH_SIZE,
Expand Down
5 changes: 3 additions & 2 deletions examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
[0]https://arxiv.org/abs/1312.6114
"""

from collections.abc import Iterator, Sequence
import dataclasses
from typing import Iterator, Tuple, NamedTuple, Sequence
from typing import NamedTuple

from absl import app
from absl import flags
Expand Down Expand Up @@ -65,7 +66,7 @@ class Encoder(hk.Module):
latent_size: int
hidden_size: int = 512

def __call__(self, x: jax.Array) -> Tuple[jax.Array, jax.Array]:
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Encodes an image as an isotropic Guassian latent code."""
x = hk.Flatten()(x)
x = hk.Linear(self.hidden_size)(x)
Expand Down
20 changes: 10 additions & 10 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@
"""Base Haiku module."""

import collections
from collections.abc import Iterable, Iterator, Mapping, Sequence
import contextlib
import functools
import itertools as it
from typing import (Callable, Iterator, Iterable, NamedTuple, Optional, Set,
Tuple, Union, Any, Sequence, Mapping, FrozenSet, TypeVar)
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union
import warnings

from haiku._src import config
from haiku._src import data_structures
from haiku._src.typing import ( # pylint: disable=g-multiple-import
Initializer,
LiftingModuleType,
Params,
Module,
MutableParams,
State,
MutableState,
Module,
PRNGKey,
Params,
State,
)
import jax
from jax import config as jax_config
Expand Down Expand Up @@ -92,7 +92,7 @@ class Frame(NamedTuple):
freeze_params: bool
module_stack: Stack[ModuleState]
counter_stack: Stack[collections.Counter]
used_names_stack: Stack[Set[str]]
used_names_stack: Stack[set[str]]
jax_trace_stack: Stack[JaxTraceLevel]
frame_id: int

Expand Down Expand Up @@ -999,7 +999,7 @@ def assert_is_prng_key(key: PRNGKey):
f"actual=(shape={key.shape}, dtype={key.dtype}){config_hint}")


PRNGSequenceState = Tuple[PRNGKey, Iterable[PRNGKey]]
PRNGSequenceState = tuple[PRNGKey, Iterable[PRNGKey]]


class PRNGSequence(Iterator[PRNGKey]):
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def __next__(self) -> PRNGKey:

next = __next__

def take(self, num) -> Tuple[PRNGKey, ...]:
def take(self, num) -> tuple[PRNGKey, ...]:
self.reserve(max(num - len(self._subkeys), 0))
return tuple(next(self) for _ in range(num))

Expand Down Expand Up @@ -1407,7 +1407,7 @@ def with_rng(key: PRNGKey):
return current_frame().rng_stack(PRNGSequence(key))


def param_names() -> FrozenSet[Tuple[str, str]]:
def param_names() -> frozenset[tuple[str, str]]:
"""Returns all module and parameter names as a set of pairs."""
out = []
params = current_frame().params
Expand All @@ -1431,7 +1431,7 @@ def assert_no_new_parameters():
raise AssertionError(f"New parameters were created: {list(sorted(diff))}")


def _get_ids(collection_name: str) -> FrozenSet[int]:
def _get_ids(collection_name: str) -> frozenset[int]:
"""Returns the identity for all state in the current context."""
out = []
collection = getattr(current_frame(), collection_name)
Expand Down
5 changes: 3 additions & 2 deletions haiku/_src/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# ==============================================================================
"""Basic Haiku modules and functions."""

from collections.abc import Iterable, Sequence
import functools
from typing import Any, Callable, Iterable, Optional, Sequence, Type
from typing import Any, Callable, Optional

from haiku._src import base
from haiku._src import initializers
Expand Down Expand Up @@ -357,7 +358,7 @@ def __call__(self, *args, **kwargs) -> Any:
raise NotImplementedError


def to_module(f: Callable[..., Any]) -> Type[CallableModule]:
def to_module(f: Callable[..., Any]) -> type[CallableModule]:
"""Converts a function into a callable module class.
Sample usage:
Expand Down
Loading

0 comments on commit cbcbbaa

Please sign in to comment.